From 34f504c877dc9152786493ca2df00d4435d17204 Mon Sep 17 00:00:00 2001 From: Robert Dyro Date: Thu, 29 Aug 2024 03:10:06 -0700 Subject: [PATCH] [WIP] jax pytree data serialization PiperOrigin-RevId: 668858634 --- .../array_serialization/__init__.py | 5 + .../array_serialization/asyncio_utils.py | 73 ++ .../pytree_serialization.py | 701 ++++++++++++++++++ .../pytree_serialization_utils.py | 365 +++++++++ .../array_serialization/serialization.py | 403 ++-------- .../array_serialization/serialization_test.py | 471 +++++++++--- .../array_serialization/tensorstore_impl.py | 529 +++++++++++++ 7 files changed, 2095 insertions(+), 452 deletions(-) create mode 100644 jax/experimental/array_serialization/asyncio_utils.py create mode 100644 jax/experimental/array_serialization/pytree_serialization.py create mode 100644 jax/experimental/array_serialization/pytree_serialization_utils.py create mode 100644 jax/experimental/array_serialization/tensorstore_impl.py diff --git a/jax/experimental/array_serialization/__init__.py b/jax/experimental/array_serialization/__init__.py index 577c9dcb20e0..d84d9c6896a6 100644 --- a/jax/experimental/array_serialization/__init__.py +++ b/jax/experimental/array_serialization/__init__.py @@ -11,3 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from jax.experimental.array_serialization.serialization import ( + GlobalAsyncCheckpointManager) +from jax.experimental.array_serialization.pytree_serialization import ( + save, load, load_pytree, nonblocking_load, nonblocking_save) diff --git a/jax/experimental/array_serialization/asyncio_utils.py b/jax/experimental/array_serialization/asyncio_utils.py new file mode 100644 index 000000000000..46a687062b5e --- /dev/null +++ b/jax/experimental/array_serialization/asyncio_utils.py @@ -0,0 +1,73 @@ +# Copyright 2021 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import functools +import logging +from concurrent.futures import ThreadPoolExecutor + +logger = logging.getLogger(__name__) + +_PARALLEL_THREAD_POOL_EXECUTOR = ThreadPoolExecutor(max_workers=64) +_ORDERED_THREAD_EXECUTOR = ThreadPoolExecutor(max_workers=1) + +# Lifted from T5X. +class _LimitInFlightBytes: + """Limits in-flight bytes when reading/writing checkpoints per process.""" + + def __init__(self, num_bytes): + self._max_bytes = num_bytes + self._available_bytes = num_bytes + self._cv = asyncio.Condition(lock=asyncio.Lock()) + + async def wait_for_bytes(self, requested_bytes): + if requested_bytes > self._max_bytes: + self._max_bytes = requested_bytes + logger.warning("Requested more bytes than we reserved space for: %d > %d" + ". Increasing the limit to %d.", requested_bytes, + self._max_bytes, self._max_bytes) + async with self._cv: + await self._cv.wait_for(lambda: self._available_bytes > requested_bytes) + self._available_bytes -= requested_bytes + assert self._available_bytes >= 0 + + async def release_bytes(self, requested_bytes): + async with self._cv: + self._available_bytes += requested_bytes + assert self._available_bytes <= self._max_bytes + self._cv.notify_all() + + +def _maybe_run_async_sync(name, async_fn, ordered_execution: bool = False): + """Run async routine synchronously irrespective of the current environment. + + Args: + name: The name of the function. + async_fn: The function to run. + ordered_execution: If True, the function will be run in an ordered sequence + Otherwise, it will be run in a separate thread pool. + Returns: + The result of the function async_fn or raises an exception. + """ + thread_pool_executor = (_ORDERED_THREAD_EXECUTOR if ordered_execution + else _PARALLEL_THREAD_POOL_EXECUTOR) + + def wrapped_fn(*args, **kw): + return thread_pool_executor.submit( + lambda: asyncio.run(async_fn(*args, **kw))).result() + + functools.update_wrapper(wrapper=wrapped_fn, wrapped=async_fn) + wrapped_fn.__name__ = name + wrapped_fn.__qualname__ = name + return wrapped_fn diff --git a/jax/experimental/array_serialization/pytree_serialization.py b/jax/experimental/array_serialization/pytree_serialization.py new file mode 100644 index 000000000000..d1b1b92c1d97 --- /dev/null +++ b/jax/experimental/array_serialization/pytree_serialization.py @@ -0,0 +1,701 @@ +# Copyright 2021 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from os import PathLike +import os +from types import ModuleType +import re +from typing import Any +from uuid import uuid4, UUID +import json +import asyncio +import threading +from concurrent.futures import ThreadPoolExecutor +import shutil +import logging + +import jax +from jax.tree_util import PyTreeDef +from jax.util import safe_zip +from jax._src import distributed +from jax._src.layout import Layout +from jax.experimental import multihost_utils +from jax.experimental.array_serialization import asyncio_utils +from jax.experimental.array_serialization import tensorstore_impl as ts_impl +from jax.experimental.array_serialization.pytree_serialization_utils import ( + MemKVStore, _LEAF_IDS_KEY, _TREE_REPR_KEY, _LEAF_COUNT_KEY, + serialize_pytree, deserialize_pytree, + default_serialization_context, SerializationContext) +from jax.sharding import SingleDeviceSharding +from jax._src.path import epath_installed, Path +import numpy as np + +_ORDERED_THREAD_EXECUTOR = ThreadPoolExecutor(max_workers=1) +_THREADING_SAVE_LOCK = threading.Lock() +_MAX_CONCURRENCY = 64 + +class _MISSING_TYPE: + pass +MISSING = _MISSING_TYPE() + +_REMOTE_URL_PREFIXES = ['gs://', 's3://'] +_PYTREEDEF_FILE = "pytreedef.json" +_TENSORSTORE_SUFFIX = ".tensorstore" +_LEAF_DATA_DIR = "leaf_data" +_OBJ_DATA_ARCHIVE = "obj_data.zip" +_NODE_DATA_ARCHIVE = "node_data.zip" +_TYPE_ID_LEAF_DELIMITER = " -> " +_USE_OCDBT = True # a lot of the code relies on this being True +_MAX_PATH_LENGTH = 4096 +_ARRAY_STORE_DIRNAME = f"array_store{_TENSORSTORE_SUFFIX}" +_ARRAY_TYPE_NAME = "Array" +_ARRAY_TYPE_REGEX = r"Array\[\[([0-9, ]*)\],\s*([a-zA-Z0-9_]+)\]" +_DOT_REPLACEMENT = ":" + +__all__ = ["save", "load", "load_pytree", + "nonblocking_load", "nonblocking_save"] + +PyTreeT = Any +PickleModule = ModuleType + +logger = logging.getLogger(__name__) + +def _get_sync_client() -> distributed.xla_extension.DistributedRuntimeClient: + assert jax.process_count() > 1, ( + "You are attempting to wait for other hosts, but there is only 1 host" + ) + assert distributed.global_state.client is not None, ( + "The distributed runtime is not initialized. You likely need to call" + " `jax.distributed.initialize()` first." + ) + return distributed.global_state.client + +def _get_unique_sync_key() -> str | None: + """Generate a thread-local key for ensuring all host finish (de)serializing""" + if jax.process_count() == 1: + return None + # broadcast a thread-local unique barrier name + sync_key_id = UUID(bytes=np.array(multihost_utils.broadcast_one_to_all( + np.frombuffer(uuid4().bytes, dtype=np.int32))).tobytes()) + sync_key = f"jax_sync_key_{str(sync_key_id)}" + return sync_key + +def _is_str_same_on_all_hosts(path: str | PathLike[str]) -> bool: + """All-gather the location of the checkpoint and check if it's the same.""" + if jax.process_count() <= 1: + return False + path_b = str(path).encode("utf-8") + assert len(path_b) <= _MAX_PATH_LENGTH, ( + f"Path exceeds maximum length of {_MAX_PATH_LENGTH} in multiprocess" + " case.") + path_array = np.concatenate([ + np.frombuffer(path_b, dtype=np.uint8), np.zeros( + _MAX_PATH_LENGTH - len(path_b), dtype=np.uint8)]) + all_path_arrays = multihost_utils.process_allgather(path_array) + return bool(np.all(all_path_arrays == all_path_arrays[:1, ...])) + +def _sync_on_key(key: str | None, extra_tag: str = "") -> None: + if key is None: + return + full_key = key if not extra_tag else f"{key}-{extra_tag}" + multihost_utils.sync_global_devices(full_key) + +def _is_array_like(x): + return isinstance(x, (jax.Array, np.ndarray)) + +def _leaf_to_type_desc(leaf) -> str: + if leaf is None: + return "null" + elif isinstance(leaf, (np.ndarray, jax.Array)): + return (f"{_ARRAY_TYPE_NAME}[[{', '.join(map(str, leaf.shape))}]," + + f" {leaf.dtype.name}]") + else: + return type(leaf).__name__ + +def _leaf_desc_to_leaf(leaf_desc: str) -> str | jax.ShapeDtypeStruct: + leaf_type: str = (leaf_desc.split(_TYPE_ID_LEAF_DELIMITER, 1)[0] + if _TYPE_ID_LEAF_DELIMITER in leaf_desc else leaf_desc) + if not leaf_type.startswith(_ARRAY_TYPE_NAME): + return leaf_type + shape_dtype_match = re.match(_ARRAY_TYPE_REGEX, leaf_type) + assert shape_dtype_match is not None, ( + f"Failed to parse array descriptor: {leaf_type} with pattern:" + f" {_ARRAY_TYPE_REGEX}") + shape_str, dtype_str = shape_dtype_match.groups() + shape = [int(x.strip()) for x in shape_str.strip("]").strip().split(",") + if len(x.strip()) > 0] + dtype = jax.numpy.dtype(dtype_str) + return jax.ShapeDtypeStruct(shape, dtype) + +def _join_leaf_type_and_id(leaf_type: str, leaf_id: str) -> str: + return f"{leaf_type}{_TYPE_ID_LEAF_DELIMITER}{leaf_id}" + +def _inscribe_leaf_types(pytree_repr: dict[str, Any], + leaf_id_type_map: dict[str, str]): + """Rewrite a JSON PyTree representation by adding type to leaf_id.""" + if pytree_repr["node_type"] == "leaf": + leaf_id = pytree_repr["leaf_id"] + if leaf_id is None: + return + pytree_repr["leaf_id"] = _join_leaf_type_and_id(leaf_id_type_map[leaf_id], + leaf_id) + else: + _ = [_inscribe_leaf_types(child, leaf_id_type_map) + for child in pytree_repr["children"]] + +def _inplace_add_types_to_pytree_repr(pytree_repr, leaf_ids_flat, data_flat): + # inscribe types into leaf ids in-place + leaf_ids_type_map = {leaf_id: _leaf_to_type_desc(leaf) for (leaf_id, leaf) + in safe_zip(leaf_ids_flat, data_flat)} + _inscribe_leaf_types(pytree_repr[_TREE_REPR_KEY], leaf_ids_type_map) + pytree_repr[_LEAF_IDS_KEY] = [ + _join_leaf_type_and_id(leaf_ids_type_map[leaf_id], leaf_id) + for leaf_id in leaf_ids_flat] + +def _combine_two_pytrees(dest: dict[str, Any], source: dict[str, Any]): + """Combine two pytrees in JSON format, in-place into dest.""" + assert dest["node_type"] == source["node_type"] + if dest["node_type"] == "leaf": + _, id1 = dest["leaf_id"].split(_TYPE_ID_LEAF_DELIMITER, 1) + type2, id2 = source["leaf_id"].split(_TYPE_ID_LEAF_DELIMITER, 1) + assert id1 == id2 + if type2 != "null": + dest["leaf_id"] = source["leaf_id"] + return dest + else: + dest["children"] = [ + _combine_two_pytrees(child1, child2) for (child1, child2) + in safe_zip(dest["children"], source["children"])] + return dest + + +def _is_remote_path(path: str | PathLike[str]): + # we check whether a path is remote by checking the prefix + # we need to truncate e.g., gs:// to gs:/ because pathlib.Path collapses // + return any(str(path).startswith(prefix[:-1]) + for prefix in _REMOTE_URL_PREFIXES) + +def _rm_dir(root: Path) -> None: + if _is_remote_path(root): + root.rmtree() # pytype: disable=attribute-error + else: + shutil.rmtree(root) + +def _maybe_overwrite_or_error(root: Path, overwrite: bool, partial_write: bool, + pytree_repr: dict[str, Any], + distinct_locations: bool, sync_key: str | None + ) -> dict[str, Any]: + if overwrite: + if root.exists() and len(list(root.iterdir())) > 0: + # check that we're only deleting things that come from JAX + # refuse to rm directories containing additional entries + paths_present = list(root.iterdir()) + extra_member_paths = [path for path in paths_present if path.name not in + (_PYTREEDEF_FILE, _LEAF_DATA_DIR, + _NODE_DATA_ARCHIVE)] + + assert len(extra_member_paths) == 0, ( + "Refusing to work on a directory that is not a previous checkpoint." + f" Unrecognized paths: {extra_member_paths}. Remove them manually if" + f" you're sure you want to use {root} as the checkpoint directory.") + + if partial_write and Path(_PYTREEDEF_FILE) in [path.relative_to(root) + for path in paths_present]: + try: + other_pytree = json.loads((root / _PYTREEDEF_FILE).read_text()) + assert other_pytree[_LEAF_COUNT_KEY] == pytree_repr[_LEAF_COUNT_KEY] + tree_repr = _combine_two_pytrees(other_pytree[_TREE_REPR_KEY], + pytree_repr[_TREE_REPR_KEY]) + new_leaf_ids = [y if not y.startswith("null") else x + for x, y in safe_zip(other_pytree[_LEAF_IDS_KEY], + pytree_repr[_LEAF_IDS_KEY])] + return {_TREE_REPR_KEY: tree_repr, + _LEAF_COUNT_KEY: pytree_repr[_LEAF_COUNT_KEY], + _LEAF_IDS_KEY: new_leaf_ids} + except AssertionError: + logger.warning("The previous pytree does not match, overwritting" + " existing data.") + if (jax.process_index() == 0 or distinct_locations) and root.exists(): + _rm_dir(root) + _sync_on_key(sync_key, "overwrite") + return pytree_repr + else: + if (root.exists() and len(list(root.iterdir())) > 0): # not empty + raise NotImplementedError(f"Files already exist at path: `{root}`," + f" but you specified `{overwrite=}`") + return pytree_repr + + +def _obj_serialize(archive: MemKVStore, filename_id: str | int, x: Any, + ctx: SerializationContext) -> None: + """Serialization method for NOT-array objects.""" + # we're only interested in name and suffix + filename = Path(Path(str(filename_id)).name) + serialization_registry = ctx.leaf_serialization_registry + if _is_array_like(x): + raise ValueError( + "Arrays cannot be serialized using this method for non-arrays.") + # dispatch the serialization method in a thread to yield async control + payload, method_name = serialization_registry.serialize(x) + suffix = "." + method_name.lstrip(".").replace(".", _DOT_REPLACEMENT) + archive.write(filename.with_suffix(suffix), payload) + +def _obj_deserialize(archive: MemKVStore, filename: str, + ctx: SerializationContext, best_effort: bool = False + ) -> Any: + """Deserialization method for NON-array objects.""" + serialization_registry = ctx.leaf_serialization_registry + + path = Path(filename) + payload = archive.read(path) + method_name = str(path.suffix).lstrip(".").replace(_DOT_REPLACEMENT, ".") + try: + return serialization_registry.deserialize(payload, method_name) + except ValueError as exc: + if best_effort: + logging.warning("Unrecognized data type `%s` we'll do our best and just" + " return the raw bytes", method_name) + return payload + else: + raise exc + +async def serialize_array(arr, path, extra_config, distinct_locations: bool) -> None: + arr = jax.numpy.asarray(arr, dtype=arr.dtype) + extra_ts_spec = extra_config + process_num = (jax.process_index() if ( + jax.process_count() > 1 and not distinct_locations) else None) + default_ts_spec = ts_impl.get_tensorstore_spec( + path, ocdbt=_USE_OCDBT, process_num=process_num, arr=arr) + expected_path = default_ts_spec['kvstore']['base']['path'] + ts_spec = ts_impl.merge_nested_specs(default_ts_spec, extra_ts_spec) + ts_impl.verify_tensorstore_spec(ts_spec, arr, expected_path, + check_metadata=True) + # all hosts write because they're writing to different storage locations (to + # be combined later) -> `primary_host=None` + await ts_impl.async_serialize(arr, ts_spec, primary_host=None) + +def finalize_array_store(kvstore_path, extra_config, distinct_locations: bool + ) -> None: + # only in multiprocess case and only process 0 + if distinct_locations or jax.process_count() <= 1 or jax.process_index() != 0: + return + extra_ts_spec = extra_config + dummy_key_path = os.path.join(kvstore_path, "dummy_key") + combined_ts_spec = ts_impl.merge_nested_specs(ts_impl.get_tensorstore_spec( + dummy_key_path, ocdbt=_USE_OCDBT, process_num=None), extra_ts_spec) + children_ts_spec = [ts_impl.merge_nested_specs(ts_impl.get_tensorstore_spec( + dummy_key_path, ocdbt=_USE_OCDBT, process_num=i), extra_ts_spec) + for i in range(jax.process_count())] + combined_kvstore = combined_ts_spec["kvstore"] + children_kvstores = [ts_spec["kvstore"] for ts_spec in children_ts_spec] + _ = combined_kvstore.pop("path") + _ = [kvstore.pop("path") for kvstore in children_kvstores] + asyncio.run(ts_impl.combine_kvstores(combined_kvstore, children_kvstores)) + +async def deserialize_array( + path: str | PathLike[str], sharding: jax.sharding.Sharding | Layout, + ts_spec: dict[str, Any], + byte_limiter: asyncio_utils._LimitInFlightBytes | None = None) -> jax.Array: + # every process reads from the central location + default_ts_spec = ts_impl.get_tensorstore_spec( + path, ocdbt=_USE_OCDBT, process_num=None) + expected_path = default_ts_spec['kvstore']['base']['path'] + ts_spec = ts_impl.merge_nested_specs(default_ts_spec, ts_spec) + ts_impl.verify_tensorstore_spec(ts_spec, arr=None, path=expected_path, + check_metadata=False) + return await ts_impl.async_deserialize(sharding, ts_spec, + byte_limiter=byte_limiter) + +def save(data: PyTreeT, directory: str | PathLike[str], overwrite: bool = True, + partial_write: bool = False, ts_specs: PyTreeT | None = None) -> None: + """Saves the given data structure to the provided directory path. + + This function provides functionality to serialize and save a data structure + comprising JAX arrays, NumPy arrays, Python objects, etc., along with its + structure to a given directory. It leverages `PyTree` for flattening and + reconstructing the data structure. + + Args: + data: The data structure to be saved. Arbitrary composition of JAX arrays, + NumPy arrays, and Python objects, including nested structures. + directory: The directory path where the data will be saved. A local path or + a remote URL (e.g., gs://, s3://). For remote URLs, `etils` is required. + overwrite: If True, any existing directory with the same name will be + overwritten. + Raises: + AssertionError: If attempting to save to a remote path without the `etils` + package installed. + NotImplementedError: If `overwrite` is False and a checkpoint already + exists at the provided directory. + """ + with _THREADING_SAVE_LOCK: + return _save(data, directory, overwrite, partial_write, ts_specs) + +def _save(data: PyTreeT, directory: str | PathLike[str], overwrite: bool = True, + partial_write: bool = False, ts_specs: PyTreeT | None = None) -> None: + sync_key = _get_unique_sync_key() # get a synchronization key for multi-host + + assert not _is_remote_path(directory) or epath_installed, ( + "For checkpointing using remote URLs (e.g., gs, s3) you need `etils`" + " module installed. You can install it using `pip install etils`." + ) + data_flat, pytreedef = jax.tree.flatten(data, is_leaf=lambda x: x is None) + distinct_locations = not _is_str_same_on_all_hosts(directory) + if jax.process_count() > 1 and distinct_locations: + logger.warning("Saving to different locations on different hosts is" + " supported, but extremely fragile. Consider using a single" + " location.") + + # start serialization ################################## + futures, executor = [], ThreadPoolExecutor(max_workers=_MAX_CONCURRENCY) + serialization_ctx = default_serialization_context.copy() + + # 0. serialize the pytree + pytree_repr, leaf_count, node_data_store = serialize_pytree( + pytreedef, ctx=serialization_ctx) + leaf_ids_flat = list(range(leaf_count)) + _inplace_add_types_to_pytree_repr(pytree_repr, leaf_ids_flat, data_flat) + + # overwrite or error + root = Path(directory).resolve() + pytree_repr = _maybe_overwrite_or_error( + root, overwrite, partial_write, pytree_repr, distinct_locations, sync_key) + + if not _is_remote_path(directory): + if distinct_locations or jax.process_index() == 0: + root.mkdir(exist_ok=True) # do not make parents, that's too much + assert root.exists() and root.is_dir() + _sync_on_key(sync_key, "mkdir") + + def _write_pytree(): + if not (jax.process_index() == 0 or distinct_locations): + return + # augment the pytree representation with leaf types + (root / _PYTREEDEF_FILE).write_text(json.dumps(pytree_repr, indent=2)) + + futures.append(executor.submit(_write_pytree)) + + # 1. serialize non-array (objects) in the pytree + def _write_objects(): + if not (jax.process_index() == 0 or distinct_locations): + return + archive_path = root / _LEAF_DATA_DIR / _OBJ_DATA_ARCHIVE + archive_data = None + if overwrite and archive_path.is_file() and archive_path.exists(): + archive_data = archive_path.read_bytes() + obj_archive = MemKVStore(archive_data) + objs_and_paths = [ + (data, leaf_id) for leaf_id, data in safe_zip(leaf_ids_flat, data_flat) + if not _is_array_like(data) and data is not None] + _ = list(executor.map(lambda arg: _obj_serialize(*arg), + [(obj_archive, leaf_id, data, serialization_ctx) + for data, leaf_id in objs_and_paths])) + archive_path.parent.mkdir(exist_ok=True) + archive_path.write_bytes(obj_archive.tobytes()) # always write + + futures.append(executor.submit(_write_objects)) + + # 2. serialize arrays + arrs_and_paths = [(data, root / _LEAF_DATA_DIR / _ARRAY_STORE_DIRNAME / + str(leaf_id)) for leaf_id, data in safe_zip( + leaf_ids_flat, data_flat) if _is_array_like(data)] + ts_specs = (([None] * len(arrs_and_paths)) if ts_specs is None else + jax.tree.leaves(ts_specs, + is_leaf=ts_impl.is_tensorstore_spec_leaf)) + + async def _serialize_arrays(): + await asyncio.gather(*[serialize_array( + arr, path, extra_ts_spec, distinct_locations) + for ((arr, path), extra_ts_spec) in safe_zip(arrs_and_paths, ts_specs)]) + + futures.append(executor.submit(asyncio.run, _serialize_arrays())) + + # 3. serialize node data if permissive ################# + def _write_node_data(): + if not (jax.process_index() == 0 or distinct_locations): + return + node_data_archive = MemKVStore() + for node_data_key, node_data_bytes in node_data_store.items(): + # this KV storage only admits sequential writes + node_data_archive.write(node_data_key, node_data_bytes) + # wait on all node_data serialize futures to finish before writing to disk + archive_path = root / _NODE_DATA_ARCHIVE # always write + archive_path.write_bytes(node_data_archive.tobytes()) + + futures.append(executor.submit(_write_node_data)) + + _ = [fut.result() for fut in futures] + _sync_on_key(sync_key, "serialization") + if len(arrs_and_paths) > 0: + store_path = arrs_and_paths[0][1].parent + finalize_array_store(store_path, ts_specs[0], distinct_locations) + # we are done with all async ops here, we can block + _sync_on_key(sync_key, "end") + +def load(directory: str | PathLike[str], + shardings: PyTreeT | _MISSING_TYPE = MISSING, + pytree: PyTreeT | _MISSING_TYPE = MISSING, + ts_specs: PyTreeT | None = None, best_effort: bool = False) -> PyTreeT: + """Loads and reconstructs a data structure from a directory. + + Args: + directory: Directory path where the data is stored. + shardings: Sharding strategy for array objects. If None, defaults to + single device sharding on the default device. + pytree: Optional pre-populated PyTree for structure. If provided, must + specify a pytree with string object ids. Useful for partial reads. + best_effort: Proceed with deserialization even in the face of partial + failures. Return custom nodes as a list of children. + Returns: + Reconstructed data structure. + Raises: + AssertionError: If attempting to load from a remote path without etils + installed. + ValueError: If data for specific leaf IDs is missing in the directory. + ImportError: If supported node type (e.g., flax's FrozenDict) cannot be + imported. + """ + assert not _is_remote_path(directory) or epath_installed, ( + "For checkpointing using remote URLs (e.g., gs, s3) you need `etils`" + " module installed. You can install it using `pip install etils`.") + + root = Path(directory) + assert root.is_dir(), f"Checkpoint directory {root} does not exist" + if not _is_remote_path(root): + root = root.resolve() + + # deserialize in 3 stages + + # 1. deserialize PyTree (if permissive inserting node_data) + if pytree is MISSING: + pytreedef = load_pytree(directory, best_effort) + # in pytreedef, None leafs indicate StaticNodes (without leaves) + # so we CANNOT flatten with is_leaf=lambda x: x is None + leaf_ids, pytreedef = jax.tree.flatten(pytreedef) + else: + leaf_ids, pytreedef = jax.tree.flatten(pytree, is_leaf=lambda x: x is None) + obj_leaf_ids = [int(leaf_id.split(_TYPE_ID_LEAF_DELIMITER, 1)[1]) for leaf_id + in leaf_ids if leaf_id is not None + and not leaf_id.startswith(_ARRAY_TYPE_NAME) + and not leaf_id.startswith("null")] + serialization_ctx = default_serialization_context.copy() + executor = ThreadPoolExecutor(max_workers=_MAX_CONCURRENCY) + + # 2. deserialize non-array objects + def _deserialize_objs(): + obj_archive = MemKVStore(data=( + root / _LEAF_DATA_DIR / _OBJ_DATA_ARCHIVE).read_bytes()) + _key2id = lambda x: int(Path(x).stem) + obj_keys = list(obj_archive.keys()) + missing_leaf_ids = set(obj_leaf_ids) - set(map(_key2id, obj_keys)) + requested_obj_keys = [obj_key for obj_key in obj_keys + if _key2id(obj_key) in obj_leaf_ids] + if len(missing_leaf_ids) > 0: + raise ValueError( + f"Values {missing_leaf_ids} are missing from the checkpoint directory.") + obj_futs = [executor.submit(_obj_deserialize, obj_archive, obj_key, + ctx=serialization_ctx, best_effort=best_effort) + for obj_key in requested_obj_keys] + obj_values = [fut.result() for fut in obj_futs] + return dict(safe_zip(map(_key2id, requested_obj_keys), obj_values)) + + objs_fut = executor.submit(_deserialize_objs) + + # 3. deserialize array objects + arr_leaf_ids = [leaf_id.split(_TYPE_ID_LEAF_DELIMITER, 1)[1] for leaf_id + in leaf_ids if leaf_id is not None and leaf_id.startswith(_ARRAY_TYPE_NAME)] + arr_paths = [root / _LEAF_DATA_DIR / _ARRAY_STORE_DIRNAME / leaf_id + for leaf_id in arr_leaf_ids] + # missing sharding assumes we want to deserialize on default device + if shardings is MISSING: + device = jax.devices()[0] # default device + shardings = [SingleDeviceSharding(device) for _ in arr_paths] + else: + shardings = jax.tree.flatten(shardings)[0] + assert len(shardings) == len(arr_paths), ( + "The sharding leaves must match the load arrays requested.") + ts_specs = (([None] * len(arr_paths)) if ts_specs is None else + jax.tree.leaves(ts_specs, + is_leaf=ts_impl.is_tensorstore_spec_leaf)) + byte_limiter = asyncio_utils._LimitInFlightBytes(100 * 1024 ** 3) # 100 GB + async def _deserialize_arrays(): + return await asyncio.gather(*[ + deserialize_array(path, sharding, ts_spec, byte_limiter) + for (path, sharding, ts_spec) + in safe_zip(arr_paths, shardings, ts_specs)]) + + arr_keys = [int(path.stem) for path in arr_paths] + arr_values_fut = executor.submit(asyncio.run, _deserialize_arrays()) + + # finally, collect the results + arrs = dict(zip(arr_keys, arr_values_fut.result())) + objs = objs_fut.result() + arr_and_objs = arrs | objs + filled_values = [arr_and_objs.get(leaf_id, None) + for leaf_id in range(len(leaf_ids))] + return jax.tree.unflatten(pytreedef, filled_values) + +def load_pytree(directory: str | PathLike[str], best_effort: bool = False + ) -> PyTreeDef: + """Loads a pytree from the given directory. + Args: + directory: Directory path to load from. + best_effort: Proceed with deserialization even in the face of partial + failures. Return custom nodes as a list of children. + Returns: + The loaded pytree. + """ + assert not _is_remote_path(directory) or epath_installed, ( + "For checkpointing using remote URLs (e.g., gs, s3) you need `etils`" + " module installed. You can install it using `pip install etils`.") + root = Path(directory) + json_content = (root / _PYTREEDEF_FILE).read_text() + raw_tree = json.loads(json_content) + serialization_ctx = default_serialization_context.copy() + executor = ThreadPoolExecutor(max_workers=_MAX_CONCURRENCY) + if (root / _NODE_DATA_ARCHIVE).exists(): + node_data_archive = MemKVStore(data=(root / _NODE_DATA_ARCHIVE + ).read_bytes()) + node_data_futs = {k: executor.submit(node_data_archive.read, k) + for k in node_data_archive.keys()} + node_data_store = {k: v.result() for k, v in node_data_futs.items()} + else: + node_data_store = dict() + return deserialize_pytree(raw_tree, node_data_store, ctx=serialization_ctx, + best_effort=best_effort) + + +class SerializationFuture: + """Keeps track of saving/loading serialized data via: + - self.done() - non-blocking check whether the underlying coroutine finished + - self.result() - gets the result of the coroutine, raises error if not done + - self.pytree - the property describing the data overview (short leaf desc.) + + The class takes in an async_fn and args/kwargs for it and launches it + immediately in a separate Python thread. This allows it to work both in sync + as well as in an async environment. + """ + def __init__(self, fn, *args, **kw): + self._pytree = None + # create a closure which will run an asyncio routine in separate thread + # and will populate either self._retval if no errors were raised or + # self._exception if there were errors + + # running an asyncio routine in a thread is a reliable way of scheduling + # an asyncio routine in the background both in regular synchronous contexts + # but also in (unexpectedly?) asynchronous contexts like Jupyter Notebooks + self._retval, self._exception = None, None + self._done_event = threading.Event() + self._fn = fn + + def _run_in_thread(): + ret, exc = None, None + try: + ret = self._fn(*args, **kw) + except Exception as e: # pylint: disable=broad-except + exc = e + self._done_event.set() + # populate either the result or the exception + self._retval, self._exception = ret, exc + + self._thread = threading.Thread(target=_run_in_thread) + self._thread.start() + # do not join the thread + + @property + def pytree(self): + return self._pytree + + @pytree.setter + def pytree(self, new_pytree: PyTreeDef): + msg = f"You cannot set the .pytree property in {type(self)} more than once." + assert self._pytree is None or self._pytree == new_pytree, msg + self._pytree = new_pytree + + def done(self): + """Check if the underlying deserialization is done. Non-blocking.""" + # return not self._thread.is_alive() + return self._done_event.is_set() + + def result(self): + """Retrieve the result or raise an exception of the `async_fn`.""" + assert self.done() + self._thread.join() + if self._exception is not None: # exception has been raised + raise self._exception + return self._retval + + def __await__(self): + while not self.done(): + yield + self._thread.join() + return self.result() + + def join(self): + """Wait for the underlying thread to complete.""" + return self._thread.join() + +def _pytree_leaf_desc(leaf): + if isinstance(leaf, (np.ndarray, jax.Array)): + return jax.ShapeDtypeStruct(leaf.shape, leaf.dtype) + else: + return leaf + +def nonblocking_save(data: PyTreeT, directory: str | PathLike[str], + overwrite: bool = True, partial_write: bool = False, + tensorstore_specs: PyTreeT | None = None, + ) -> SerializationFuture: + # start serialization immediately + fut = SerializationFuture(save, data, directory, overwrite, + partial_write, tensorstore_specs) + # construct a nice looking pytree representing the nodes being read + fut.pytree = jax.tree.map(_pytree_leaf_desc, data) + return fut + + +_is_desc_array = lambda x: (re.match(_ARRAY_TYPE_REGEX, + x.split(_TYPE_ID_LEAF_DELIMITER, 1)[0]) + is not None) +_none_is_leaf = lambda x: x is None + +def nonblocking_load(directory: str | PathLike[str], + shardings: PyTreeT | _MISSING_TYPE = MISSING, + pytree: PyTreeT | _MISSING_TYPE = MISSING, + tensorstore_specs: PyTreeT | None = None, + best_effort: bool = False) -> SerializationFuture: + if pytree is MISSING: + pytree = load_pytree(directory, best_effort=best_effort) + + # read in all the objects synchronously, but arrays asynchronously + # use the existing partial-read functionality + # the load here MUST BE CALLED BEFORE the async_load or it'll cause a deadlock + # on multiprocess CPU machines (that need and XLA communication fallback) + arr_pytree = jax.tree.map(lambda x: x if _is_desc_array(x) else None, pytree) + obj_pytree = jax.tree.map(lambda x: None if _is_desc_array(x) else x, pytree) + arr_shapes = jax.tree.map(_leaf_desc_to_leaf, arr_pytree) # skip None-s here + obj_data = load(directory, pytree=obj_pytree, best_effort=best_effort) + pytree_stub = jax.tree.map(lambda x, y: x if x is not None else y, + arr_shapes, obj_data, is_leaf=_none_is_leaf) + # TODO(rdyro): delete: pytree_stub = jax.tree.map(_leaf_desc_to_leaf, pytree) + + # TODO(rdyro): the output of this class is a workaround + # it should return the fully populated pytree instead of just + # jax.ShapeDtypeStruct for arrays by constructing them asynchronously + fut = SerializationFuture(load, directory, shardings, pytree, + tensorstore_specs, best_effort) + + fut.pytree = pytree_stub + return fut diff --git a/jax/experimental/array_serialization/pytree_serialization_utils.py b/jax/experimental/array_serialization/pytree_serialization_utils.py new file mode 100644 index 000000000000..878e7804c5ec --- /dev/null +++ b/jax/experimental/array_serialization/pytree_serialization_utils.py @@ -0,0 +1,365 @@ +# Copyright 2021 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# + +# # Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import os +import json +import dataclasses +import collections +import threading +import itertools +import logging +import io +import zipfile +import contextlib +from concurrent.futures import ThreadPoolExecutor +from types import ModuleType +from typing import Any, Callable + + +import jax +from jax.tree_util import PyTreeDef, default_registry, treedef_is_leaf + +PickleModule = ModuleType + +logger = logging.getLogger(__name__) + +_TREE_REPR_KEY = "__jax_tree_repr" +_LEAF_IDS_KEY = "__jax_leaf_ids" +_LEAF_COUNT_KEY = "__jax_leaf_count" +_NODE_DATA_ARCHIVE_KEY_FORMAT = "___jax_node_data_ref_{}" +_NODE_DATA_ARCHIVE_KEY_REGEX = r"___jax_node_data_ref_([0-9]+)" + +def _cls2typerepr(cls): + return f"{cls.__module__}.{cls.__name__}" + +class MemKVStore: + def __init__(self, data: bytes | None = None): + self.buffer = io.BytesIO(data) if data is not None else io.BytesIO() + self.buffer.seek(0) + self.zipfile = zipfile.ZipFile( + self.buffer, mode="a", compression=zipfile.ZIP_DEFLATED, allowZip64=True) + self._lock, self._closed = threading.Lock(), False + + def keys(self) -> list[str]: + assert not self._closed + return self.zipfile.namelist() + + def tobytes(self) -> bytes: + assert not self._closed + with self._lock: + self.zipfile.close() + self._closed = True + return self.buffer.getvalue() + + def write(self, filename: str | os.PathLike[str], data: bytes | str) -> None: + assert not self._closed + with self._lock: + self.zipfile.writestr(str(filename), data) + + def read(self, filename: str | os.PathLike[str]) -> bytes: + assert not self._closed + return self.zipfile.read(str(filename)) + +identity = lambda x: x + +class SerializationRegistry: + def __init__(self): + self._serialization_map, self._deserialization_map = {}, {} + for t in [int, float, str, complex, bool]: + self._serialization_map[_cls2typerepr(t)] = (json.dumps, "json") + self._deserialization_map["json"] = json.loads + for t in [bytes, bytearray]: + self._serialization_map[_cls2typerepr(t)] = (identity, "bin") + self._deserialization_map["bin"] = identity + self._fallback_serialize_fn: Callable[[str, Any], str | bytes] | None = None + self._fallback_deserialize_fn: Callable[[str, str | bytes], Any] | None = None + + def register_type(self, cls: type[Any], + serialize_fn: Callable[[Any], Any] | None = None, + deserialize_fn: Callable[[Any], Any] | None = None, + name: str | None = None): + name = name if name is not None else _cls2typerepr(cls) + self._serialization_map[name] = (serialize_fn, name) + self._deserialization_map[name] = deserialize_fn + + def serialize(self, obj: Any, name: str | None = None + ) -> tuple[bytes | str, str]: + cls = type(obj) + name = name if name is not None else _cls2typerepr(cls) + if name in self._serialization_map: + serialization_fn, method_name = self._serialization_map[name] + return serialization_fn(obj), method_name + else: + if self._fallback_serialize_fn is not None: + return self._fallback_serialize_fn(name, obj), "fallback" + else: + raise ValueError(f"Class `{cls}` not registered for serialization.") + + def deserialize(self, data: bytes, method_name: str) -> Any: + if method_name in self._deserialization_map: + return self._deserialization_map[method_name](data) + else: + if self._fallback_deserialize_fn is not None: + return self._fallback_deserialize_fn(method_name, data) + else: + raise ValueError(f"Extension `{method_name}` not registered for" + " deserialization.") + + def register_pickle_fallback(self, pickle_module: PickleModule): + def _fallback_serialize_fn(class_name: str, obj: Any) -> bytes: + del class_name + return pickle_module.dumps(obj) + + def _fallback_deserialize_fn(class_name: str, data: str | bytes) -> Any: + del class_name + return pickle_module.loads(data) + + self._fallback_serialize_fn = _fallback_serialize_fn + self._fallback_deserialize_fn = _fallback_deserialize_fn + + def deregister_fallback(self): + self._fallback_serialize_fn = None + self._fallback_deserialize_fn = None + + def copy(self): + new_registry = SerializationRegistry() + new_registry._serialization_map = self._serialization_map.copy() + new_registry._deserialization_map = self._deserialization_map.copy() + new_registry._fallback_serialize_fn = self._fallback_serialize_fn + new_registry._fallback_deserialize_fn = self._fallback_deserialize_fn + return new_registry + + +class NodeSerializationRegistry(SerializationRegistry): + def __init__(self): + super().__init__() + self.node_registry = { + _cls2typerepr(dict): dict, + _cls2typerepr(list): list, + _cls2typerepr(tuple): tuple, + _cls2typerepr(set): set, + _cls2typerepr(collections.OrderedDict): collections.OrderedDict, + # "flax.core.frozen_dict.FrozenDict": "flax.core.frozen_dict.FrozenDict", + } + self.in_tree_node_registry = set(self.node_registry.keys()) + + def register_node(self, node_type: type[Any], + serialize_fn: Callable[[Any], Any] | None = None, + deserialize_fn: Callable[[Any], Any] | None = None, + place_data_in_tree: bool = False): + type_repr = _cls2typerepr(node_type) + self.node_registry[type_repr] = node_type + if place_data_in_tree: + self.in_tree_node_registry.add(type_repr) + else: + super().register_type(node_type, serialize_fn, deserialize_fn, + name=type_repr) + + def copy(self): + new_registry = NodeSerializationRegistry() + new_registry._serialization_map = self._serialization_map.copy() + new_registry._deserialization_map = self._deserialization_map.copy() + new_registry._fallback_serialize_fn = self._fallback_serialize_fn + new_registry._fallback_deserialize_fn = self._fallback_deserialize_fn + new_registry.node_registry = self.node_registry.copy() + new_registry.in_tree_node_registry = self.in_tree_node_registry.copy() + return new_registry + + +@dataclasses.dataclass +class SerializationContext: + leaf_serialization_registry: SerializationRegistry = dataclasses.field( + default_factory=SerializationRegistry) + node_serialization_registry: NodeSerializationRegistry = dataclasses.field( + default_factory=NodeSerializationRegistry) + thread_pool_executor: ThreadPoolExecutor = dataclasses.field( + default_factory=lambda: ThreadPoolExecutor(max_workers=32)) + _lock: threading.Lock = dataclasses.field(default_factory=threading.Lock) + + def copy(self) -> "SerializationContext": + with self._lock: + return SerializationContext( + leaf_serialization_registry=self.leaf_serialization_registry.copy(), + node_serialization_registry=self.node_serialization_registry.copy() + ) + + def register_custom_node( + self, node_type: type[Any], + serialize_fn: Callable[[Any], Any] | None = None, + deserialize_fn: Callable[[Any], Any] | None = None, + place_data_in_tree: bool = False) -> None: + self.node_serialization_registry.register_node( + node_type, serialize_fn, deserialize_fn, place_data_in_tree) + + def deregister_custom_node(self, node_type: type[Any]) -> None: + type_repr = _cls2typerepr(node_type) + self.node_serialization_registry.node_registry.pop(type_repr) + if type_repr in self.node_serialization_registry.in_tree_node_registry: + self.node_serialization_registry.in_tree_node_registry.remove(type_repr) + else: + self.node_serialization_registry._serialization_map.pop(type_repr) + self.node_serialization_registry._deserialization_map.pop(type_repr) + + def register_custom_leaf( + self, cls: type[Any], serialize_fn: Callable[[Any], Any] | None = None, + deserialize_fn: Callable[[Any], Any] | None = None, + name: str | None = None) -> None: + self.leaf_serialization_registry.register_type(cls, serialize_fn, + deserialize_fn, name) + + @contextlib.contextmanager + def with_fallback(self, pickle_module: PickleModule): + try: + self.leaf_serialization_registry.register_pickle_fallback(pickle_module) + self.node_serialization_registry.register_pickle_fallback(pickle_module) + yield + finally: + self.leaf_serialization_registry.deregister_fallback() + self.node_serialization_registry.deregister_fallback() + +default_serialization_context = SerializationContext() + +def _node_serialize(node_data_store: dict[str, Any], counter: itertools.count, + node_type_data: tuple[type[Any], Any], + ctx: SerializationContext) -> tuple[str, Any]: + node_type, node_data = node_type_data + type_repr = _cls2typerepr(node_type) + if type_repr in ctx.node_serialization_registry.in_tree_node_registry: + return (type_repr, node_data) + else: + id = next(counter) + name = _NODE_DATA_ARCHIVE_KEY_FORMAT.format(id) + assert node_data_store is not None, ( + "Archive must be provided for not in-tree node data.") + + node_data_store[f"{name}.{type_repr}"] = ctx.thread_pool_executor.submit( + lambda: ctx.node_serialization_registry.serialize(node_data, + name=type_repr)[0]) + return (type_repr, name) + +def _node_deserialize(node_data_store: dict[str, Any], + node_type_data: tuple[str, Any], + ctx: SerializationContext, best_effort: bool = False + ) -> tuple[type[Any], Any]: + serialization_registry = ctx.node_serialization_registry + type_repr, node_data = node_type_data + try: + # get the class from the registry or try to import it if pickle_module + if type_repr in serialization_registry.node_registry: + node_type = serialization_registry.node_registry[type_repr] + else: + raise ValueError(f"Type `{type_repr}` not registered for" + " (de)serialization.") + + # if the node has an in-tree representation, just return that data + if type_repr in serialization_registry.in_tree_node_registry: + return (node_type, node_data) + + node_id = re.match(_NODE_DATA_ARCHIVE_KEY_REGEX, node_data) + if node_id is None: + raise ValueError(f"Node type {type_repr} is not registered as in-tree," + f" but the store key {node_data} does not match the" + f" expected format. Check that the node was registered" + f" that same way as during serialization. Otherwise use" + f" use argument `best_effort=True` to reconstruct as" + f" a list of children.") + node_id = int(node_id.group(1)) # type: ignore + filename = f"{_NODE_DATA_ARCHIVE_KEY_FORMAT.format(node_id)}.{type_repr}" + payload = node_data_store[filename] + node_data = serialization_registry.deserialize(payload, type_repr) + return (node_type, node_data) + except Exception as e: # pylint: disable=broad-except + if best_effort: + logger.warning("We couldn't read the node %s, returning list of children", + type_repr) + return (list, None) + else: + raise e + + +################################################################################ + +def _serialize_pytree_helper(node, leaf_counter: itertools.count, + node_counter: itertools.count, + node_data_store: dict[str, Any], + ctx: SerializationContext): + if treedef_is_leaf(node) and node.num_leaves == 1: + return dict(node_type="leaf", leaf_id=next(leaf_counter)) + node_repr = dict() + type_repr, node_data = _node_serialize(node_data_store, node_counter, + node.node_data(), ctx) + node_repr["name"], node_repr["node_data_ref"] = type_repr, node_data + node_repr["node_type"] = "static_node" if node.num_nodes == 1 else "node" + node_repr["children"] = [_serialize_pytree_helper( # type: ignore + child, leaf_counter, node_counter, node_data_store, ctx) + for child in node.children()] + return node_repr + +@jax.tree_util.register_static +class _EmptyStaticNode: + pass + +def _deserialize_pytree_helper(node, node_data_store: dict[str, Any], + ctx: SerializationContext, + best_effort: bool = False): + assert "node_type" in node + + if node["node_type"] == "leaf": + # case 1: normal leaf node ------------------------------------------------- + node_data, pytree_children = None, () + else: + node_data = _node_deserialize(node_data_store, + (node["name"], node["node_data_ref"]), ctx, + best_effort=best_effort) + + pytree_children = [_deserialize_pytree_helper( + child, node_data_store, ctx, best_effort=best_effort) + for child in node["children"]] # type: ignore + if (node["node_type"] == "static_node" and best_effort + and node_data[0] is list): + # the node failed to deserialize and was replaced at best effort with list + return jax.tree.structure(_EmptyStaticNode()) + pt = PyTreeDef.make_from_node_data_and_children(default_registry, node_data, + pytree_children) + return pt + + +# serialize and deserialize pytree methods namespaces: permissive and strict +def serialize_pytree(node, ctx: SerializationContext | None = None + ) -> tuple[dict[str, Any], int, dict[str, Any]]: + node_data_store: dict[str, Any] = {} + ctx = ctx if ctx is not None else default_serialization_context.copy() + leaf_counter, node_counter = itertools.count(), itertools.count() + root_repr = _serialize_pytree_helper( + node, leaf_counter, node_counter, node_data_store, ctx) + leaf_count = next(leaf_counter) + tree_repr = {_TREE_REPR_KEY: root_repr, _LEAF_COUNT_KEY: leaf_count, + _LEAF_IDS_KEY: list(range(leaf_count))} + + # gather data from the thread pool executor + node_data_store = {k: v.result() for k, v in node_data_store.items()} + return tree_repr, leaf_count, node_data_store + +def deserialize_pytree(rawtree: dict[str, Any], + node_data_store: dict[str, Any] | None = None, + ctx: SerializationContext | None = None, + best_effort: bool = False) -> Any: + node_data_store = {} if node_data_store is None else node_data_store + ctx = ctx if ctx is not None else default_serialization_context.copy() + pt = _deserialize_pytree_helper(rawtree[_TREE_REPR_KEY], node_data_store, ctx, + best_effort=best_effort) + leaf_ids = rawtree[_LEAF_IDS_KEY] + return jax.tree.unflatten(pt, leaf_ids) diff --git a/jax/experimental/array_serialization/serialization.py b/jax/experimental/array_serialization/serialization.py index 2620f5cc760c..9ecd5f9fe313 100644 --- a/jax/experimental/array_serialization/serialization.py +++ b/jax/experimental/array_serialization/serialization.py @@ -17,36 +17,37 @@ import abc import asyncio -from collections.abc import Awaitable, Callable, Sequence -from functools import partial +from collections.abc import Callable, Sequence +import functools import itertools import logging -import os import re import threading import time -from typing import Any, Optional +from typing import Any import jax from jax._src import array -from jax._src import config from jax._src import distributed from jax._src import sharding -from jax._src import sharding_impls -from jax._src.layout import Layout, DeviceLocalLayout as DLL from jax._src import typing from jax._src import util +from jax._src.layout import Layout from jax._src.lib import xla_extension as xe -import jax.numpy as jnp -import numpy as np +from jax.experimental.array_serialization import tensorstore_impl as ts_impl +# pylint: disable=unused-import +# import tensorstore-backed methods for backward compatibility. +from jax.experimental.array_serialization.tensorstore_impl import ( + get_tensorstore_spec as _new_get_tensorstore_spec, + run_deserialization, run_serialization, + async_serialize, async_deserialize) +# pylint: enable=unused-import import tensorstore as ts -TS_CONTEXT = ts.Context({'file_io_concurrency': {'limit': 128}}) _REMOVED_VALUE = 'Value removed' _CHECKPOINT_SUCCESS = 'checkpoint_write_success' _module_unique_count = itertools.count() -_DEFAULT_DRIVER = 'file' _DISTRIBUTED_SYSTEM_MSG = ( 'Please initialize the distributed system via ' '`jax.distributed.initialize()` at the start of your program.') @@ -55,8 +56,9 @@ {'driver': 'gcs', 'path_regex': None}, {'driver': 's3', 'path_regex': None}, ] +_DEFAULT_DRIVER = 'file' -class BarrierTimeoutException(Exception): +class BarrierTimeoutError(Exception): pass _BARRIER_TIMED_OUT_MSG = ( @@ -68,68 +70,6 @@ class BarrierTimeoutException(Exception): logger = logging.getLogger(__name__) -async def create_async_array_from_callback( - global_shape: array.Shape, - inp_sharding: jax.sharding.Sharding, - data_callback: Callable[[array.Index, jax.Device], Awaitable[jax.Array]], -): - device_to_index_map = inp_sharding.devices_indices_map(global_shape) - addressable_da = inp_sharding._addressable_device_assignment - future_arrays = [data_callback(device_to_index_map[d], d) - for d in addressable_da] - dbs = await asyncio.gather(*future_arrays) - return array.make_array_from_single_device_arrays( - global_shape, inp_sharding, dbs) - - -def _get_metadata(arr): - local_shape = arr.addressable_data(0).shape - return { - 'compressor': {'id': 'zstd'}, - 'shape': arr.shape, - 'chunks': np.array(np.maximum(1, local_shape)), - } - - -def _spec_has_metadata(tree): - if not isinstance(tree, dict): - return False - return 'metadata' in tree or any( - _spec_has_metadata(subtree) for _, subtree in tree.items()) - -def _get_kvstore_for_gcs(ckpt_path: str): - m = re.fullmatch('^gs://([^/]*)/(.*)$', ckpt_path, re.DOTALL) - if m is None: - raise ValueError('The ckpt_path should contain the bucket name and the ' - f'file path inside the bucket. Got: {ckpt_path}') - gcs_bucket = m.group(1) - path_without_bucket = m.group(2) - return {'driver': 'gcs', 'bucket': gcs_bucket, 'path': path_without_bucket} - -def get_tensorstore_spec(ckpt_path: str, ocdbt: bool = False): - # Normalize path to exclude trailing '/'. In GCS path case, we will need to - # fix the path prefix to add back the stripped '/'. - ckpt_path = os.path.normpath(ckpt_path).replace('gs:/', 'gs://') - is_gcs_path = ckpt_path.startswith('gs://') - spec = {'driver': 'zarr', 'kvstore': {}} - if ocdbt: - if not is_gcs_path and not os.path.isabs(ckpt_path): - raise ValueError(f'Checkpoint path should be absolute. Got {ckpt_path}') - base_path = os.path.dirname(ckpt_path) - spec['kvstore'] = { - 'driver': 'ocdbt', - 'base': base_path if is_gcs_path else f'{_DEFAULT_DRIVER}://{base_path}', - 'path': os.path.basename(ckpt_path), - } - else: - if is_gcs_path: - spec['kvstore'] = _get_kvstore_for_gcs(ckpt_path) - else: - spec['kvstore'] = {'driver': _DEFAULT_DRIVER, 'path': ckpt_path} - - return spec - - def is_remote_storage(tspec: dict[str, Any] | str) -> bool: """Detect if user is using cloud storages. @@ -159,278 +99,11 @@ def is_remote_storage(tspec: dict[str, Any] | str) -> bool: return False - -# Lifted from T5X. -class _LimitInFlightBytes: - """Limits in-flight bytes when reading/writing checkpoints per process.""" - - def __init__(self, num_bytes): - self._max_bytes = num_bytes - self._available_bytes = num_bytes - self._cv = asyncio.Condition(lock=asyncio.Lock()) - - async def wait_for_bytes(self, requested_bytes): - if requested_bytes > self._max_bytes: - raise ValueError('Requested more bytes than we reserved space for: ' - f'{requested_bytes} > {self._max_bytes}') - async with self._cv: - await self._cv.wait_for(lambda: self._available_bytes > requested_bytes) - self._available_bytes -= requested_bytes - assert self._available_bytes >= 0 - - async def release_bytes(self, requested_bytes): - async with self._cv: - self._available_bytes += requested_bytes - assert self._available_bytes <= self._max_bytes - self._cv.notify_all() - - -async def transfer_shard_to_host(shard: array.Shard) -> np.ndarray: - data = shard.data - has_pinned_host = any( - m.kind == "pinned_host" for m in shard.device.addressable_memories()) - if has_pinned_host: - # If available, transfer to pinned host memory - sharding = jax.sharding.SingleDeviceSharding(shard.device, - memory_kind="pinned_host") - data = jax.device_put(data, sharding) - else: - data.copy_to_host_async() - # Allow other transfers to be scheduled simultaneously - await asyncio.sleep(0) - # Ensure that jax.Array's internal numpy array can be zero-copied. Tensorstore - # implicitly converts the written data to a numpy array, and would otherwise - # silently copy host-to-host. - return np.array(data, copy=False) - - -async def async_serialize( - arr_inp, - tensorstore_spec, - commit_future=None, - context=TS_CONTEXT, - primary_host: int | None = 0, - replica_id: int = 0, - transaction: Optional[ts.Transaction] = None, -): - """Serialize an array using TensorStore. - - Args: - arr_inp: The array to serialize. - tensorstore_spec: The tensorstore spec to use. - commit_future: A list of futures that will be appended to. The futures can - be awaited asynchronously. If None, the futures will be awaited - synchronously by this method. - context: ts.Context instance. - primary_host: Primary host, which indicates the host that will be treated as - the "leader". If None, all hosts are treated as the primary. DO NOT USE - unless you are sure you know what you are doing. - replica_id: Allows overriding the shard replica id that will be saved. DO - NOT USE unless you are sure you know what you are doing. - transaction: TensorStore transaction to use for opening and writing the - array. If not specified, a non-transactional write will be used. - """ - if (isinstance(arr_inp, array.ArrayImpl) and jax.process_count() > 1 and - arr_inp.is_fully_addressable): - raise ValueError( - f'Passing fully addressable arrays to a multiprocess ' - f'serialization is not allowed, as this may lead to a race condition ' - f'between processes. Serialization have failed for the array with ' - f'the path "{tensorstore_spec["kvstore"]["path"]}".') - - # 'metadata' may not be present at the top level (for example, if we are using - # a 'cast' driver). - if not _spec_has_metadata(tensorstore_spec): - tensorstore_spec['metadata'] = _get_metadata(arr_inp) - - # Set dtype if it's not in spec - if 'dtype' not in tensorstore_spec: - tensorstore_spec['dtype'] = jnp.dtype(arr_inp.dtype).name - - # If primary_host is None, all hosts will checkpoint. This is used - # for checkpointing to local filesystem. - if primary_host is None or jax.process_index() == primary_host: - open_future = ts.open( - ts.Spec(tensorstore_spec), - create=True, - open=True, - context=context, - transaction=transaction, - ) - # Asynchronous case. - if commit_future is not None: - assert isinstance(commit_future, list) - commit_future.append(open_future) - else: - await open_future - - # `ts.open` runs twice for process `primary_host` because for the first time, - # we just get the future to be awaited upon in the background thread. The - # second one runs with `assume_metadata=True` which does no I/O operation and - # returns the tensorstore object. - # For every process other than `primary_host`, we open with - # `assume_metadata=True`. - t = await ts.open( - ts.Spec(tensorstore_spec), - open=True, - assume_metadata=True, - context=context, - transaction=transaction, - ) - - async def _write_array(shard): - if shard.replica_id == replica_id: - data = await transfer_shard_to_host(shard) - write_future = t[shard.index].write( - data, - # Avoid additional copy of input array into the TensorStore chunk - # cache. If `arr_inp` is a jax.Array, the result of converting - # it to a NumPy array, as is done internally by TensorStore, is - # guaranteed to be immutable and therefore it is safe to retain a - # reference indefinitely. - can_reference_source_data_indefinitely=isinstance( - arr_inp, array.ArrayImpl - ), - ) - if commit_future is not None: - assert isinstance(commit_future, list) - commit_future.append(write_future.commit) - await write_future.copy - else: - await write_future.commit - - local_shards = arr_inp.addressable_shards - future_write_state = jax.tree_util.tree_map(_write_array, local_shards) - return await asyncio.gather(*future_write_state) - - -def run_serialization(arrays, tensorstore_specs): - async def _run_serializer(): - future_writer = jax.tree_util.tree_map(async_serialize, arrays, tensorstore_specs) - return await asyncio.gather(*future_writer) - asyncio.run(_run_serializer()) - - -def estimate_read_memory_footprint(t: ts.TensorStore, - domain: ts.IndexDomain) -> int: - rank = t.rank - num_bytes = t.dtype.numpy_dtype.itemsize - chunk_template = t.chunk_layout.read_chunk_template - if domain is None: - domain = t.domain - origin = domain.origin - shape = domain.shape - chunk_origin = chunk_template.origin - chunk_shape = chunk_template.shape - - # Some TensorStore drivers are not chunked, e.g. the inline 'array' driver. - # For those, instead of returning a near-infinite memory footprint, estimate - # the footprint as the entire shape. - for i in range(rank): - if not chunk_template[i].finite: - return domain.size * num_bytes - - # Otherwise, if we have a chunked driver, estimate based on chunk size. - for i in range(rank): - origin_value = origin[i] - chunk_origin_value = chunk_origin[i] - chunk_size = chunk_shape[i] - lower = origin_value - chunk_origin_value - upper = origin_value + shape[i] - chunk_origin_value - lower_aligned = lower // chunk_size * chunk_size - upper_aligned = -(-upper // chunk_size) * chunk_size - num_bytes *= (upper_aligned - lower_aligned) - - return num_bytes - - -async def async_deserialize( - user_in_sharding: jax.sharding.Sharding | Layout, - tensorstore_spec: ts.Spec | dict[str, Any], - global_shape: Sequence[int] | None = None, - dtype=None, - byte_limiter: _LimitInFlightBytes | None = None, - context=TS_CONTEXT, - assume_metadata: bool = False, -): - in_sharding = (user_in_sharding.sharding - if isinstance(user_in_sharding, Layout) else user_in_sharding) - if not isinstance(in_sharding, jax.sharding.Sharding): - raise ValueError( - 'sharding passed to deserialization should be specified, concrete and' - f' an instance of `jax.sharding.Sharding`. Got {in_sharding}') - dll = (user_in_sharding.device_local_layout - if isinstance(user_in_sharding, Layout) else None) - t = await ts.open( - tensorstore_spec, - open=True, - assume_metadata=assume_metadata, - context=context, - ) - shape = t.shape if global_shape is None else global_shape - new_shard_shape = in_sharding.shard_shape(tuple(shape)) - - async def cb(index: array.Index, device: jax.Device): - requested_domain = ts.IndexTransform(input_shape=shape)[index].domain - restricted_domain = t.domain.intersect(requested_domain) - requested_bytes = estimate_read_memory_footprint(t, restricted_domain) - # Limit the bytes read for every shard. - if byte_limiter is not None: - await byte_limiter.wait_for_bytes(requested_bytes) - # This maybe needed because the shape the array was saved with is smaller - # than the requested shape of the array in which it will be reloaded. So - # the extra values will be filled with 0s. - out = np.zeros(new_shard_shape, dtype=t.dtype.numpy_dtype) - await ts.array(out)[ts.d[:].translate_to[requested_domain.origin]][ - restricted_domain].write(t[restricted_domain]) - if dtype is not None: - # Cast while reloading on process to avoid 2 copies on device if the - # casting is done on device. - out = out.astype(dtype) - # Convert to jnp array so that layouts are initialized properly for - # sub-byte dtypes. - # TODO(yashkatariya): This is a band-aid fix. Figure out a better way to - # make this work. - if out.dtype == jnp.int4: - out = jnp.asarray(out) # type: ignore - result = jax.device_put( - out, Layout(dll, jax.sharding.SingleDeviceSharding(device))) - if byte_limiter is not None: - # NB: `out` actually might not be ready for garbage collection by the - # time we call release_bytes . Thus peak memory usage still might grow - # beyond what byte_limiter limit suggests it should. The simplest option - # would be to call `result.block_until_ready()`` here. However it - # also comes with ~15-20% perf penalty as we would be waiting for CPU->GPU - # transfer instead of loading data. In the future, if memory pressure - # becomes a problem, we can instead instrument bytelimiter to - # keep track of all in-flight tensors and only block_until_ready, if byte - # limiter hits the limit to get reduced memory usage, without losing - # performance in common use cases. - await byte_limiter.release_bytes(requested_bytes) - return result - - return await create_async_array_from_callback(tuple(shape), in_sharding, cb) - - -def run_deserialization(shardings: Sequence[sharding.Sharding | Layout], - tensorstore_specs: Sequence[dict[str, Any]], - global_shapes: Sequence[array.Shape] | None = None, - dtypes: Sequence[typing.DTypeLike] | None = None, - concurrent_gb: int = 32): - concurrent_bytes = concurrent_gb * 10**9 - - async def _run_deserializer(): - # Object should be created once per process. - byte_limiter = _LimitInFlightBytes(concurrent_bytes) - - future_arrays = jax.tree_util.tree_map( - partial(async_deserialize, byte_limiter=byte_limiter), - shardings, tensorstore_specs, - [None] * len(tensorstore_specs) if global_shapes is None else global_shapes, - [None] * len(tensorstore_specs) if dtypes is None else dtypes) - return await asyncio.gather(*future_arrays) - return asyncio.run(_run_deserializer()) - +# for compatibility with older zarr format +_get_metadata = functools.partial(ts_impl. get_tensorstore_metadata, + driver='zarr') +get_tensorstore_spec = functools.partial(ts_impl.get_tensorstore_spec, + driver='zarr') def _get_key(key: int): return f'tensorstore_checkpoint_{key}' @@ -536,21 +209,28 @@ def _thread_func(self): logger.info('Finished committing to storage layer by process: %s', current_process) + key_for_barrier = None if process_count > 1: # All processes will wait at the barrier. When all processes are at the # barrier, the barrier will be satisfied. If not, then it will timeout. key_for_barrier = _get_key(self._count) logger.info('Key used for barrier is %s for process %s', key_for_barrier, current_process) - self._client.wait_at_barrier(key_for_barrier, self._timeout_in_ms) + # pytype: disable=attribute-error + self._client.wait_at_barrier(key_for_barrier, self._timeout_in_ms, + process_ids=None) + # pytype: enable=attribute-error logger.info('Finished waiting at barrier for process %s', current_process) if current_process == 0: - self._on_commit_callback() - logger.info('on_commit_callback successfully ran!') + if self._on_commit_callback is not None: + self._on_commit_callback() + logger.info('on_commit_callback successfully ran!') if process_count > 1: + # pytype: disable=attribute-error self._client.key_value_set(key_for_barrier, _CHECKPOINT_SUCCESS) + # pytype: enable=attribute-error logger.info('Process 0 successfully set key %s in the kv store', key_for_barrier) @@ -558,7 +238,7 @@ def _thread_func(self): '/jax/checkpoint/write/async/thread_duration_sec', time.time() - thread_start_time) - except Exception as e: + except Exception as e: # pylint: disable=broad-except self._exception = e def _start_async_commit(self, on_commit_callback): @@ -575,7 +255,7 @@ def check_for_errors(self): self._exception = None if (isinstance(exception, xe.XlaRuntimeError) and 'DEADLINE_EXCEEDED: Barrier timed out' in str(exception)): - raise BarrierTimeoutException( + raise BarrierTimeoutError( '\n'.join([str(exception), _BARRIER_TIMED_OUT_MSG])) raise exception # pylint: disable=raising-bad-type @@ -592,7 +272,9 @@ def wait_until_finished(self): # Block until process 0 writes success value to the key value store. # If it fails to write it, then `blocking_key_value_get` will time out. get_key = _get_key(self._count) + # pytype: disable=attribute-error self._client.blocking_key_value_get(get_key, self._timeout_in_ms) + # pytype: enable=attribute-error logger.info('blocking_key_value_get on key %s was successfully ' 'completed.', get_key) @@ -608,8 +290,8 @@ def serialize( arrays, tensorstore_specs, *, - on_commit_callback, - transaction: Optional[ts.Transaction] = None, + on_commit_callback: Callable[[], None], + transaction: ts.Transaction | None = None, ): """Serializes Arrays or Arrays via TensorStore asynchronously. @@ -642,7 +324,7 @@ def serialize( async def _run_serializer(): future_writer = jax.tree_util.tree_map( - lambda arr_inp, tensorstore_spec: async_serialize( + lambda arr_inp, tensorstore_spec: ts_impl.async_serialize( arr_inp, tensorstore_spec, commit_future=commit_futures, @@ -652,7 +334,6 @@ async def _run_serializer(): tensorstore_specs, ) return await asyncio.gather(*future_writer) - asyncio.run(_run_serializer()) self._add_futures(commit_futures) @@ -666,11 +347,13 @@ def serialize_with_paths( arrays: Sequence[jax.Array], paths: Sequence[str], *, - on_commit_callback, - transaction: Optional[ts.Transaction] = None, + on_commit_callback: Callable[[], None] | None = None, + transaction: ts.Transaction | None = None, ): tspecs = jax.tree.map(get_tensorstore_spec, paths) - self.serialize( + if on_commit_callback is None: + on_commit_callback = lambda: None + return self.serialize( arrays, tspecs, on_commit_callback=on_commit_callback, @@ -683,8 +366,8 @@ def deserialize(self, shardings: Sequence[sharding.Sharding | Layout], dtypes: Sequence[typing.DTypeLike] | None = None, concurrent_gb: int = 32): self.wait_until_finished() - return run_deserialization(shardings, tensorstore_specs, - global_shapes, dtypes, concurrent_gb) + return ts_impl.run_deserialization( + shardings, tensorstore_specs, global_shapes, dtypes, concurrent_gb) def deserialize_with_paths( self, shardings: Sequence[sharding.Sharding], diff --git a/jax/experimental/array_serialization/serialization_test.py b/jax/experimental/array_serialization/serialization_test.py index 61993637912f..0ec7f9b628dc 100644 --- a/jax/experimental/array_serialization/serialization_test.py +++ b/jax/experimental/array_serialization/serialization_test.py @@ -13,34 +13,55 @@ # limitations under the License. """Tests for serialization and deserialization of GDA.""" +# pylint: disable=g-importing-member import asyncio import contextlib +from dataclasses import dataclass +import functools +import json +import logging import math -from functools import partial import os import pathlib +import pickle +import tempfile +import time import tracemalloc as tm +from typing import Any from absl.testing import absltest from absl.testing import parameterized import jax -import jax.numpy as jnp -from jax._src import test_util as jtu +from jax import random +from jax import tree from jax._src import array -from jax._src import xla_bridge as xb -from jax.sharding import NamedSharding, GSPMDSharding, SingleDeviceSharding -from jax.sharding import PartitionSpec as P +from jax._src import test_util as jtu +from jax._src.layout import DeviceLocalLayout as DLL +from jax._src.layout import Layout +from jax.experimental.array_serialization import pytree_serialization from jax.experimental.array_serialization import serialization -from jax.experimental.layout import Layout, DeviceLocalLayout as DLL +from jax.experimental.array_serialization import tensorstore_impl as ts_impl +from jax.experimental.array_serialization.pytree_serialization_utils import ( + default_serialization_context) +import jax.numpy as jnp + +from jax.sharding import GSPMDSharding +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec as P +from jax.sharding import SingleDeviceSharding import numpy as np import tensorstore as ts +# pylint: enable=g-importing-member + jax.config.parse_flags_with_absl() _exit_stack = contextlib.ExitStack() + def setUpModule(): _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) + def tearDownModule(): _exit_stack.close() @@ -62,18 +83,19 @@ def test_memory_consumption(self): inp_shape, sharding, lambda idx: src[idx]) ckpt_dir = pathlib.Path(self.create_tempdir('memprof').full_path) - tspec = serialization.get_tensorstore_spec(str(ckpt_dir)) + tspec = ts_impl.get_tensorstore_spec(str(ckpt_dir)) manager = serialization.GlobalAsyncCheckpointManager() manager.serialize( [inp], [tspec], - on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir)) + on_commit_callback=functools.partial( + self._on_commit_callback, ckpt_dir, ckpt_dir)) manager.wait_until_finished() async def deserialize_with_byte_limit(): r = await serialization.async_deserialize( - sharding, tspec, inp_shape, - byte_limiter=serialization._LimitInFlightBytes(4_200_000)) + sharding, tspec, inp_shape, + byte_limiter=serialization._LimitInFlightBytes(4_200_000)) r.block_until_ready() tm.start() @@ -107,24 +129,21 @@ def test_memory_consumption_for_save(self): inp_shape, sharding, lambda idx: src[idx] ) ckpt_dir = pathlib.Path(self.create_tempdir('memprofsave').full_path) - tspec = serialization.get_tensorstore_spec(str(ckpt_dir)) + tspec = ts_impl.get_tensorstore_spec(str(ckpt_dir)) tspec['metadata'] = { 'shape': inp.shape, - 'compressor': None, - 'chunks': inp.shape, + 'data_type': jnp.dtype(inp.dtype).name, + 'chunk_grid': { + 'name': 'regular', + 'configuration': {'chunk_shape': np.array(np.maximum(1, inp.shape))} + } } - is_cpu = jtu.test_device_matches(['cpu']) tm.start() try: manager = serialization.GlobalAsyncCheckpointManager() - manager.serialize( - [inp], - [tspec], - on_commit_callback=partial( - self._on_commit_callback, ckpt_dir, ckpt_dir - ), - ) + manager.serialize([inp], [tspec], on_commit_callback=functools.partial( + self._on_commit_callback, ckpt_dir, ckpt_dir)) manager.wait_until_finished() unused_current, peak = tm.get_traced_memory() self.assertLess(peak, src.nbytes * (1 * (not is_cpu) + 0.5)) @@ -150,7 +169,8 @@ def test_checkpointing_with_path_variant(self): manager = serialization.GlobalAsyncCheckpointManager() manager.serialize_with_paths( [a1], ckpt_paths, - on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir)) + on_commit_callback=functools.partial( + self._on_commit_callback, ckpt_dir, ckpt_dir)) manager.wait_until_finished() m1, = manager.deserialize_with_paths( @@ -175,7 +195,8 @@ def test_checkpointing_jax_array(self): inp_shape, NamedSharding(global_mesh, pspec), lambda idx: global_input_data1[idx]) ckpt_dir = pathlib.Path(self.create_tempdir('ckpt').full_path) - ckpt_path1 = pathlib.Path(self.create_tempdir(f'{ckpt_dir}/first').full_path) + ckpt_path1 = pathlib.Path( + self.create_tempdir(f'{ckpt_dir}/first').full_path) # Second Array global_input_data2 = np.arange( @@ -183,7 +204,8 @@ def test_checkpointing_jax_array(self): a2 = array.make_array_from_callback( inp_shape, NamedSharding(global_mesh, pspec), lambda idx: global_input_data2[idx]) - ckpt_path2 = pathlib.Path(self.create_tempdir(f'{ckpt_dir}/second').full_path) + ckpt_path2 = pathlib.Path( + self.create_tempdir(f'{ckpt_dir}/second').full_path) # Third Array def cb3(_): @@ -191,18 +213,20 @@ def cb3(_): global_mesh1d = jtu.create_mesh((8,), ('x',)) a3 = array.make_array_from_callback( (0,), NamedSharding(global_mesh1d, P(None)), cb3) - ckpt_path3 = pathlib.Path(self.create_tempdir(f'{ckpt_dir}/third').full_path) + ckpt_path3 = pathlib.Path( + self.create_tempdir(f'{ckpt_dir}/third').full_path) ckpt_paths = [str(ckpt_path1), str(ckpt_path2), str(ckpt_path3)] - tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths) + tspecs = jax.tree_util.tree_map(ts_impl.get_tensorstore_spec, ckpt_paths) manager = serialization.GlobalAsyncCheckpointManager() manager.serialize( [a1, a2, a3], tspecs, - on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir)) + on_commit_callback=functools.partial( + self._on_commit_callback, ckpt_dir, ckpt_dir)) manager.wait_until_finished() - m1, m2, m3 = serialization.run_deserialization( + m1, m2, m3 = ts_impl.run_deserialization( [NamedSharding(global_mesh, pspec), NamedSharding(global_mesh, P('x')), NamedSharding(global_mesh1d, P(None))], @@ -270,29 +294,23 @@ def cb3(_): ckpt_paths = [str(ckpt_path1), str(ckpt_path2), str(ckpt_path3)] tspecs = jax.tree_util.tree_map( - lambda p: serialization.get_tensorstore_spec(p, ocdbt=True), ckpt_paths - ) + lambda p: ts_impl.get_tensorstore_spec(p, ocdbt=True), ckpt_paths) manager = serialization.GlobalAsyncCheckpointManager() with ts.Transaction(atomic=True) as transaction: manager.serialize( [a1, a2, a3], tspecs, - on_commit_callback=partial( + on_commit_callback=functools.partial( self._on_commit_callback, ckpt_dir, ckpt_dir ), transaction=transaction, ) manager.wait_until_finished() - m1, m2, m3 = serialization.run_deserialization( - [ - NamedSharding(global_mesh, pspec), - NamedSharding(global_mesh, P('x')), - NamedSharding(global_mesh1d, P(None)), - ], - tspecs, - ) + m1, m2, m3 = ts_impl.run_deserialization( + [NamedSharding(global_mesh, pspec), NamedSharding(global_mesh, P('x')), + NamedSharding(global_mesh1d, P(None))], tspecs) self.assertIsInstance(m1, array.ArrayImpl) self.assertArraysEqual( @@ -341,19 +359,19 @@ def cb1(index): ckpt_dir = pathlib.Path(self.create_tempdir('first').full_path) ckpt_paths = [str(ckpt_dir)] - tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths) + tspecs = jax.tree_util.tree_map(ts_impl.get_tensorstore_spec, ckpt_paths) manager = serialization.GlobalAsyncCheckpointManager() manager.serialize( [arr], tspecs, - on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir)) + on_commit_callback=functools.partial( + self._on_commit_callback, ckpt_dir, ckpt_dir)) manager.wait_until_finished() ds = NamedSharding(jtu.create_mesh((4, 2), ('x', 'y'), iota_order=True), P('x', 'y')) - m1, = serialization.run_deserialization([ds], tspecs, [(12, 2)], - [np.float32]) + m1, = ts_impl.run_deserialization([ds], tspecs, [(12, 2)], [np.float32]) expected_data = { 0: np.array([[0], [2], [4]], dtype=np.float32), @@ -370,7 +388,7 @@ def cb1(index): self.assertArraysEqual(np.asarray(l.data), expected_data[l.device.id]) new_ds = GSPMDSharding.get_replicated(list(global_mesh.devices.flat)) - m2, = serialization.run_deserialization([new_ds], tspecs, [(8, 2)], [np.float32]) + m2, = ts_impl.run_deserialization([new_ds], tspecs, [(8, 2)], [np.float32]) for l in m2.addressable_shards: self.assertArraysEqual(l.data, global_input_data1.astype('float32')) @@ -390,20 +408,20 @@ def cb(index): ckpt_dir = pathlib.Path(self.create_tempdir('first').full_path) ckpt_paths = [str(ckpt_dir)] - tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths) + tspecs = jax.tree_util.tree_map(ts_impl.get_tensorstore_spec, ckpt_paths) manager = serialization.GlobalAsyncCheckpointManager() manager.serialize( [arr], tspecs, - on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir)) + on_commit_callback=functools.partial( + self._on_commit_callback, ckpt_dir, ckpt_dir)) manager.wait_until_finished() ds = NamedSharding(jtu.create_mesh((4, 2), ('x', 'y'), iota_order=True), P('x', 'y')) target_dtype = jnp.dtype('int4') - m1, = serialization.run_deserialization([ds], tspecs, [(12, 2)], - [target_dtype]) + m1, = ts_impl.run_deserialization([ds], tspecs, [(12, 2)], [target_dtype]) # values bigger than 7 are converted properly. expected_data = { @@ -421,7 +439,8 @@ def cb(index): self.assertArraysEqual(np.asarray(l.data), expected_data[l.device.id]) new_ds = GSPMDSharding.get_replicated(list(global_mesh.devices.flat)) - m2, = serialization.run_deserialization([new_ds], tspecs, [(8, 2)], [target_dtype]) + m2, = ts_impl.run_deserialization([new_ds], tspecs, [(8, 2)], + [target_dtype]) for l in m2.addressable_shards: self.assertArraysEqual(l.data, global_input_data.astype(target_dtype)) @@ -435,22 +454,19 @@ def test_checkpointing_scalar_jax_array(self): ckpt_dir = pathlib.Path(self.create_tempdir('first').full_path) ckpt_paths = [str(ckpt_dir)] - tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths) + tspecs = jax.tree_util.tree_map( + ts_impl.get_tensorstore_spec, ckpt_paths) manager = serialization.GlobalAsyncCheckpointManager() manager.serialize( [array1], tspecs, - on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir)) + on_commit_callback=functools.partial( + self._on_commit_callback, ckpt_dir, ckpt_dir)) manager.wait_until_finished() ds = NamedSharding(jtu.create_mesh((2,), ('x')), P(None)) - m1, = serialization.run_deserialization( - [ds], - tspecs, - [()], - [np.float32] - ) + m1, = ts_impl.run_deserialization([ds], tspecs, [()], [np.float32]) for l in m1.addressable_shards: self.assertArraysEqual(np.asarray(l.data), data.astype(np.float32)) @@ -459,10 +475,8 @@ def test_deserialize_tensorstore_array_jax_array(self): global_mesh = jtu.create_mesh((2,), ('x')) data = np.arange(1024) tspec = ts.array(data).spec() - m1, = serialization.run_deserialization( - [NamedSharding(global_mesh, P(None))], - [tspec] - ) + m1, = ts_impl.run_deserialization([NamedSharding(global_mesh, P(None))], + [tspec]) for l in m1.addressable_shards: self.assertArraysEqual(np.asarray(l.data), data) @@ -479,9 +493,9 @@ def test_spec_has_metadata(self): }, 'f': 4 } - self.assertTrue(serialization._spec_has_metadata(spec)) + self.assertTrue(ts_impl._spec_has_metadata(spec)) self.assertTrue( - serialization._spec_has_metadata({ + ts_impl._spec_has_metadata({ 'driver': 'zarr', 'kvstore': 'gfile', 'metadata': { @@ -503,39 +517,40 @@ def test_spec_has_no_metadata(self): }, 'f': 4 } - self.assertFalse(serialization._spec_has_metadata(spec)) + self.assertFalse(ts_impl._spec_has_metadata(spec)) def test_empty_spec_has_no_metadata(self): spec = {} - self.assertFalse(serialization._spec_has_metadata(spec)) + self.assertFalse(ts_impl._spec_has_metadata(spec)) @parameterized.named_parameters( ('gcs', 'gs://my/ckpt/dir/path'), ('file', '/my/ckpt/dir/path') ) def test_get_tensorstore_spec_ocdbt(self, path): - spec = serialization.get_tensorstore_spec(path, ocdbt=True) + spec = ts_impl.get_tensorstore_spec(path, ocdbt=True) is_gcs_path = path.startswith('gs://') + # for OCDBT the last part of the path is the key in the kvstore + expected_path = os.path.split(path)[0] if is_gcs_path: - self.assertEqual(spec['kvstore']['base'], os.path.dirname(path)) + self.assertEqual(spec['kvstore']['base']['driver'], 'gcs') + self.assertTrue(expected_path.endswith(spec['kvstore']['base']['path'])) else: - self.assertEqual(spec['kvstore']['base'], - f'{serialization._DEFAULT_DRIVER}://{os.path.dirname(path)}') - self.assertEqual(spec['kvstore']['path'], 'path') + self.assertEqual(spec['kvstore']['base']['path'], expected_path) def test_get_tensorstore_spec_not_absolute_path(self): path = 'my/ckpt/path' with self.assertRaisesRegex(ValueError, - "Checkpoint path should be absolute"): - serialization.get_tensorstore_spec(path, ocdbt=True) + 'Checkpoint path should be absolute'): + ts_impl.get_tensorstore_spec(path, ocdbt=True) def test_maybe_cloud_storage(self): - gs_path = 'gs://some-buck/path' - gs_spec = serialization.get_tensorstore_spec(gs_path, ocdbt=True) + gs_path = 'gs://some-buck/path/array_name' + gs_spec = ts_impl.get_tensorstore_spec(gs_path, ocdbt=True) self.assertTrue(serialization.is_remote_storage(gs_spec)) - local_path = '/tmp/checkpoint' - local_spec = serialization.get_tensorstore_spec(local_path, ocdbt=True) + local_path = '/tmp/checkpoint/array_name' + local_spec = ts_impl.get_tensorstore_spec(local_path, ocdbt=True) self.assertFalse(serialization.is_remote_storage(local_spec)) nested_tspec = { @@ -543,7 +558,8 @@ def test_maybe_cloud_storage(self): 'dtype': 'int32', 'base': { 'driver': 'zarr', - 'kvstore': {'driver': 'ocdbt', 'base': 's3://some-bucket/path'}, + 'kvstore': {'driver': 'ocdbt', + 'base': 's3://some-bucket/path/array_name'}, }, } self.assertTrue(serialization.is_remote_storage(nested_tspec)) @@ -564,15 +580,16 @@ def test_load_with_layout(self): ckpt_dir = pathlib.Path(self.create_tempdir('ckpt').full_path) ckpt_path = pathlib.Path(self.create_tempdir(f'{ckpt_dir}/first').full_path) - tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, [ckpt_path]) + tspecs = jax.tree_util.tree_map(ts_impl.get_tensorstore_spec, [ckpt_path]) manager = serialization.GlobalAsyncCheckpointManager() manager.serialize( [arr], tspecs, - on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir)) + on_commit_callback=functools.partial( + self._on_commit_callback, ckpt_dir, ckpt_dir)) manager.wait_until_finished() - out, = serialization.run_deserialization([out_layout], tspecs) + out, = ts_impl.run_deserialization([out_layout], tspecs) self.assertEqual(out.layout, out_layout) self.assertIsInstance(out, array.ArrayImpl) @@ -591,9 +608,7 @@ def test_deserialization_with_int4(self): # Run serialization. sharding = jax.sharding.GSPMDSharding.get_replicated(jax.devices()) - tspecs = jax.tree_util.tree_map( - serialization.get_tensorstore_spec, [ckpt_dir] - ) + tspecs = jax.tree_util.tree_map(ts_impl.get_tensorstore_spec, [ckpt_dir]) manager = serialization.GlobalAsyncCheckpointManager() manager.serialize( [arr], @@ -603,12 +618,9 @@ def test_deserialization_with_int4(self): manager.wait_until_finished() # Run deserialization. - deserialized_arr, = serialization.run_deserialization( - shardings=[sharding], - tensorstore_specs=tspecs, - global_shapes=[shape], - dtypes=[dtype], - ) + deserialized_arr, = ts_impl.run_deserialization( + shardings=[sharding], tensorstore_specs=tspecs, global_shapes=[shape], + dtypes=[dtype]) out = deserialized_arr.astype(jnp.int8) # doesn't crash self.assertEqual(out.dtype, jnp.int8) @@ -620,13 +632,288 @@ class TransferShardTest(jtu.JaxTestCase): @jtu.skip_on_devices('cpu') def test_transfer_shard_to_host(self): np_inp = np.arange(16).reshape((4, 4)) - sharding = SingleDeviceSharding(jax.devices()[0], memory_kind="device") + sharding = SingleDeviceSharding(jax.devices()[0], memory_kind='device') arr = jax.device_put(np_inp, sharding) shard = arr.addressable_shards[0] - np_out = asyncio.run(serialization.transfer_shard_to_host(shard)) + np_out = asyncio.run(ts_impl._transfer_shard_to_host(shard)) self.assertArraysEqual(np_out, np_inp) + +class UserAPITestCase(jtu.JaxTestCase): + name: str | None + path: pathlib.Path | None + + def setUp(self): + super().setUp() + tmpdir = tempfile.TemporaryDirectory() + self.enter_context(tmpdir) + self.name = tmpdir.name + self.path = pathlib.Path(self.name) + + def tearDown(self): + self.path = None + self.name = None + super().tearDown() + + def generate_random_fp32(self, shape, dtype=jnp.float32): + seed = round(time.time() * 1e6) % (2 ** 31) + key = random.key(seed) + return random.normal(key, shape=shape).astype(dtype) + + def generate_clean_tree(self, dtype=jnp.float32): + r1 = self.generate_random_fp32((), dtype=dtype) + r2 = self.generate_random_fp32((4,), dtype=dtype) + r3 = self.generate_random_fp32((2, 3), dtype=dtype) + return (r1, {'a': r2, 'rs': [r1, r2, r3], 'c': {'d': {'e': (r2,)}}}) + + def _is_equal(self, el1, el2): + if not isinstance(el1, type(el2)) or not isinstance(el2, type(el1)): + return False + if isinstance(el1, (np.ndarray, jax.Array)): + return (el1.dtype == el2.dtype and el1.shape == el2.shape + and jnp.allclose(el1, el2)) + else: + return el1 == el2 + + def assertPyTreeEqual(self, p1, p2): + leaves1, struct1 = tree.flatten(p1) + leaves2, struct2 = tree.flatten(p2) + self.assertEqual(struct1, struct2) + self.assertTrue(all(self._is_equal(el1, el2) + for (el1, el2) in zip(leaves1, leaves2))) + +_DTYPES_LIST = [ + jnp.uint8, + jnp.uint16, + jnp.uint32, + jnp.int8, + jnp.int16, + jnp.int32, + jnp.float8_e4m3fn, + jnp.float8_e4m3fnuz, + jnp.float8_e5m2, + jnp.float8_e5m2fnuz, + jnp.float8_e4m3b11fnuz, + jnp.bfloat16, + jnp.float16, + jnp.float32, + jnp.complex64, +] + +if jax.config.x64_enabled: + _DTYPES_LIST.extend([ + jnp.uint64, + jnp.int64, + jnp.float64, + jnp.complex128, + ]) + + +class CustomNode: + def __init__(self, a): + self.a = a + + def tree_flatten(self): + return (self.a,), None + + @classmethod + def tree_unflatten(cls, aux_data, children): + del aux_data + return cls(*children) + + +@dataclass +class CustomDataclass: + a: int + c: str + d: int + + +class CustomStatic: + def __init__(self, a): + self.a = a + + +class UserAPITest(UserAPITestCase): + @parameterized.product(tree=[{'a': 1}, [1, 2, 3], (1, 2, 3), + 'hello', 1, 2, 3]) + def test_save_then_load(self, tree): # pylint: disable=redefined-outer-name + pytree_serialization.save(tree, self.path) + tree2 = pytree_serialization.load(self.path) + self.assertPyTreeEqual(tree, tree2) + + @parameterized.product(dtype=_DTYPES_LIST) + def test_saving_dtype(self, dtype): + test_tree = self.generate_clean_tree(dtype=dtype) + print('Generated tree', flush=True) + pytree_serialization.save(test_tree, self.path) + new_tree = pytree_serialization.load(self.path) + self.assertPyTreeEqual(test_tree, new_tree) + + def test_do_not_overwrite_noncheckpoint_directories(self): + (self.path / 'hello.txt').write_text('Hello World') + with self.assertRaises(AssertionError): + pytree_serialization.save({'a': 1}, self.path) + + def test_checkpoint_exists(self): + pytree_serialization.save({'a': 1}, self.path) + with self.assertRaises(NotImplementedError): + pytree_serialization.save({'a': 1}, self.path, overwrite=False) + + @parameterized.product(use_node=[True, False], use_dataclass=[True, False], + use_static=[True, False], save_pickle=[True, False], + load_pickle=[True, False]) + def test_custom_types(self, use_node, use_dataclass, use_static, save_pickle, + load_pickle): + if not use_node and not use_dataclass and not use_static: + return + magic_value = 37 + n = CustomNode(magic_value) if use_node else None + d = (CustomDataclass(magic_value, 'hello', magic_value + 1) + if use_dataclass else None) + s = CustomStatic(magic_value - 1) + tree_to_save = [n, (d, s)] + + if save_pickle: + with default_serialization_context.with_fallback(pickle): + pytree_serialization.save(tree_to_save, self.path) + else: + with self.assertRaises(ValueError): + pytree_serialization.save(tree_to_save, self.path) + return + + if load_pickle: + with default_serialization_context.with_fallback(pickle): + default_serialization_context.register_custom_node( + CustomStatic, pickle.dumps, pickle.loads) + default_serialization_context.register_custom_node( + CustomNode, pickle.dumps, pickle.loads) + default_serialization_context.register_custom_node( + CustomDataclass, pickle.dumps, pickle.loads) + tree2 = pytree_serialization.load(self.path) + default_serialization_context.deregister_custom_node(CustomStatic) + default_serialization_context.deregister_custom_node(CustomNode) + default_serialization_context.deregister_custom_node(CustomDataclass) + else: + with self.assertRaises(ValueError): + _ = pytree_serialization.load(self.path) + return + + if use_node: + self.assertEqual(tree2[0].a, magic_value) + if use_dataclass: + self.assertEqual(tree2[1][0].a, magic_value) + self.assertEqual(tree2[1][0].c, 'hello') + self.assertEqual(tree2[1][0].d, magic_value + 1) + if use_static: + self.assertEqual(tree2[1][1].a, magic_value - 1) + + @parameterized.product(register=[True, False]) + def test_best_effort(self, register): + magic_value = 37 + n = CustomNode(magic_value) + d = CustomDataclass(magic_value, 'hello', magic_value + 1) + s = CustomStatic(magic_value - 1) + tree_to_save = [n, (d, s)] + + if register: + # jax.tree_util.register_dataclass( + # CustomDataclass, data_fields=["a", "d"], meta_fields=["c"]) + jax.tree_util.register_pytree_node_class(CustomNode) + jax.tree_util.register_static(CustomStatic) + + with default_serialization_context.with_fallback(pickle): + pytree_serialization.save(tree_to_save, self.path) + with self.assertRaises(ValueError): + _ = pytree_serialization.load(self.path) + _ = pytree_serialization.load(self.path, best_effort=True) + + def test_flax_frozen_dict(self): + try: + # pylint: disable=g-import-not-at-top + # pylint: disable=g-importing-member + from flax.core.frozen_dict import FrozenDict + # pylint: enable=g-importing-member + # pylint: enable=g-import-not-at-top + except ImportError: + logging.warning('Skipping Flax FrozenDict tests as flax is not installed') + return + + try: + default_serialization_context.register_custom_node(FrozenDict, + pickle.dumps, + pickle.loads) + pytree_serialization.save(FrozenDict(a=1, b=self.generate_clean_tree()), + self.path) + pytree_serialization.load(self.path) + finally: + default_serialization_context.deregister_custom_node(FrozenDict) + + def test_incremental_writes(self): + incremental_tree = [None, None, None] + pytree_serialization.save(incremental_tree, self.path, partial_write=True) + incremental_tree[0] = 1 + pytree_serialization.save(incremental_tree, self.path, partial_write=True) + ret = pytree_serialization.load(self.path) + assert ret[0] == 1 and ret[1] is None and ret[2] is None + incremental_tree[0], incremental_tree[2] = None, jnp.ones(4) + pytree_serialization.save(incremental_tree, self.path, partial_write=True) + ret = pytree_serialization.load(self.path) + assert (ret[0] == 1 and ret[1] is None + and (np.testing.assert_allclose(ret[2], jnp.ones(4)) is None)) + + def test_custom_node_leaf_registration(self): + @dataclass + class P: + a: int = 2 + + @functools.partial(jax.tree_util.register_dataclass, data_fields=['a', 'b'], + meta_fields=['op']) + @dataclass + class D: + a: Any + b: Any + op: str + + def serialize_D(data): + return json.dumps(data) + + def deserialize_D(data): + return json.loads(data) + + data = ['hello', {'world': ['!', (1, 2)]}, None, P()] + + serialize_fn = lambda p: json.dumps(p.a) + deserialize_fn = lambda data: P(json.loads(data)) + + with self.assertRaises(ValueError): + pytree_serialization.save(data, self.path) + + default_serialization_context.register_custom_leaf(P, serialize_fn, + deserialize_fn) + magic_value = -171 + data[-1].a = magic_value + pytree_serialization.save(data, self.path) + ret = pytree_serialization.load(self.path) + self.assertLen(ret, len(data)) + self.assertEqual(ret[-1].a, magic_value) + + magic_string = str(hash('hello')) + data.append(D(1, jax.numpy.zeros(2), magic_string)) + with self.assertRaises(ValueError): + pytree_serialization.save(data, self.path) + + default_serialization_context.register_custom_node(D, serialize_D, + deserialize_D) + pytree_serialization.save(data, self.path) + ret = pytree_serialization.load(self.path) + self.assertLen(ret, len(data)) + self.assertEqual(ret[-1].op, magic_string) + + jax.tree.flatten(data) + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/jax/experimental/array_serialization/tensorstore_impl.py b/jax/experimental/array_serialization/tensorstore_impl.py new file mode 100644 index 000000000000..785d4681afef --- /dev/null +++ b/jax/experimental/array_serialization/tensorstore_impl.py @@ -0,0 +1,529 @@ +# Copyright 2021 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from functools import partial +import functools +import os +from os import PathLike +import re +from typing import Any, Awaitable, Callable, Sequence +import math +import operator + +import jax +from jax import numpy as jnp +from jax._src import array +from jax._src.layout import Layout +from jax._src import typing +from jax.experimental.array_serialization import asyncio_utils +import numpy as np +import tensorstore as ts + +TS_ARRAY_DRIVER = "zarr3" + +TS_CONTEXT = ts.Context({ + 'file_io_concurrency': {'limit': 128}, + 'cache_pool': {'total_bytes_limit': 10_000_000_000}, # 10 GB RAM limit + 'cache_pool#remote': {'total_bytes_limit': 10_000_000_000}, + 'data_copy_concurrency': {'limit': 128} +}) +TS_CHUNK_LAYOUT = ts.ChunkLayout({ + "chunk": {"elements": 100_000_000}, # 100M (800MB for float64) file size +}) + +_DEFAULT_BASE_DRIVER = 'file' +_PROCESS_DIR_FORMAT = "process_{}" +_FILE_SIZE_TARGET = 2 * 1024 ** 3 # 2 GB + +def _prod(x: Sequence[int]) -> int: + return functools.reduce(operator.mul, x, 1) + +def _maximum(x: Sequence[int], default: int = 1) -> list[int]: + return [max(z, default) for z in x] + +def is_tensorstore_spec_leaf(leaf: Any): + # TODO(rdyro): think of a better way to detect which leaf is a ts config + return isinstance(leaf, dict) and "driver" in leaf or "kvstore" in leaf + +def _prime_factors(x: int) -> list[int]: + factors = [] + while x % 2 == 0: + factors.append(2) + x //= 2 + for i in range(3, int(math.sqrt(x)) + 1, 2): + while x % i == 0: + factors.append(i) + x //= i + if x > 1: + factors.append(x) + return sorted(factors) + +def _compute_chunk_shape( + arr: jax.Array, file_size_target: int = _FILE_SIZE_TARGET) -> list[int]: + """Compute a chunk such that divides local shape and is less than target file size.""" + local_shape = _maximum(list(arr.addressable_data(0).shape), default=1) + total_size = (_prod(_maximum(local_shape, default=1)) + * jnp.dtype(arr.dtype).itemsize) + if len(local_shape) == 0: + return local_shape + axis_prime_factors = [_prime_factors(z) if z > 1 else [] for z in local_shape] + chunk_size = total_size + while chunk_size > 1.1 * file_size_target: # 10% buffer + chosen_axis_idx, chosen_divisor = None, 1 + for axis_idx in range(len(local_shape)): + if len(axis_prime_factors[axis_idx]) == 1: + continue + if (chosen_axis_idx is None + or chosen_divisor > axis_prime_factors[axis_idx][0]): + chosen_axis_idx = axis_idx + chosen_divisor = axis_prime_factors[axis_idx][0] + if chosen_axis_idx is None: + break + if len(axis_prime_factors[chosen_axis_idx]) == 0: + return local_shape + axis_prime_factors[chosen_axis_idx].pop(0) + local_shape[chosen_axis_idx] //= chosen_divisor + chunk_size //= chosen_divisor + return local_shape + +def get_tensorstore_metadata(arr, is_remote: bool = False, + file_size_target: int = _FILE_SIZE_TARGET, + driver: str = TS_ARRAY_DRIVER, + ) -> dict[str, Any]: + if driver == TS_ARRAY_DRIVER: + codecs = ([{"name": "zstd"}] if is_remote else []) + return { + 'codecs': codecs, + 'shape': arr.shape, + 'data_type': jnp.dtype(arr.dtype).name, + 'chunk_grid': { + 'name': 'regular', + 'configuration': {'chunk_shape': _compute_chunk_shape( + arr, file_size_target=file_size_target)} + } + } + elif driver == "zarr": # in zarr dtype goes in the base spec + local_shape = arr.addressable_data(0).shape + return {'compressor': {'id': 'zstd'}, 'shape': arr.shape, + 'chunks': np.array(np.maximum(1, local_shape))} + else: + raise ValueError(f"Unsupported driver: {driver}") + +_divides = lambda x, y: np.all((np.array(x) % np.array(y)) == 0) + +def merge_nested_specs(dict1: dict[Any, Any], dict2: dict[Any, Any]): + """Merge two specs as nested dictionaries, dict2 takes precedence.""" + if dict2 is None: + return dict1 + exclusive_dict1_keys = set(dict1.keys()) - set(dict2.keys()) + exclusive_dict2_keys = set(dict2.keys()) - set(dict1.keys()) + shared_keys = set(dict1.keys()) & set(dict2.keys()) + out_dict = {k: dict1[k] for k in exclusive_dict1_keys} + out_dict.update({k: dict2[k] for k in exclusive_dict2_keys}) + for k in shared_keys: + v1, v2 = dict1[k], dict2[k] + if isinstance(v1, dict): + out_dict[k] = merge_nested_specs(v1, v2) + else: + out_dict[k] = v2 + return out_dict + +def verify_tensorstore_spec(spec: dict[str, Any], arr: jax.Array | None, + path: str | os.PathLike[str], + check_metadata: bool = True) -> None: + """Verify the minimum requirements for a tensorstore spec.""" + if check_metadata: + assert arr is not None, "Array is required for metadata verification." + metadata = spec['metadata'] + msg = f"Provided dtype {metadata['data_type']} != array dtype: {arr.dtype}" + assert metadata['data_type'] == jnp.dtype(arr.dtype).name, msg + msg = f"Provided shape {metadata['shape']} != array shape: {arr.shape}" + assert metadata['shape'] == arr.shape, msg + local_shape = arr.addressable_data(0).shape + chunk_shape = metadata['chunk_grid']['configuration']['chunk_shape'] + msg = (f"Provided chunk shape {chunk_shape} does not divide the local shape" + f" of the array {local_shape}") + assert _divides(local_shape, chunk_shape), msg + # we don't support altering the path of the tensorstore + msg = (f"Provided { path = } does not match the path in the spec:" + f" {spec['kvstore']}") + assert spec["kvstore"]['base']['path'] == str(path), msg + +def _deprecated_get_metadata(arr): + local_shape = arr.addressable_data(0).shape + return { + 'compressor': {'id': 'zstd'}, + 'shape': arr.shape, + 'chunks': np.array(np.maximum(1, local_shape)), + } + +def _spec_has_metadata(tree): + if not isinstance(tree, dict): + return False + return 'metadata' in tree or any( + _spec_has_metadata(subtree) for _, subtree in tree.items()) + +def _get_kvstore_for_gcs(ckpt_path: str): + m = re.fullmatch('^gs://([^/]*)/(.*)$', ckpt_path, re.DOTALL) + if m is None: + raise ValueError('The ckpt_path should contain the bucket name and the ' + f'file path inside the bucket. Got: {ckpt_path}') + bucket = m.group(1) + path_without_bucket = m.group(2) + return {'driver': 'gcs', 'bucket': bucket, 'path': path_without_bucket} + +def _get_kvstore_for_s3(ckpt_path: str): + m = re.fullmatch('^s3://([^/]*)/(.*)$', ckpt_path, re.DOTALL) + if m is None: + raise ValueError('The ckpt_path should contain the bucket name and the ' + f'file path inside the bucket. Got: {ckpt_path}') + bucket = m.group(1) + path_without_bucket = m.group(2) + return {'driver': 's3', 'bucket': bucket, 'path': path_without_bucket} + +def get_tensorstore_spec( + ckpt_path: str | PathLike[str], ocdbt: bool = False, + process_num: int | None = None, arr: jax.Array | None = None, + driver: str = TS_ARRAY_DRIVER) -> dict[str, Any]: + + # Normalize path to exclude trailing '/'. In GCS path case, we will need to + # fix the path prefix to add back the stripped '/'. + ckpt_path = str(ckpt_path) + ckpt_path = re.sub(r"^gs:/", r"gs://", os.path.normpath(ckpt_path)) + ckpt_path = re.sub(r"^s3:/", r"s3://", ckpt_path) + + # in cases of multi-process writes, we need to write to a different location + # for each process and finally created a combined symlink to the final + # location, tensorstore can do this via ts.KvStore.experimental_copy_range_to + if process_num is not None: + _parent, _name = os.path.split(ckpt_path) + ckpt_path = os.path.join(_parent, _PROCESS_DIR_FORMAT.format(process_num), + _name) + + is_gcs_path = ckpt_path.startswith('gs://') + is_s3_path = ckpt_path.startswith('s3://') + spec = {'driver': driver, 'kvstore': {}} + + # use a combined OCDBT store, the actual path is the parent path + # the name (filename/last part of the path) is the key in the ocdbt kvstore + entry_key = None + if ocdbt: + (ckpt_path, entry_key), org_ckpt_path = os.path.split(ckpt_path), ckpt_path + if is_gcs_path: + m = re.fullmatch('^gs://([^/]*)/(.*)$', ckpt_path, re.DOTALL) + elif is_s3_path: + m = re.fullmatch('^s4://([^/]*)/(.*)$', ckpt_path, re.DOTALL) + else: + m = re.match("a", "a") # make it True + if m is None: + raise ValueError('Using OCDBT requires the bucket name, the directory' + ' name and the array name, your path is: ' + f'{org_ckpt_path}') + + if is_gcs_path: + base_kvstore = _get_kvstore_for_gcs(ckpt_path) + elif is_s3_path: + base_kvstore = _get_kvstore_for_s3(ckpt_path) + else: + base_kvstore = {'driver': _DEFAULT_BASE_DRIVER, 'path': ckpt_path} + + if ocdbt: + if not is_gcs_path and not os.path.isabs(ckpt_path): + raise ValueError(f'Checkpoint path should be absolute. Got {ckpt_path}') + spec['kvstore'] = {'driver': 'ocdbt', 'base': base_kvstore, + 'path': entry_key} + else: + spec['kvstore'] = base_kvstore + if arr is not None: + spec["metadata"] = get_tensorstore_metadata(arr, driver=str(spec["driver"])) + return spec + +async def _create_async_array_from_callback( + global_shape: array.Shape, + inp_sharding: jax.sharding.Sharding, + data_callback: Callable[[array.Index, jax.Device], Awaitable[jax.Array]], +): + device_to_index_map = inp_sharding.devices_indices_map(global_shape) + addressable_da = inp_sharding._addressable_device_assignment + future_arrays = [data_callback(device_to_index_map[d], d) + for d in addressable_da] + dbs = await asyncio.gather(*future_arrays) + return array.make_array_from_single_device_arrays( + global_shape, inp_sharding, dbs) + +async def _transfer_shard_to_host(shard: array.Shard) -> np.ndarray: + data = shard.data + has_pinned_host = any( + m.kind == "pinned_host" for m in shard.device.addressable_memories()) + if has_pinned_host: + # If available, transfer to pinned host memory + sharding = jax.sharding.SingleDeviceSharding(shard.device, + memory_kind="pinned_host") + data = jax.device_put(data, sharding) + else: + data.copy_to_host_async() + # Allow other transfers to be scheduled simultaneously + await asyncio.sleep(0) + # Ensure that jax.Array's internal numpy array can be zero-copied. Tensorstore + # implicitly converts the written data to a numpy array, and would otherwise + # silently copy host-to-host. + return np.array(data, copy=False) + +async def combine_kvstores(combined_kvstore: dict[str, Any], + kvstores: list[dict[str, Any]], + context: ts.Context | dict[str, Any] = TS_CONTEXT + ) -> None: + """Merge a list of kvstores into a single kvstore. NOT multi-process safe.""" + combined_fut = ts.KvStore.open(combined_kvstore, context=context) + kvstores_futs = [ts.KvStore.open(kvstore, context=context) + for kvstore in kvstores] + combined, kvstores = await asyncio.gather(combined_fut, + asyncio.gather(*kvstores_futs)) + tx = ts.Transaction() + await asyncio.gather(*[kvstore.experimental_copy_range_to( + combined.with_transaction(tx)) for kvstore in kvstores]) + await tx.commit_async() + +async def async_serialize( + arr_inp, + tensorstore_spec, + commit_future=None, + context=TS_CONTEXT, + chunk_layout=TS_CHUNK_LAYOUT, + primary_host: int | None = None, + replica_id: int = 0, + transaction: ts.Transaction | None = None, +): + """Serialize an array using TensorStore. + + Args: + arr_inp: The array to serialize. + tensorstore_spec: The tensorstore spec to use. + commit_future: A list of futures that will be appended to. The futures can + be awaited asynchronously. If None, the futures will be awaited + synchronously by this method. + context: ts.Context instance. + primary_host: Primary host, which indicates the host that will be treated as + the "leader". If None, all hosts are treated as the primary. DO NOT USE + unless you are sure you know what you are doing. + replica_id: Allows overriding the shard replica id that will be saved. DO + NOT USE unless you are sure you know what you are doing. + transaction: TensorStore transaction to use for opening and writing the + array. If not specified, a non-transactional write will be used. + """ + if (isinstance(arr_inp, array.ArrayImpl) and jax.process_count() > 1 and + arr_inp.is_fully_addressable): + raise ValueError( + f'Passing fully addressable arrays to a multiprocess ' + f'serialization is not allowed, as this may lead to a race condition ' + f'between processes. Serialization have failed for the array with ' + f'the path from kvstore: "{tensorstore_spec["kvstore"]}".') + + # 'metadata' may not be present at the top level (for example, if we are using + # a 'cast' driver). + if not _spec_has_metadata(tensorstore_spec): + tensorstore_spec['metadata'] = get_tensorstore_metadata( + arr_inp, driver=tensorstore_spec['driver']) + ## zarr driver requires specifying the dtype in the spec base + if tensorstore_spec['driver'] == 'zarr' and 'dtype' not in tensorstore_spec: + tensorstore_spec['dtype'] = jnp.dtype(arr_inp.dtype).name + + # If primary_host is None, all hosts will checkpoint. This is used + # for checkpointing to local filesystem. + if primary_host is None or jax.process_index() == primary_host: + open_future = ts.open( + ts.Spec(tensorstore_spec), + create=True, + open=True, + context=context, + chunk_layout=chunk_layout, + transaction=transaction, + ) + # Asynchronous case. + if commit_future is not None: + assert isinstance(commit_future, list) + commit_future.append(open_future) + else: + await open_future + + # `ts.open` runs twice for process `primary_host` because for the first time, + # we just get the future to be awaited upon in the background thread. The + # second one runs with `assume_metadata=True` which does no I/O operation and + # returns the tensorstore object. + # For every process other than `primary_host`, we open with + # `assume_metadata=True`. + t = await ts.open( + ts.Spec(tensorstore_spec), + open=True, + assume_metadata=True, + context=context, + chunk_layout=chunk_layout, + transaction=transaction, + ) + + async def _write_array(shard): + if shard.replica_id == replica_id: + data = await _transfer_shard_to_host(shard) + write_future = t[shard.index].write( + data, + # Avoid additional copy of input array into the TensorStore chunk + # cache. If `arr_inp` is a jax.Array, the result of converting + # it to a NumPy array, as is done internally by TensorStore, is + # guaranteed to be immutable and therefore it is safe to retain a + # reference indefinitely. + can_reference_source_data_indefinitely=isinstance( + arr_inp, array.ArrayImpl + ), + ) + if commit_future is not None: + assert isinstance(commit_future, list) + commit_future.append(write_future.commit) + await write_future.copy + else: + await write_future.commit + + local_shards = arr_inp.addressable_shards + future_write_state = jax.tree_util.tree_map(_write_array, local_shards) + return await asyncio.gather(*future_write_state) + + +def run_serialization(arrays, tensorstore_specs): + async def _run_serializer(): + future_writer = jax.tree_util.tree_map(async_serialize, arrays, tensorstore_specs) + return await asyncio.gather(*future_writer) + asyncio.run(_run_serializer()) + + +def estimate_read_memory_footprint(t: ts.TensorStore, + domain: ts.IndexDomain) -> int: + rank = t.rank + num_bytes = t.dtype.numpy_dtype.itemsize + chunk_template = t.chunk_layout.read_chunk_template + if domain is None: + domain = t.domain + origin = domain.origin + shape = domain.shape + chunk_origin = chunk_template.origin + chunk_shape = chunk_template.shape + + # Some TensorStore drivers are not chunked, e.g. the inline 'array' driver. + # For those, instead of returning a near-infinite memory footprint, estimate + # the footprint as the entire shape. + for i in range(rank): + if not chunk_template[i].finite: + return domain.size * num_bytes + + # Otherwise, if we have a chunked driver, estimate based on chunk size. + for i in range(rank): + origin_value = origin[i] + chunk_origin_value = chunk_origin[i] + chunk_size = chunk_shape[i] + lower = origin_value - chunk_origin_value + upper = origin_value + shape[i] - chunk_origin_value + lower_aligned = lower // chunk_size * chunk_size + upper_aligned = -(-upper // chunk_size) * chunk_size + num_bytes *= (upper_aligned - lower_aligned) + + return num_bytes + + +async def async_deserialize( + user_in_sharding: jax.sharding.Sharding | Layout, + tensorstore_spec: ts.Spec | dict[str, Any], + global_shape: Sequence[int] | None = None, + dtype=None, + byte_limiter: asyncio_utils._LimitInFlightBytes | None = None, + context=TS_CONTEXT, + chunk_layout=TS_CHUNK_LAYOUT, + assume_metadata: bool = False, +): + in_sharding = (user_in_sharding.sharding + if isinstance(user_in_sharding, Layout) else user_in_sharding) + if not isinstance(in_sharding, jax.sharding.Sharding): + raise ValueError( + 'sharding passed to deserialization should be specified, concrete and' + f' an instance of `jax.sharding.Sharding`. Got {in_sharding}') + dll = (user_in_sharding.device_local_layout + if isinstance(user_in_sharding, Layout) else None) + t = await ts.open( + tensorstore_spec, + open=True, + assume_metadata=assume_metadata, + context=context, + chunk_layout=chunk_layout, + ) + shape = t.shape if global_shape is None else global_shape + new_shard_shape = in_sharding.shard_shape(tuple(shape)) + + async def cb(index: array.Index, device: jax.Device): + requested_domain = ts.IndexTransform(input_shape=shape)[index].domain + restricted_domain = t.domain.intersect(requested_domain) + requested_bytes = estimate_read_memory_footprint(t, restricted_domain) + # Limit the bytes read for every shard. + if byte_limiter is not None: + await byte_limiter.wait_for_bytes(requested_bytes) + # This maybe needed because the shape the array was saved with is smaller + # than the requested shape of the array in which it will be reloaded. So + # the extra values will be filled with 0s. + out = np.zeros(new_shard_shape, dtype=t.dtype.numpy_dtype) + await ts.array(out)[ts.d[:].translate_to[requested_domain.origin]][ + restricted_domain].write(t[restricted_domain]) + if dtype is not None: + # Cast while reloading on process to avoid 2 copies on device if the + # casting is done on device. + out = out.astype(dtype) + # Convert to jnp array so that layouts are initialized properly for + # sub-byte dtypes. + # TODO(yashkatariya): This is a band-aid fix. Figure out a better way to + # make this work. + if out.dtype == jnp.int4: + out = jnp.asarray(out) # type: ignore + result = jax.device_put( + out, Layout(dll, jax.sharding.SingleDeviceSharding(device))) + if byte_limiter is not None: + # NB: `out` actually might not be ready for garbage collection by the + # time we call release_bytes . Thus peak memory usage still might grow + # beyond what byte_limiter limit suggests it should. The simplest option + # would be to call `result.block_until_ready()`` here. However it + # also comes with ~15-20% perf penalty as we would be waiting for CPU->GPU + # transfer instead of loading data. In the future, if memory pressure + # becomes a problem, we can instead instrument bytelimiter to + # keep track of all in-flight tensors and only block_until_ready, if byte + # limiter hits the limit to get reduced memory usage, without losing + # performance in common use cases. + await byte_limiter.release_bytes(requested_bytes) + return result + + return await _create_async_array_from_callback(tuple(shape), in_sharding, cb) + + +def run_deserialization(shardings: Sequence[jax.sharding.Sharding | Layout], + tensorstore_specs: Sequence[dict[str, Any]], + global_shapes: Sequence[array.Shape] | None = None, + dtypes: Sequence[typing.DTypeLike] | None = None, + concurrent_gb: int = 32): + concurrent_bytes = concurrent_gb * 10**9 + + async def _run_deserializer(): + # Object should be created once per process. + byte_limiter = asyncio_utils._LimitInFlightBytes(concurrent_bytes) + + future_arrays = jax.tree_util.tree_map( + partial(async_deserialize, byte_limiter=byte_limiter), + shardings, tensorstore_specs, + [None] * len(tensorstore_specs) if global_shapes is None else global_shapes, + [None] * len(tensorstore_specs) if dtypes is None else dtypes) + return await asyncio.gather(*future_arrays) + return asyncio.run(_run_deserializer())