# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""StandardCheckpointHandler class."""
from __future__ import annotations
import dataclasses
import functools
import numbers
from typing import Any, Callable, List, Optional
from etils import epath
import jax
import numpy as np
from orbax.checkpoint import checkpoint_args
from orbax.checkpoint import checkpoint_utils
from orbax.checkpoint import options as options_lib
from orbax.checkpoint._src import asyncio_utils
from orbax.checkpoint._src.arrays import sharding as arrays_sharding_lib
from orbax.checkpoint._src.futures import future
from orbax.checkpoint._src.handlers import async_checkpoint_handler
from orbax.checkpoint._src.handlers import pytree_checkpoint_handler
from orbax.checkpoint._src.metadata import pytree_metadata_options as pytree_metadata_options_lib
from orbax.checkpoint._src.metadata import tree as tree_metadata
from orbax.checkpoint._src.metadata import value as value_metadata
from orbax.checkpoint._src.path import types as path_types
from orbax.checkpoint._src.tree import types as tree_types
from orbax.checkpoint._src.tree import utils as tree_utils
PyTree = Any
CheckpointArgs = checkpoint_args.CheckpointArgs
PyTreeMetadataOptions = pytree_metadata_options_lib.PyTreeMetadataOptions
register_with_handler = checkpoint_args.register_with_handler
[docs]
class StandardCheckpointHandler(
async_checkpoint_handler.DeferredPathAsyncCheckpointHandler
):
"""A CheckpointHandler implementation for any PyTree structure.
See JAX documentation for more information on what constitutes a "PyTree".
This handler is capable of saving and restoring PyTrees with leaves of type
Python scalar, np.ndarray, and jax.Array
As with all :py:class:`.CheckpointHandler` subclasses,
`StandardCheckpointHandler` should only be used in conjunction with a
:py:class:`.Checkpointer` (or subclass). By itself, the `CheckpointHandler` is
non-atomic.
Example::
ckptr = Checkpointer(StandardCheckpointHandler())
# OR
ckptr = StandardCheckpointer()
If you find that your use case is not covered by `StandardCheckpointHandler`,
consider using the parent class directly, or explore a custom implementation
of `CheckpointHandler`.
"""
[docs]
def __init__(
self,
*,
save_concurrent_gb: int = 96,
restore_concurrent_gb: int = 96,
multiprocessing_options: options_lib.MultiprocessingOptions = (
options_lib.MultiprocessingOptions()
),
pytree_metadata_options: PyTreeMetadataOptions = (
pytree_metadata_options_lib.PYTREE_METADATA_OPTIONS
),
use_ocdbt: bool = True,
):
"""Creates StandardCheckpointHandler.
Args:
save_concurrent_gb: max concurrent GB that are allowed to be writing to
disk at any given time. This limits the amount of data currently being
written to disk, which can help to reduce the possibility of OOM's when
large checkpoints are saved. Note that this does NOT limit
device-to-host transfer, meaning that the limit specified here may still
be exceeded by the total memory usage of the process.
restore_concurrent_gb: max concurrent GB that are allowed to be restored.
Can help to reduce the possibility of OOM's when large checkpoints are
restored.
multiprocessing_options: See orbax.checkpoint.options.
pytree_metadata_options: Options to control types like tuple and
namedtuple in pytree metadata.
use_ocdbt: Whether to enable Tensorstore OCDBT driver.
"""
self._supported_types = checkpoint_utils.STANDARD_ARRAY_TYPES
self._impl = pytree_checkpoint_handler.PyTreeCheckpointHandler(
save_concurrent_gb=save_concurrent_gb,
restore_concurrent_gb=restore_concurrent_gb,
multiprocessing_options=multiprocessing_options,
pytree_metadata_options=pytree_metadata_options,
use_ocdbt=use_ocdbt,
)
def _validate_save_state(
self, item: PyTree, save_args: Optional[PyTree] = None
) -> PyTree:
if item is None:
raise ValueError('Must provide item to save.')
if isinstance(item, jax.Array | numbers.Number):
raise ValueError(
'StandardCheckpointHandler / StandardSave does not support single '
'arrays or scalars. Use ArrayCheckpointHandler / ArraySave'
)
if save_args is None:
save_args = jax.tree.map(lambda x: None, item)
def _check_input(k, x, arg):
if arg is not None:
if arg.aggregate:
raise ValueError(f'Unsupported option `aggregate` for key: {k}.')
if not isinstance(x, (np.ndarray, jax.Array)) and hasattr(x, '__array__'):
x = np.asarray(x)
if not isinstance(x, self._supported_types):
k = tree_utils.tuple_path_from_keypath(k)
raise ValueError(f'Unsupported type: {type(x)} for key: {k}.')
return x
return jax.tree_util.tree_map_with_path(_check_input, item, save_args)
def _validate_restore_state(self, item: PyTree):
def _check_input(k, x):
if not isinstance(x, self._supported_types) and not isinstance(
x, jax.ShapeDtypeStruct
):
k = tree_utils.tuple_path_from_keypath(k)
raise ValueError(f'Unsupported type: {type(x)} for key: {k}.')
jax.tree_util.tree_map_with_path(_check_input, item)
[docs]
async def async_save(
self,
directory: epath.Path | path_types.PathAwaitingCreation,
item: Optional[PyTree] = None,
save_args: Optional[PyTree] = None,
args: Optional[StandardSaveArgs] = None,
) -> Optional[List[future.Future]]:
"""Saves a PyTree of array-like objects.
See :py:class:`.PyTreeCheckpointHandler`.
Args:
directory: path to the directory where the checkpoint will be saved.
item: Deprecated, use `args`.
save_args: Deprecated, use `args`.
args: `StandardSaveArgs` (see below).
Returns:
A list of futures that will be completed when the save is complete.
"""
if isinstance(item, CheckpointArgs):
raise ValueError(
'Make sure to specify kwarg name `args=` when providing'
' `StandardSaveArgs`.'
)
custom_metadata = None
if args is not None:
item = args.item
save_args = args.save_args
custom_metadata = args.custom_metadata
item = self._validate_save_state(item, save_args=save_args)
return await self._impl.async_save(
directory,
args=pytree_checkpoint_handler.PyTreeSaveArgs(
item=item,
save_args=save_args,
custom_metadata=custom_metadata,
),
)
[docs]
def save(self, directory: epath.Path, *args, **kwargs):
"""Saves the provided item synchronously."""
async def async_save(*args, **kwargs):
commit_futures = await self.async_save(*args, **kwargs) # pytype: disable=bad-return-type
# Futures are already running, so sequential waiting is equivalent to
# concurrent waiting.
if commit_futures: # May be None.
for f in commit_futures:
f.result() # Block on result.
asyncio_utils.run_sync(async_save(directory, *args, **kwargs))
[docs]
def restore(
self,
directory: epath.Path,
item: Optional[PyTree] = None,
args: Optional[StandardRestoreArgs] = None,
) -> PyTree:
"""Restores a PyTree. See :py:class:`.PyTreeCheckpointHandler`.
Example::
ckptr = StandardCheckpointer()
item = {
'layer0': {
'w': jax.Array(...),
'b': np.ndarray(...),
},
}
ckptr.save(dir, StandardSaveArgs(item))
target = {
'layer0': {
'w': jax.ShapeDtypeStruct(...),
'b': jax.Array(...),
},
}
ckptr.restore(dir, StandardRestoreArgs(target))
Args:
directory: path from which to restore.
item: Deprecated, use `args`.
args: `StandardRestoreArgs` (see below).
Returns:
a restored PyTree.
"""
if isinstance(item, CheckpointArgs):
raise ValueError(
'Make sure to specify kwarg name `args=` when providing'
' `StandardRestoreArgs`.'
)
if not args:
args = StandardRestoreArgs(item=item)
if args.item is not None:
self._validate_restore_state(args.item)
restore_args = _construct_restore_args(
args.item,
functools.partial(self.metadata, directory),
args.fallback_sharding,
args.support_layout,
)
def _replace_strict(
arg: pytree_checkpoint_handler.RestoreArgs,
) -> pytree_checkpoint_handler.RestoreArgs:
if hasattr(arg, 'strict'):
return dataclasses.replace(arg, strict=False)
return arg
if not args.strict:
restore_args = jax.tree.map(_replace_strict, restore_args)
return self._impl.restore(
directory,
args=pytree_checkpoint_handler.PyTreeRestoreArgs(
item=args.item, restore_args=restore_args
),
)
[docs]
def finalize(self, directory: epath.Path) -> None:
self._impl.finalize(directory)
[docs]
def close(self):
self._impl.close()
[docs]
@register_with_handler(StandardCheckpointHandler, for_save=True)
@dataclasses.dataclass
class StandardSaveArgs(CheckpointArgs):
"""Parameters for saving a standard PyTree.
Also see :py:class:`.PyTreeSave` for additional options.
Attributes:
item (required): a PyTree to be saved.
save_args: a PyTree with the same structure of `item`, which consists of
`ocp.SaveArgs` objects as values. `None` can be used for values where no
`SaveArgs` are specified.
custom_metadata: User-provided custom metadata. An arbitrary
JSON-serializable dictionary the user can use to store additional
information. The field is treated as opaque by Orbax.
"""
item: PyTree
save_args: Optional[PyTree] = None
custom_metadata: tree_types.JsonType | None = None
def __post_init__(self):
if isinstance(self.item, tree_metadata.TreeMetadata):
raise ValueError('Cannot save TreeMetadata.')
def _construct_restore_args(
target: PyTree | None,
metadata: Callable[[], tree_metadata.TreeMetadata],
fallback_sharding: jax.sharding.Sharding | None,
support_layout: bool = False,
) -> PyTree:
"""Creates restore_args given a target tree and sharding tree we construct.
If target tree does not exist, use metadata_tree as target tree.
Overrides the sharding in `target tree` with fallback_sharding if the sharding
in `target tree` is either missing or incompatible with the current device
mesh.
Accounts for the following cases:
- sharding exists and is in target tree leaf (use target sharding)
- sharding missing in target but exists in metadata tree leaf (fallback to
metadata sharding)
- sharding in metadata tree leaf is incompatible with current device mesh
(fallback to fallback_sharding)
Args:
target: The returned TreeMetadata will match the structure of `target`.
metadata: A callable that returns the metadata to be used as target if
target is none or as fallback sharding.
fallback_sharding: If provided, this sharding is used as fallback if the
sharding in `target` fails to load from the checkpoint.
support_layout: If true, layout is extracted instead of explicit sharding.
Returns:
A PyTree matching target of RestoreArgs (or ArrayRestoreArgs) objects.
"""
@functools.lru_cache(maxsize=1)
def _get_loaded_metadata():
return metadata()
@functools.lru_cache(maxsize=1)
def _get_flat_metadata():
return tree_utils.to_flat_dict(_get_loaded_metadata().tree)
def _get_sharding_from_metadata_leaf(metadata_leaf):
if (
isinstance(metadata_leaf, value_metadata.ArrayMetadata)
and metadata_leaf.sharding is not None
):
try:
return metadata_leaf.sharding.to_jax_sharding()
except ValueError as e:
if fallback_sharding is not None:
return fallback_sharding
raise ValueError(
'Topology mismatch detected. The checkpoint was saved with'
' a different topology than the current one. Please provide'
' a target tree with the desired topology to restore.'
) from e
return None
def _get_sharding_for_target_leaf(
path: tuple[Any, ...], item_leaf: Any
) -> jax.sharding.Sharding | None:
"""Determines the sharding for a given leaf in the target tree."""
if not hasattr(item_leaf, 'sharding'):
return None
# 1. Check sharding on item_leaf itself and return proper sharding format
sharding = arrays_sharding_lib.get_sharding_or_format(
item_leaf, support_format=support_layout
)
if sharding is not None:
return sharding
# 2. If item_leaf.sharding is None, try to get from metadata_tree.
# We iterate on the target tree and use path-based lookups for metadata
# because direct jax.tree.map(target, metadata) fails when target contains
# unregistered custom PyTree nodes (like Flax dataclasses) that don't
# structurally match the dict-based metadata tree.
target_path_tuple = tree_utils.tuple_path_from_keypath(path)
metadata_leaf = _get_flat_metadata().get(target_path_tuple)
if metadata_leaf is None:
raise ValueError(
f'Structure mismatch: Target path {target_path_tuple} not found in'
' metadata_tree. Expected identical structures. Metadata keys:'
f' {_get_flat_metadata().keys()}'
)
sharding = _get_sharding_from_metadata_leaf(metadata_leaf)
return sharding
if target is not None:
target_tree = target
sharding_tree = jax.tree_util.tree_map_with_path(
_get_sharding_for_target_leaf,
target_tree,
)
else:
loaded_metadata_tree = _get_loaded_metadata()
target_tree = loaded_metadata_tree.tree
sharding_tree = jax.tree_util.tree_map(
_get_sharding_from_metadata_leaf, loaded_metadata_tree.tree
)
return checkpoint_utils.construct_restore_args(target_tree, sharding_tree)
[docs]
@register_with_handler(StandardCheckpointHandler, for_restore=True)
@dataclasses.dataclass
class StandardRestoreArgs(CheckpointArgs):
"""Parameters for restoring a standard PyTree.
Also see :py:class:`.PyTreeRestore` for additional options.
Attributes (all optional):
item: target PyTree. Currently non-optional. Values may be either real
array or scalar values, or they may be jax.ShapeDtypeStruct, or
`ocp.metadata.value.Metadata` objects (which come from calling the
`metadata` method). If real values are provided,
that value will be restored as the given type, with
the given properties. If jax.ShapeDtypeStruct is provided, the value
will be restored as np.ndarray, unless `sharding` is specified. If
`item` is a custom PyTree class, the tree will be restored with the
same structure as provided. If not provided, restores as a serialized
nested dict representation of the custom class.
`TreeMetadata` is also allowed as the tree used to
define the restored structure.
strict: if False, restoration allows silent truncating/padding of arrays if
the stored array shape does not match the target shape. Otherwise,
raises an error.
support_layout: if True, restores with the layouts in `item`.
fallback_sharding: If provided, this sharding will be used as a fallback
if the saved sharding fails to load from the checkpoint.
"""
item: Optional[PyTree] = None
strict: bool = True
support_layout: bool = False
fallback_sharding: Optional[jax.sharding.Sharding] = None
def __post_init__(self):
if isinstance(self.item, tree_metadata.TreeMetadata):
self.item = self.item.tree