Source code for orbax.checkpoint.experimental.v1._src.partial.saving

# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Defines free-function interface for partial saving and finalizing."""

import ast
import asyncio
import dataclasses
import itertools
import json
from typing import Any, Awaitable, Callable
from etils import epath

from orbax.checkpoint._src import asyncio_utils
from orbax.checkpoint._src.handlers import base_pytree_checkpoint_handler
from orbax.checkpoint._src.metadata import tree as tree_metadata
from orbax.checkpoint._src.path import async_path
from orbax.checkpoint._src.path import format_utils
from orbax.checkpoint._src.path import utils as ocp_path_utils
from orbax.checkpoint._src.path.snapshot import snapshot
from orbax.checkpoint._src.tree import structure_utils as tree_structure_utils
from orbax.checkpoint.experimental.v1._src.context import context as context_lib
from orbax.checkpoint.experimental.v1._src.handlers import global_registration  # pylint: disable=unused-import
from orbax.checkpoint.experimental.v1._src.handlers import pytree_handler
from orbax.checkpoint.experimental.v1._src.handlers import stateful_checkpointable_handler
from orbax.checkpoint.experimental.v1._src.handlers import types as handler_types
from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout
from orbax.checkpoint.experimental.v1._src.layout import orbax_layout
from orbax.checkpoint.experimental.v1._src.metadata import serialization as metadata_serialization
from orbax.checkpoint.experimental.v1._src.partial import path as partial_path_lib
from orbax.checkpoint.experimental.v1._src.path import types as path_types
from orbax.checkpoint.experimental.v1._src.saving import execution
from orbax.checkpoint.experimental.v1._src.synchronization import multihost
from orbax.checkpoint.experimental.v1._src.synchronization import synchronization
from orbax.checkpoint.experimental.v1._src.synchronization import types as async_types
from orbax.checkpoint.experimental.v1._src.tree import types as tree_types


STATE_CHECKPOINTABLE_KEY = checkpoint_layout.STATE_CHECKPOINTABLE_KEY
ORBAX_CHECKPOINT_INDICATOR_FILE = orbax_layout.ORBAX_CHECKPOINT_INDICATOR_FILE
CHECKPOINT_METADATA_FILENAME = metadata_serialization._CHECKPOINT_METADATA_FILENAME  # pylint: disable=protected-access
PYTREE_METADATA_FILE = format_utils.PYTREE_METADATA_FILE


StatefulCheckpointableHandler = (
    stateful_checkpointable_handler.StatefulCheckpointableHandler
)
BasePyTreeCheckpointHandler = (
    base_pytree_checkpoint_handler.BasePyTreeCheckpointHandler
)


@dataclasses.dataclass
class _PartialSavePyTree(handler_types.StatefulCheckpointable):
  """Wraps a PyTree to signal that it should be saved in partial mode."""

  state: tree_types.PyTree

  def __post_init__(self):
    self.handler = pytree_handler.PyTreeHandler(partial_save_mode=True)

  async def save(
      self, directory: path_types.PathAwaitingCreation
  ) -> Awaitable[None]:
    return await self.handler.save(directory, self.state)

  async def load(self, directory: path_types.Path) -> Awaitable[None]:
    raise NotImplementedError('Partial load is not supported via this wrapper.')


[docs] def save( path: path_types.PathLike, state: tree_types.PyTreeOf[tree_types.Leaf], *, custom_metadata: tree_types.JsonType | None = None, ): """Partially saves a PyTree. This function allows for incrementally updating a checkpoint. It is designed to be called multiple times. The first call initiates a new partial save "session" in a temporary location. Subsequent calls will update this session by modifying the checkpoint in place. The operation is atomic; if it is interrupted, the previous version of the partial save will be preserved. IMPORTANT: The checkpoint is not finalized at the target `path` until :py:func:`.finalize` is called. The intermediate checkpoints are temporary and should not be used directly. ### Workflow A typical partial save workflow involves one or more calls to :py:func:`.save` followed by a single call to :py:func:`~.finalize`:: path = '/path/to/my/checkpoint' # The first call creates a temporary directory: # '/path/to/my/checkpoint.partial_save' # Note: the exact temporary directory name is an implementation detail that # depends on the file system and should not be relied on. ocp.partial.save(path, {'layer1': ..., 'step': 1}) # A subsequent call reads the previous version and applies new updates # to the temporary directory: # '/path/to/my/checkpoint.partial_save' ocp.partial.save(path, {'layer2': ..., 'metrics': ...}) # This call commits the latest version to the final destination at # '/path/to/my/checkpoint'. ocp.partial.finalize(path) ### Additions vs. Replacements The provided `state` represents a set of updates. - If a key in `state` (e.g., 'metrics') does not exist in the on-disk checkpoint, it is treated as an **addition**. In other words, the sets of keys of the on-disk PyTree and the provided `state` are disjoint. - If a key (e.g., 'step') already exists, its value is **replaced**. In other words, the sets of keys of the on-disk PyTree and the provided `state` overlap. Replacements are currently NOT supported. Please reach out to the Orbax team if you need this functionality. See :py:func:`~.v1.save` for general PyTree saving documentation. Args: path: The path to save the checkpoint to. state: A PyTree representing the additions to be applied to the on-disk checkpoint. custom_metadata: User-provided custom metadata. This will be merged with any existing custom metadata. Values from this dictionary will overwrite existing values if keys conflict. """ save_async( path, state, custom_metadata=custom_metadata, ).result()
[docs] def save_async( path: path_types.PathLike, state: tree_types.PyTreeOf[tree_types.Leaf], *, custom_metadata: tree_types.JsonType | None = None, ) -> async_types.AsyncResponse[None]: """Partially saves a PyTree asynchronously. Unlike :py:func:`.save`, this function returns an :py:class:`.AsyncResponse` immediately after scheduling the save operation. The actual writing to disk happens in a background thread. You can use `response.result()` to block until the operation is complete. This function allows for incrementally updating a checkpoint. It is designed to be called multiple times. The first call initiates a new partial save "session" in a temporary location. Subsequent calls will update this session by creating a new version that includes all previous changes plus the new ones. The operation is atomic; if it is interrupted, the previous version of the partial save will be preserved. IMPORTANT: The checkpoint is not finalized at the target `path` until :py:func:`.finalize` is called. The intermediate checkpoints are temporary and may be garbage collected in certain environments. ### Workflow A typical partial save workflow involves one or more calls to :py:func:`.save_async` followed by a single call to :py:func:`.finalize`:: path = '/path/to/my/checkpoint' # The first call creates a temporary directory and returns immediately. response1 = ocp.partial.save_async(path, {'layer1': ..., 'step': 1}) # A subsequent call also returns immediately. Orbax ensures that this # operation waits for the first one to complete before starting. response2 = ocp.partial.save_async( path, {'layer2': ..., 'metrics': ...} ) # Wait for all async partial saves to complete before finalizing. response1.result() response2.result() # This call commits the latest version to the final destination at # '/path/to/my/checkpoint'. ocp.partial.finalize(path) ### Additions vs. Replacements The provided `state` represents a set of updates. - If a key in `state` (e.g., 'metrics') does not exist in the on-disk checkpoint, it is treated as an **addition**. - If a key (e.g., 'step') already exists, its value is **replaced**. Replacements are currently NOT supported. Please reach out to the Orbax team if you need this functionality. See :py:func:`~.v1.save_async` for general PyTree saving documentation. Args: path: The path to save the checkpoint to. state: The PyTree to save. This may be any JAX PyTree consisting of supported leaf types (see :py:class:`~.v1.tree.Leaf`). Default supported leaf types include `jax.Array`, `np.ndarray`, simple types like `int`, `float`, `str`, and empty nodes. custom_metadata: User-provided custom metadata. An arbitrary JSON-serializable dictionary the user can use to store additional information. The field is treated as opaque by Orbax. Returns: An :py:class:`.AsyncResponse` that can be used to block until the save is complete. Blocking can be done using `response.result()`, which returns `None`. Raises: FileExistsError: If a finalized checkpoint already exists at `path`. To overwrite, it must be deleted first. """ ctx = context_lib.get_context() path = ctx.file_options.path_class(path) if path.exists(): raise FileExistsError(f'Finalized checkpoint already exists at {path}.') return execution.save_checkpointables_impl( partial_path_lib.add_partial_save_suffix(path), {STATE_CHECKPOINTABLE_KEY: _PartialSavePyTree(state)}, overwrite=False, custom_metadata=custom_metadata, async_origin=True, partial_save=True, )
async def _read_first_metadata( pending_dirs: list[epath.Path], ) -> tree_metadata.InternalTreeMetadata | None: """Reads metadata from the first pending directory.""" if not pending_dirs: return None for item in await async_path.iterdir(pending_dirs[0]): if not await async_path.is_dir(item): continue first_meta_path = item / PYTREE_METADATA_FILE if await async_path.exists(first_meta_path): try: return tree_metadata.InternalTreeMetadata.from_json( json.loads(await async_path.read_text(first_meta_path)), pytree_metadata_options=tree_metadata.PYTREE_METADATA_OPTIONS, ) except json.JSONDecodeError as e: raise ValueError( 'Failed to read metadata from first metadata file' f' {first_meta_path}.' ) from e return None def _is_prefix(t1: tuple[str, ...], t2: tuple[str, ...]) -> bool: return len(t1) < len(t2) and t2[: len(t1)] == t1 def _filter_conflicting_keys(d: dict[str, Any]) -> dict[str, Any]: """Filters metadata keys that conflict due to parent-child relationships. When merging metadata from multiple partial saves, we might encounter conflicting entries. For example, one partial save might save 'a/b' as a leaf, while another saves 'a/b/c' as a leaf. This is a conflict because 'a/b' cannot be both a leaf and an intermediate node containing 'c'. This function resolves the conflict by removing metadata for 'a/b', keeping 'a/b/c', and implicitly treating 'a/b' as an intermediate node. Args: d: A dictionary of metadata. Returns: The filtered metadata dictionary. """ keys = list(d.keys()) to_remove = set() parsed_keys = {} for k in keys: try: parsed_keys[k] = ast.literal_eval(k) except (ValueError, SyntaxError): parsed_keys[k] = k for k1, k2 in itertools.permutations(keys, 2): t1, t2 = parsed_keys[k1], parsed_keys[k2] if isinstance(t1, tuple) and isinstance(t2, tuple): if _is_prefix(t1, t2): to_remove.add(k1) elif isinstance(k1, str) and isinstance(k2, str): if k2.startswith((k1 + '.', k1 + '/')): to_remove.add(k1) for k in to_remove: del d[k] return d async def _rename_or_merge_json( src: epath.Path, dst: epath.Path, merge_fn: Callable[[Any, Any], Any] ): """Tries to rename src to dst, otherwise merges them as JSONs using merge_fn.""" try: await async_path.rename(src, dst) except FileExistsError: pass else: return src_meta = json.loads(await async_path.read_text(src)) dst_meta = json.loads(await async_path.read_text(dst)) merged_meta = merge_fn(src_meta, dst_meta) await async_path.write_text(dst, json.dumps(merged_meta)) await async_path.unlink(src) async def _merge_pytree_metadata(src_item: epath.Path, dst_item: epath.Path): """Merges PyTree metadata files (_METADATA or _sharding).""" def _merge_fn(src_meta, dst_meta): merged = tree_structure_utils.merge_trees( dst_meta, src_meta, overwrite=True ) if 'tree_metadata' in merged: merged['tree_metadata'] = _filter_conflicting_keys( merged['tree_metadata'] ) return merged await _rename_or_merge_json(src_item, dst_item, _merge_fn) async def _rename_ocdbt_process_dir( item: epath.Path, pytree_dst: epath.Path, uuid_str: str ): """Renames an ocdbt.process_ directory to avoid collisions across partial saves.""" # To avoid collisions across different partial save pending directories, # we append the pending dir's UUID to the original process ID. # We must avoid using '_' in the new ID because `ocdbt_utils.py` splits # the directory name by '_' to extract the process ID. new_name = f'{item.name}{uuid_str.replace("-", "")}' await async_path.rename(item, pytree_dst / new_name) async def _merge_array_metadatas(src_dir: epath.Path, dst_dir: epath.Path): """Merges array_metadatas JSON files (process_0, process_1, etc.).""" await async_path.mkdir(dst_dir, parents=True, exist_ok=True) async def _process_child(src_child: epath.Path): dst_child = dst_dir / src_child.name def _merge_fn(src_meta, dst_meta): src_arr_meta = src_meta.get('array_metadatas', []) dst_arr_meta = dst_meta.get('array_metadatas', []) dst_arr_meta.extend(src_arr_meta) dst_meta['array_metadatas'] = dst_arr_meta return dst_meta await _rename_or_merge_json(src_child, dst_child, _merge_fn) await asyncio.gather(*[ _process_child(src_child) for src_child in await async_path.iterdir(src_dir) ]) async def _recursive_merge(src: epath.Path, dst: epath.Path): """Recursively merges src into dst.""" if not await async_path.exists(src): return try: await async_path.rename(src, dst) except FileExistsError: pass else: return if await async_path.is_dir(src): items = await async_path.iterdir(src) await asyncio.gather( *[_recursive_merge(item, dst / item.name) for item in items] ) await async_path.rmtree(src) return raise FileExistsError( f'File collision on {src.name} during finalize. Overwriting destination ' 'file is not allowed.' ) async def _merge_pytree_directory( pytree_src: epath.Path, partial_path: epath.Path, uuid_str: str, ): """Merges a single pending pytree directory into the destination.""" if not await async_path.exists(pytree_src): return pytree_dst = partial_path / pytree_src.name await async_path.mkdir(pytree_dst, parents=True, exist_ok=True) async def _merge_item(item_path: epath.Path): if item_path.name in [PYTREE_METADATA_FILE, '_sharding']: await _merge_pytree_metadata(item_path, pytree_dst / item_path.name) elif item_path.name.startswith('ocdbt.process_'): await _rename_ocdbt_process_dir(item_path, pytree_dst, uuid_str) elif item_path.name == 'array_metadatas': await _merge_array_metadatas(item_path, pytree_dst / item_path.name) else: await _recursive_merge(item_path, pytree_dst / item_path.name) await asyncio.gather( *[_merge_item(item) for item in await async_path.iterdir(pytree_src)] ) await async_path.rmtree(pytree_src) async def _merge_checkpoint_metadata(src: epath.Path, dst: epath.Path): """Merges checkpoint metadata.""" def _merge_fn(src_meta, dst_meta): return tree_structure_utils.merge_trees(dst_meta, src_meta, overwrite=True) await _rename_or_merge_json(src, dst, _merge_fn) async def _merge_indicator_file(src: epath.Path, dst: epath.Path): """Merges the Orbax checkpoint indicator file.""" try: await async_path.rename(src, dst) except FileExistsError: await async_path.unlink(src) async def _is_pytree_dir(item: epath.Path) -> bool: """Returns True if the item is a PyTree directory.""" return await async_path.is_dir(item) and await async_path.exists( item / PYTREE_METADATA_FILE ) async def _merge_all(partial_path: epath.Path): """Merges all pending directories into the partial path.""" # Each partial save call results in a new pending directory containing unique # PyTree keypaths and corresponding data. During finalization, all pending # directories are merged to form the final checkpoint state. # Ensure deterministic merge order (alphabetical glob + timestamp). pending_dirs = sorted(await snapshot.list_pending_dirs(partial_path)) first_metadata = await _read_first_metadata(pending_dirs) use_zarr3 = first_metadata.use_zarr3 if first_metadata is not None else False pytree_directories = [] for p_dir in pending_dirs: uuid_str = snapshot.get_uuid_from_pending_dir_name(p_dir.name) async def _process_item(item: epath.Path, uuid_str: str): if item.name == CHECKPOINT_METADATA_FILENAME: await _merge_checkpoint_metadata(item, partial_path / item.name) elif item.name == ORBAX_CHECKPOINT_INDICATOR_FILE: await _merge_indicator_file(item, partial_path / item.name) elif await _is_pytree_dir(item): pytree_directories.append(item.name) await _merge_pytree_directory(item, partial_path, uuid_str) else: await _recursive_merge(item, partial_path / item.name) await asyncio.gather(*[ _process_item(item, uuid_str) for item in await async_path.iterdir(p_dir) ]) await async_path.rmtree(p_dir) # 3. Call PyTreeHandler.finalize to perform OCDBT merge. # This merges the individual ocdbt.process_xxx directories into a single # valid manifest for the final partial state. handler = BasePyTreeCheckpointHandler(use_zarr3=use_zarr3) for pytree_dir_name in pytree_directories: await asyncio.to_thread(handler.finalize, partial_path / pytree_dir_name)
[docs] def finalize(path: path_types.PathLike) -> None: """Finalizes a partially-saved checkpoint, making it permanent and readable. This function commits all changes made during a partial save session, concluding the transaction. It should be called once after all desired :py:func:`.save` operations are complete. The finalization process is atomic. It renames the temporary, versioned partial save directory to the final target `path`, making the updated checkpoint "live". IMPORTANT: Until `finalize` is called, the checkpoint at the target `path` is not created or modified. All changes are buffered in a temporary location. This function is what makes those changes permanent. ### Example:: path = '/path/to/my/checkpoint' # These calls write to a temporary, versioned directory, not the final path. ocp.partial.save(path, {'step': 1}) ocp.partial.save_checkpointables(path, {'metrics': ...}) # This call performs the atomic rename, making the checkpoint available at # '/path/to/my/checkpoint'. ocp.partial.finalize(path) Args: path: The final, target path of the checkpoint to be finalized. This should be the same path that was passed to :py:func:`~.save` calls. Raises: FileExistsError: If a finalized checkpoint already exists at `path`. To overwrite, it must be deleted first. FileNotFoundError: If no partial save session is found for the given `path`. This can happen if :py:func:`.save` was not called first. """ context = context_lib.get_context() path = context.file_options.path_class(path) if partial_path_lib.is_partial_save_path(path): final_path = partial_path_lib.remove_partial_save_suffix(path) partial_path = path else: final_path = path partial_path = partial_path_lib.add_partial_save_suffix(path) async def _finalize_impl(): await multihost.sync_global_processes( multihost.unique_barrier_key( 'OcpPartialSaving:finalize_path_existence_start', prefix=context.multiprocessing_options.barrier_sync_key_prefix, ), operation_id=synchronization.get_operation_id(), processes=context.multiprocessing_options.active_processes, ) if await async_path.exists(final_path): raise FileExistsError( f'Finalized checkpoint already exists at {final_path}.' ) elif not await async_path.exists(partial_path): raise FileNotFoundError( f'Partial save path {partial_path} does not exist.' ) await multihost.sync_global_processes( multihost.unique_barrier_key( 'OcpPartialSaving:finalize_path_rename_start', prefix=context.multiprocessing_options.barrier_sync_key_prefix, ), operation_id=synchronization.get_operation_id(), processes=context.multiprocessing_options.active_processes, ) finalize_failed = False finalize_error = None if multihost.is_primary_host(context.multiprocessing_options.primary_host): try: await _merge_all(partial_path) await async_path.rename(partial_path, final_path) except (ValueError, OSError) as e: finalize_failed = True finalize_error = e finalize_failed = multihost.broadcast_one_to_all( finalize_failed, is_source=multihost.is_primary_host( context.multiprocessing_options.primary_host ), ) await multihost.sync_global_processes( multihost.unique_barrier_key( 'OcpPartialSaving:finalize_rename_complete', prefix=context.multiprocessing_options.barrier_sync_key_prefix, ), operation_id=synchronization.get_operation_id(), processes=context.multiprocessing_options.active_processes, ) if finalize_failed: raise finalize_error or OSError('Partial checkpoint finalization failed.') asyncio_utils.run_sync(_finalize_impl())