# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of :py:class:`~.v1.handlers.CheckpointableHandler` for PyTrees."""
from __future__ import annotations
import asyncio
import dataclasses
import time
import typing
from typing import Any, Awaitable, Sequence, get_args
from absl import logging
import jax
import numpy as np
from orbax.checkpoint import options as v0_options_lib
from orbax.checkpoint._src.futures import future
from orbax.checkpoint._src.futures import synchronization
from orbax.checkpoint._src.handlers import base_pytree_checkpoint_handler
from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib
from orbax.checkpoint._src.serialization import type_handlers as type_handlers_v0
from orbax.checkpoint._src.serialization import types as v0_serialization_types
from orbax.checkpoint.experimental.v1._src.context import context as context_lib
from orbax.checkpoint.experimental.v1._src.context import options as options_lib
from orbax.checkpoint.experimental.v1._src.handlers import types as handler_types
from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout
from orbax.checkpoint.experimental.v1._src.metadata import types as metadata_types
from orbax.checkpoint.experimental.v1._src.path import types as path_types
from orbax.checkpoint.experimental.v1._src.serialization import compatibility
from orbax.checkpoint.experimental.v1._src.serialization import options_resolution
from orbax.checkpoint.experimental.v1._src.serialization import protocol_utils
from orbax.checkpoint.experimental.v1._src.serialization import registry
from orbax.checkpoint.experimental.v1._src.serialization import scalar_leaf_handler
from orbax.checkpoint.experimental.v1._src.serialization import types as serialization_types
from orbax.checkpoint.experimental.v1._src.synchronization import multihost
from orbax.checkpoint.experimental.v1._src.tree import types as tree_types
Path = path_types.Path
CheckpointableHandler = handler_types.CheckpointableHandler
PyTree = tree_types.PyTree
PartialSaveError = base_pytree_checkpoint_handler.PartialSaveError
PartialSaveReplacementError = (
base_pytree_checkpoint_handler.PartialSaveReplacementError
)
STATE_CHECKPOINTABLE_KEY = checkpoint_layout.STATE_CHECKPOINTABLE_KEY
def _get_remaining_timeout(
start_time: float,
timeout_secs: float,
error_message: str,
) -> float:
"""Returns remaining timeout in seconds, or raises TimeoutError if expired."""
time_remaining = timeout_secs - (time.time() - start_time)
if time_remaining <= 0:
raise TimeoutError(error_message)
return time_remaining
def _get_v0_save_args(
checkpointable: PyTree,
array_saving_options: options_lib.ArrayOptions.Saving,
) -> PyTree:
"""Returns save args that are compatible with the V0 API."""
def _leaf_get_v0_save_args(k, v):
resolved_options = options_resolution.resolve_storage_options(
k, v, array_saving_options
)
return type_handlers_v0.SaveArgs(
dtype=np.dtype(resolved_options.dtype)
if resolved_options.dtype is not None
else None,
chunk_byte_size=resolved_options.chunk_byte_size,
shard_axes=resolved_options.shard_axes,
)
return jax.tree.map_with_path(_leaf_get_v0_save_args, checkpointable)
def _create_v0_handler(
context: context_lib.Context,
*,
type_handler_registry: v0_serialization_types.TypeHandlerRegistry,
array_metadata_validator: array_metadata_store_lib.Validator = array_metadata_store_lib.Validator(),
) -> base_pytree_checkpoint_handler.BasePyTreeCheckpointHandler:
"""Creates a V0 handler from a V1 context."""
return base_pytree_checkpoint_handler.BasePyTreeCheckpointHandler(
save_concurrent_bytes=context.memory_options.write_concurrent_bytes,
restore_concurrent_bytes=context.memory_options.read_concurrent_bytes,
save_device_host_concurrent_bytes=context.memory_options.transfer_concurrent_bytes,
use_ocdbt=context.array_options.saving.use_ocdbt,
use_zarr3=context.array_options.saving.use_zarr3,
use_compression=context.array_options.saving.use_compression,
multiprocessing_options=v0_options_lib.MultiprocessingOptions(
primary_host=context.multiprocessing_options.primary_host,
active_processes=context.multiprocessing_options.active_processes,
barrier_sync_key_prefix=context.multiprocessing_options.barrier_sync_key_prefix,
),
type_handler_registry=type_handler_registry,
enable_post_merge_validation=context.array_options.saving.enable_post_merge_validation,
pytree_metadata_options=context.pytree_options.saving.pytree_metadata_options,
array_metadata_validator=array_metadata_validator,
enable_pinned_host_transfer=context.array_options.saving.enable_pinned_host_transfer,
is_prioritized_key_fn=context.memory_options.is_prioritized_key_fn,
)
def create_v0_save_args(
context: context_lib.Context,
checkpointable: PyTree,
) -> base_pytree_checkpoint_handler.BasePyTreeSaveArgs:
"""Creates v0 CheckpointArgs for saving."""
return base_pytree_checkpoint_handler.BasePyTreeSaveArgs(
item=checkpointable,
save_args=_get_v0_save_args(
checkpointable,
context.array_options.saving,
),
ocdbt_target_data_file_size=context.array_options.saving.ocdbt_target_data_file_size,
)
def _restore_type_by_abstract_type(
abstract_checkpointable: Any,
) -> Any:
"""Allows users to override the restored type.
When users pass the `value` in the `DeserializationParam`, the `PyTreeHandler`
will try to restore to the specified type `T`. This only supports the standard
types supported by Orbax.
For example:
- `jax.ShapeDtype` -> `jax.Array`
- `NumpyAbstractType` -> `jax.Array`
- `int` | `float` | `Type[int]` | `Type[float]` -> `int` | `float` | `int` |
`float`
Args:
abstract_checkpointable: The abstract checkpointable passed in by the user.
Returns:
Returns the `restore_type` parameter for `V0RestoreArgs`. This is needed to
determine which `LeafHandler` will eventually handle this
`abstract_checkpointable`.
"""
if abstract_checkpointable is None:
ret = None
elif serialization_types.is_placeholder(abstract_checkpointable):
ret = serialization_types.PLACEHOLDER
else:
if isinstance(abstract_checkpointable, type):
abstract_type = abstract_checkpointable
else:
abstract_type = type(abstract_checkpointable)
# Make sure test with AbstractShardedArray before AbstractArray otherwise
# Numpy will be matched first.
if protocol_utils.is_subclass_protocol(
abstract_type, serialization_types.AbstractShardedArray
):
ret = jax.Array
elif protocol_utils.is_subclass_protocol(
abstract_type, serialization_types.AbstractArray
):
ret = np.ndarray
elif issubclass(abstract_type, get_args(scalar_leaf_handler.Scalar)):
ret = abstract_type
else:
# this will use registered handler derived from metadata
ret = None
logging.vlog(
1,
'abstract_checkpointable: %s, restore_type: %s',
abstract_checkpointable,
ret,
)
return ret
def create_v0_restore_args(
context: context_lib.Context,
abstract_checkpointable: PyTree | None,
) -> base_pytree_checkpoint_handler.BasePyTreeRestoreArgs:
"""Creates v0 CheckpointArgs for restoration."""
if abstract_checkpointable:
restore_args = jax.tree.map(
lambda checkpointable: compatibility.V0RestoreArgs(
restore_type=_restore_type_by_abstract_type(checkpointable),
abstract_leaf=checkpointable,
),
abstract_checkpointable,
)
else:
restore_args = None
logging.vlog(1, 'restore_args: %s', restore_args)
return base_pytree_checkpoint_handler.BasePyTreeRestoreArgs(
item=abstract_checkpointable,
restore_args=restore_args,
partial_restore=context.pytree_options.loading.partial_load,
)
async def _async_futures(
commit_futures: Sequence[future.Future],
timeout_secs: float | None = None,
start_time: float | None = None,
):
"""Waits for commit futures to complete with a timeout."""
deadline = (
start_time + timeout_secs
if timeout_secs is not None and start_time is not None
else None
)
def _wait_with_timeout(f: future.Future):
if deadline is None:
return f.result()
timeout = deadline - time.time()
if timeout <= 0:
raise TimeoutError('Overall save timeout exceeded.')
return f.result(timeout=timeout)
await asyncio.gather(
*[asyncio.to_thread(_wait_with_timeout, f) for f in commit_futures]
)
[docs]
@typing.final
class PyTreeHandler(CheckpointableHandler[PyTree, PyTree]):
"""An implementation of :py:class:`.CheckpointableHandler` for PyTrees.
PyTreeHandler manages the decomposition of JAX PyTree structures into leaf-
level parameters for persistence. It utilizes an asynchronous two-tier
execution model to allow for background I/O, ensuring that heavy array
serialization does not block the main training process.
**Note: Users are encouraged NEVER to instantiate or use this handler
directly.** Always use the top-level APIs like `ocp.save_checkpointables` and
`ocp.load_checkpointables`. Orbax uses this handler by default for standard
JAX PyTrees (like nested dictionaries of arrays).
To configure a specific serialization context for a PyTree and aggressively
force Orbax to use the customized PyTreeHandler, the recommended approach
is to use `ocp.Context` with `CheckpointablesOptions`. This allows you to
bind the handler to a specific dictionary key within the Context scope.
See :py:class:`~orbax.checkpoint.options.CheckpointablesOptions` for more
details on handler registration.
Usage Example:
Save a state dictionary configuration::
import orbax.checkpoint as ocp
state_pytree = {'weights': [1.0, 2.0], 'bias': 0.0}
registry = ocp.handlers.local_registry()
registry.add(
ocp.handlers.PyTreeHandler, checkpointable_name='model_state'
)
ctx = ocp.Context()
ctx.checkpointables.registry = registry
with ctx:
ocp.save_checkpointables(path, dict(model_state=state_pytree))
Attributes:
context (Optional[Context]): Optional V1 Context providing configuration for
serialization, array options, and multiprocessing coordination.
array_metadata_validator (Validator): A validator object used to verify
consistency of array metadata during restoration.
"""
[docs]
def __init__(
self,
*,
context: context_lib.Context | None = None,
array_metadata_validator: array_metadata_store_lib.Validator = (
array_metadata_store_lib.Validator()
),
leaf_handler_registry: (
serialization_types.LeafHandlerRegistry | None
) = None,
partial_save_mode: bool = False,
):
context = context if context is not None else context_lib.get_context()
self._context = context
self._multiprocessing_options = context.multiprocessing_options
self._partial_save_mode = partial_save_mode
self._leaf_handler_registry = (
leaf_handler_registry or registry.StandardLeafHandlerRegistry()
)
type_handler_registry = compatibility.get_v0_type_handler_registry(
self._leaf_handler_registry, self._context
)
self._handler_impl = _create_v0_handler(
context,
type_handler_registry=type_handler_registry,
array_metadata_validator=array_metadata_validator,
)
async def _finalize(self, directory: path_types.Path):
# Keep non-finalized checkpoint state during partial saves to be merged
# later during partial save finalization.
if self._partial_save_mode:
return
if multihost.is_primary_host(self._multiprocessing_options.primary_host):
await self._handler_impl._finalize_async(directory) # pylint: disable=protected-access
async def _background_save(
self,
directory: path_types.PathAwaitingCreation,
*,
commit_futures: Sequence[future.Future],
operation_id: str,
start_time: float,
):
timeout_secs = self._context.async_options.timeout_secs
directory = await directory.await_creation()
active_processes = self._multiprocessing_options.active_processes or set(
range(multihost.process_count())
)
await _async_futures(
commit_futures, timeout_secs=timeout_secs, start_time=start_time
)
# Global sync to ensure all participating processes have completed their
# save operations before proceeding to finalize.
barrier_name = f'save_and_finalize_{operation_id}_commit_complete'
if timeout_secs is None:
await multihost.sync_global_processes(
barrier_name,
operation_id=operation_id,
processes=active_processes,
)
else:
await multihost.sync_global_processes(
barrier_name,
operation_id=operation_id,
processes=active_processes,
timeout=int(
_get_remaining_timeout(
start_time,
timeout_secs,
'Timed out while waiting for commit to complete.',
)
),
)
# Finalize.
await self._finalize(directory)
# Global sync to ensure all hosts are aware that the finalize operation
# has completed before returning to the user.
barrier_name = f'save_and_finalize_{operation_id}_finalize_complete'
if timeout_secs is None:
await multihost.sync_global_processes(
barrier_name,
operation_id=operation_id,
processes=active_processes,
)
else:
await multihost.sync_global_processes(
barrier_name,
operation_id=operation_id,
processes=active_processes,
timeout=int(
_get_remaining_timeout(
start_time,
timeout_secs,
'Timed out while waiting for finalize to complete.',
)
),
)
[docs]
async def save(
self, directory: path_types.PathAwaitingCreation, checkpointable: PyTree
) -> Awaitable[None]:
start_time = time.time()
self.validate_leaves_handleable(checkpointable)
save_args = create_v0_save_args(self._context, checkpointable)
save_args = dataclasses.replace(
save_args, partial_save_mode=self._partial_save_mode
)
commit_futures = await self._handler_impl.async_save(
directory.path,
args=save_args,
)
assert commit_futures
# TODO(b/398310070): Move operation ID generation to `Context`.
operation_id = (
synchronization.OperationIdGenerator.get_current_operation_id()
)
# Needed to differentiate between different handlers when we have multiple
# PyTreeHandlers performing a save.
operation_id = f'{operation_id}.{directory.path.name}'
return self._background_save(
directory,
commit_futures=commit_futures,
operation_id=operation_id,
start_time=start_time,
)
async def _background_load(
self,
directory: path_types.Path,
abstract_checkpointable: PyTree | None = None,
) -> PyTree:
return self._handler_impl.restore(
directory,
args=create_v0_restore_args(self._context, abstract_checkpointable),
)
[docs]
async def load(
self,
directory: path_types.Path,
abstract_checkpointable: PyTree | None = None,
) -> Awaitable[PyTree]:
"""Loads a PyTree from a checkpoint directory.
Args:
directory: The directory to load from.
abstract_checkpointable: The abstract checkpointable to load into. If
None, the handler will attempt to load the entire checkpoint using the
recorded metadata. Otherwise, the `abstract_checkpointable` is expected
to be a PyTree of abstract leaves. The abstract leaf may be a value of
type :py:class:`~.v1.tree.AbstractLeaf`,
`Type[AbstractLeaf]`, or `None`. Passing the latter two indicates that
the metadata should be used to restore the leaf.
Returns:
A awaitable which can be awaited to complete the load operation and
obtain a PyTree.
"""
self.validate_abstract_leaves_handleable(abstract_checkpointable)
return self._background_load(directory, abstract_checkpointable)
def validate_leaves_handleable(self, checkpointable: PyTree):
missing_leaf_types = set()
def _validate_handleable_leaf(leaf: Any):
if serialization_types.is_placeholder(leaf):
return
leaf_type = type(leaf)
if not self._leaf_handler_registry.is_handleable(leaf_type):
missing_leaf_types.add(leaf_type)
jax.tree.map(
_validate_handleable_leaf,
checkpointable,
)
if missing_leaf_types:
raise registry.UnregisteredTypeError(
'The following leaf types are not registered in the'
f' `LeafHandlerRegistry`: [{missing_leaf_types}]. Please register a'
' `LeafHandler` for each type in the `LeafHandlerRegistry` and'
' assign it into the `PyTreeOptions` in the `Context`.'
)
def validate_abstract_leaves_handleable(
self, abstract_checkpointable: PyTree
):
missing_abstract_leaf_types = set()
def _validate_handleable_leaf(leaf: Any):
if serialization_types.is_placeholder(leaf):
return
leaf_type = leaf if isinstance(leaf, type) else type(leaf)
if not self._leaf_handler_registry.is_abstract_handleable(leaf_type):
missing_abstract_leaf_types.add(leaf_type)
jax.tree.map(
_validate_handleable_leaf,
abstract_checkpointable,
)
if missing_abstract_leaf_types:
raise registry.UnregisteredTypeError(
'The following abstract leaf types are not registered in the'
f' `LeafHandlerRegistry`: [{missing_abstract_leaf_types}]. Please'
' register a `LeafHandler` for each type in the'
' `LeafHandlerRegistry` and assign it into the `PyTreeOptions` in'
' the `Context`.'
)
[docs]
def is_handleable(self, checkpointable: Any) -> bool:
try:
# If it's a leaf it's not handleable.
tree_structure = jax.tree.structure(checkpointable)
return not (
jax.tree_util.treedef_is_leaf(tree_structure)
and tree_structure.num_leaves == 1
)
except Exception: # pylint: disable=broad-exception-caught
return False
[docs]
def is_abstract_handleable(self, abstract_checkpointable: Any) -> bool:
return self.is_handleable(abstract_checkpointable)