# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines free-function interface for loading."""
from __future__ import annotations
import functools
import time
from typing import Any, Awaitable, Protocol
from absl import logging
import jax
from orbax.checkpoint._src import asyncio_utils
from orbax.checkpoint._src.logging import event_tracking
from orbax.checkpoint.experimental.v1._src.context import context as context_lib
from orbax.checkpoint.experimental.v1._src.context import options as options_lib
from orbax.checkpoint.experimental.v1._src.handlers import types as handler_types
import orbax.checkpoint.experimental.v1._src.handlers.global_registration # pylint: disable=unused-import
from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout
from orbax.checkpoint.experimental.v1._src.layout import registry as layout_registry
from orbax.checkpoint.experimental.v1._src.loading import validation
from orbax.checkpoint.experimental.v1._src.metadata import types as metadata_types
from orbax.checkpoint.experimental.v1._src.path import types as path_types
from orbax.checkpoint.experimental.v1._src.synchronization import multihost
from orbax.checkpoint.experimental.v1._src.synchronization import synchronization
from orbax.checkpoint.experimental.v1._src.synchronization import thread_utils
from orbax.checkpoint.experimental.v1._src.synchronization import types as async_types
from orbax.checkpoint.experimental.v1._src.tree import types as tree_types
from typing_extensions import deprecated # pytype: disable=not-supported-yet
STATE_CHECKPOINTABLE_KEY = checkpoint_layout.STATE_CHECKPOINTABLE_KEY
AUTO_CHECKPOINTABLE_KEY = checkpoint_layout.AUTO_CHECKPOINTABLE_KEY
AbstractPyTree = tree_types.PyTreeOf[tree_types.AbstractLeaf]
CheckpointMetadata = metadata_types.CheckpointMetadata
PLACEHOLDER = ...
Checkpointable = handler_types.Checkpointable
AbstractCheckpointable = handler_types.AbstractCheckpointable
AsyncResponse = async_types.AsyncResponse
class LoadFn(Protocol):
"""Protocol for a two-phase load function used in `_load_impl`.
Is a callable that, when awaited, performs validation and setup, then
resolves to a second awaitable for the background load operation (I/O).
"""
async def __call__(self) -> Awaitable[Any]:
...
def _standardize_abstract_checkpointables(abstract_checkpointables):
"""Standardizes abstract checkpointables for loading.
This function resolves AbstractMesh instances within NamedSharding to concrete
Mesh instances and extracts metadata if the input is a CheckpointMetadata
object.
Args:
abstract_checkpointables: The abstract checkpointables, potentially wrapped
in CheckpointMetadata.
Returns:
The standardized abstract checkpointables.
"""
def _resolve_abstract_mesh(leaf):
if (
hasattr(leaf, 'sharding')
and isinstance(leaf.sharding, jax.sharding.NamedSharding)
and isinstance(leaf.sharding.mesh, jax.sharding.AbstractMesh)
):
new_sharding = jax.sharding.NamedSharding(
jax.sharding.get_mesh(), leaf.sharding.spec
)
return jax.ShapeDtypeStruct(leaf.shape, leaf.dtype, sharding=new_sharding)
return leaf
if isinstance(abstract_checkpointables, CheckpointMetadata):
abstract_checkpointables = abstract_checkpointables.metadata
elif abstract_checkpointables is not None and hasattr(
jax.sharding, 'get_mesh'
):
# jax.sharding.get_mesh() was recently added so check to ensure that we
# don't crash on older versions.
return jax.tree_util.tree_map(
_resolve_abstract_mesh, abstract_checkpointables
)
return abstract_checkpointables
[docs]
def load(
path: path_types.PathLike,
abstract_state: (
AbstractPyTree | CheckpointMetadata[AbstractPyTree] | None
) = None,
*,
checkpointable_name: str | None = AUTO_CHECKPOINTABLE_KEY,
) -> tree_types.PyTreeOf[tree_types.Leaf]:
"""Loads a PyTree.
Loads from a `PyTree` checkpoint. A `PyTree` checkpoint must be a path
containing a subdirectory with the name provided by `checkpointable_name`,
with default value `AUTO`. See `checkpointable_name` for more details.
This function must be called on all available controller processes.
The operation blocks until complete. For improved performance, consider using
:py:func:`.load_async` instead.
If `abstract_state` is not provided, the `PyTree` will be loaded exactly as
saved.
IMPORTANT: Loading is more brittle and error-prone when not providing
`abstract_state`. Always provide `abstract_state` if possible. Note that
you can always obtain the tree structure from a saved checkpoint using
:py:func:`.metadata`.
Providing the `abstract_state` guarantees two things:
1. The restored tree will exactly match the structure of `abstract_state` (or
raise an error if it is impossible to guarantee this). For example, if
`abstract_state` is a custom object registered as a `PyTree`, the checkpoint
will be restored as the same object, if possible.
2. The leaves of the restored tree will be restored with the properties
indicated by the abstract leaves. For example, if a leaf in `abstract_state`
is a `jax.ShapeDtypeStruct`, the restored leaf will be a `jax.Array` with the
same shape and `dtype`. Each `AbstractLeaf` has a corresponding `Leaf`
that is restored. See `orbax.checkpoint.v1.tree` for a table
of standard supported leaf types.
Example Usage:
Load a saved PyTree with and without providing its abstract structure::
path = '/tmp/my_checkpoint'
# Save a checkpoint
state = {'a': jnp.arange(8), 'b': jnp.zeros(4)}
ocp.save(path, state)
# Load the checkpoint
# Highly recommended to provide the abstract pytree (structure/shapes)
abstract_state = jax.eval_shape(lambda: state)
# Method A: Load using the abstract structure.
# This automatically looks for the 'pytree' subdirectory inside 'path'.
restored = ocp.load(path, abstract_state)
# Method B: Infer structure from file (Not recommended for production use)
# cases or for complex trees.
restored_inferred = ocp.load(path)
Args:
path: The path to load the checkpoint from. This path must contain a
subdirectory with name provided by `checkpointable_name`. See
`checkpointable_name` for more details.
abstract_state: Provides a tree structure for the checkpoint to be restored
into. May be omitted to load exactly as saved, but this is much more
brittle than providing the tree.
checkpointable_name: The name of the checkpointable to load. A subdirectory
with this name must exist in `path`. If None, then path itself is expected
to contain all files relevant for loading the PyTree, rather than any
subdirectory. Such files include, for example, `manifest.ocdbt`,
`_METADATA`, `ocp.process_X`. Defaults to `AUTO`. Setting to `AUTO` mode
dynamically discovers and resolves a pytree checkpointable. It prioritizes
the standard 'pytree' checkpointable name if present, then sorts any other
valid pytree checkpointable names alphabetically and returns the first
valid one, and ultimately falls back to interpreting the path as a flat V0
root layout if no standard pytree exists.
Returns:
The restored `PyTree`.
"""
start_time = time.time()
event_tracking.OperationRecorder(
path,
operation_type=event_tracking.OperationType.LOAD,
async_origin=False,
).record_start()
abstract_state = _standardize_abstract_checkpointables(abstract_state)
validation.validate_pytree_checkpointable_name(checkpointable_name)
ctx = context_lib.get_context()
path = ctx.file_options.path_class(path)
resolver = asyncio_utils.run_sync(
layout_registry.CheckpointLayoutResolver.resolve(
path, ctx.checkpoint_layout, pytree_name=checkpointable_name
)
)
loaded_pytree = _load_impl(
path,
functools.partial(
resolver.layout.load,
path=path,
checkpointable_name=resolver.pytree_name,
abstract_state=abstract_state,
),
start_time=start_time,
)
return loaded_pytree
[docs]
def load_checkpointables(
path: path_types.PathLike,
abstract_checkpointables: (
dict[str, AbstractCheckpointable]
| CheckpointMetadata[dict[str, AbstractCheckpointable]]
| None
) = None,
) -> dict[str, Checkpointable]:
"""Loads checkpointables.
See documentation for :py:func:`.save_checkpointables` for more context on
what a checkpointable is.
This function can be used to load any checkpoint saved by
:py:func:`.save_checkpointables` (or :py:func:`.save`). The path should
contain a number of subdirectories - each of these represents the name of a
checkpointable.
This function must be called on all available controller processes.
The operation blocks until complete. For improved performance, consider using
:py:func:`.load_checkpointables_async` instead.
If `abstract_checkpointables` is not provided, the checkpointables will be
loaded exactly as saved.
IMPORTANT: Loading is more brittle and error-prone when not providing
`abstract_checkpointables`. Always provide `abstract_checkpointables` if
possible. Note that you can always obtain the information about the
checkpointables using
:py:func:`.checkpointables_metadata`.
If `abstract_checkpointables` is provided, the value provided for each key
is treated as the abstract type for the given checkpointable. For example, for
a `PyTree` of `jax.Array`, the corresponding abstract checkpointable is a
`PyTree` of `jax.ShapeDtypeStruct`. `None` is always a valid abstract
checkpointable, which just indicates that the checkpointable should be loaded
exactly as saved.
The keys provided in `abstract_checkpointables` may be any subset of the
checkpointables in the checkpoint. Any checkpointables names not provided in
`abstract_checkpointables` will not be loaded.
Example Usage:
Load checkpointables from a saved checkpoint::
path = '/tmp/my_checkpoint_step_100'
# Save multiple components (checkpointables)
params = {'w': jnp.ones((8, 8)), 'b': jnp.zeros(8)}
opt_state = {'count': jnp.array(100)}
# Setup Grain (Stateful Checkpointable)
import grain
dataset_iter = iter(
grain.MapDataset.range(30)
.batch(3)
.map(lambda x: x.tolist())
)
ocp.save_checkpointables(path, {
'model': params,
'optimizer': opt_state,
'dataset': dataset_iter,
})
# Load the checkpointables
abstract_params = jax.eval_shape(lambda: params)
abstract_opt = jax.eval_shape(lambda: opt_state)
abstract_checkpointables = {
'model': abstract_params,
'optimizer': abstract_opt,
# Dataset is restored statefully. An initialized object must be
# passed, but its position will be set to the position recorded in the
# checkpoint after restoring.
'dataset': dataset_iter,
}
# Load all components
restored = ocp.load_checkpointables(path, abstract_checkpointables)
# Load only a subset
restored_subset = ocp.load_checkpointables(
path,
{'model': abstract_params}
)
Args:
path: The path to load the checkpoint from. This path must contain a
subdirectory for each checkpointable.
abstract_checkpointables: A dictionary of abstract checkpointables.
Dictionary keys represent the names of the checkpointables, while the
values are the abstract checkpointable objects themselves.
Returns:
A dictionary of checkpointables. Dictionary keys represent the names of the
checkpointables, while the values are the checkpointable objects themselves.
Raises:
FileNotFoundError: If the checkpoint path does not exist.
"""
start_time = time.time()
event_tracking.OperationRecorder(
path,
operation_type=event_tracking.OperationType.LOAD,
async_origin=False,
).record_start()
abstract_checkpointables = _standardize_abstract_checkpointables(
abstract_checkpointables
)
validation.validate_abstract_checkpointables(abstract_checkpointables)
ctx = context_lib.get_context()
path = ctx.file_options.path_class(path)
layout = asyncio_utils.run_sync(
layout_registry.get_checkpoint_layout(path, ctx.checkpoint_layout)
)
if not hasattr(layout, 'load_checkpointables'):
raise NotImplementedError(
f'Layout {type(layout)} does not support loading checkpointables.'
)
return _load_impl(
path,
functools.partial(
layout.load_checkpointables,
path=path,
abstract_checkpointables=abstract_checkpointables,
),
start_time=start_time,
)
def _load_impl(
path: path_types.Path,
load_fn: LoadFn,
start_time: float,
) -> dict[str, Checkpointable] | tree_types.PyTreeOf[tree_types.Leaf]:
"""Implementation of loading logic for both :py:func:`.load_checkpointables` and :py:func:`.load`.
Args:
path: The path to the checkpoint.
load_fn: A function that returns an awaitable for loading the checkpoint
based on either :py:func:`.load_checkpointables` or :py:func:`.load`.
start_time: The time when the loading process started.
Returns:
The loaded checkpointables or PyTree itself.
"""
if not path:
raise ValueError('Path must not be None.')
ctx = context_lib.get_context()
# Ensure the operation ID is incremented as soon as possible. This must be
# done uniquely for each load operation.
asyncio_utils.run_sync(
synchronization.synchronize_next_operation_id(
prefix=ctx.multiprocessing_options.barrier_sync_key_prefix,
processes=ctx.multiprocessing_options.active_processes,
)
)
async def _load() -> Checkpointable:
load_awaitable = await load_fn()
event_tracking.OperationRecorder(
path,
operation_type=event_tracking.OperationType.LOAD,
async_origin=False,
).record_blocking_completion(time.time() - start_time)
result = await load_awaitable
await multihost.sync_global_processes(
multihost.unique_barrier_key(
'_load_impl',
prefix=ctx.multiprocessing_options.barrier_sync_key_prefix,
),
operation_id=synchronization.get_operation_id(),
processes=ctx.multiprocessing_options.active_processes,
)
return result
result = asyncio_utils.run_sync(_load())
duration_secs = time.time() - start_time
event_tracking.OperationRecorder(
path,
operation_type=event_tracking.OperationType.LOAD,
async_origin=False,
).record_completion(duration_secs)
return result
class _LoadPyTreeResponse(AsyncResponse[tree_types.PyTreeOf[tree_types.Leaf]]):
"""An :py:class:`.AsyncResponse` for :py:func:`.load_async`."""
def __init__(
self,
operation_id: str,
path: path_types.Path,
background_awaitable: Awaitable[tree_types.PyTreeOf[tree_types.Leaf]],
*,
start_time: float,
context: context_lib.Context,
):
self._operation_id = operation_id
self._path = path
self._background_awaitable = background_awaitable
self._start_time = start_time
self._context = context
self._thread_runner = thread_utils.BackgroundThreadRunner[
tree_types.PyTreeOf[tree_types.Leaf]
](self._finalize_load())
@classmethod
def create(
cls,
background_awaitable: Awaitable[tree_types.PyTreeOf[tree_types.Leaf]],
path: path_types.Path,
start_time: float,
*,
context: context_lib.Context,
) -> _LoadPyTreeResponse:
"""Creates and returns the final AsyncResponse for a load operation."""
blocking_duration_secs = time.time() - start_time
event_tracking.OperationRecorder(
path,
operation_type=event_tracking.OperationType.LOAD,
async_origin=True,
).record_blocking_completion(blocking_duration_secs)
return cls(
synchronization.get_operation_id(),
path,
background_awaitable,
start_time=start_time,
context=context,
)
async def _finalize_load(self) -> tree_types.PyTreeOf[tree_types.Leaf]:
logging.info(
'[process=%s] Waiting for background load operations',
multihost.process_index(),
)
result = await self._background_awaitable
logging.vlog(
1,
'[process=%s] Finished waiting for background load operations.',
multihost.process_index(),
)
await multihost.sync_global_processes(
multihost.unique_barrier_key(
'_load_async:finalize',
prefix=(
self._context.multiprocessing_options.barrier_sync_key_prefix
),
),
operation_id=synchronization.get_operation_id(),
processes=self._context.multiprocessing_options.active_processes,
)
total_duration_secs = time.time() - self._start_time
event_tracking.OperationRecorder(
self._path,
operation_type=event_tracking.OperationType.LOAD,
async_origin=True,
).record_completion(total_duration_secs)
return result
def result(
self, timeout: float | None = None
) -> tree_types.PyTreeOf[tree_types.Leaf]:
return self._thread_runner.result(timeout=timeout)
[docs]
def load_async(
path: path_types.PathLike,
abstract_state: (
AbstractPyTree | CheckpointMetadata[AbstractPyTree] | None
) = None,
*,
checkpointable_name: str | None = STATE_CHECKPOINTABLE_KEY,
) -> async_types.AsyncResponse[tree_types.PyTreeOf[tree_types.Leaf]]:
"""Loads a PyTree asynchronously. Currently has limited support."""
start_time = time.time()
event_tracking.OperationRecorder(
path,
operation_type=event_tracking.OperationType.LOAD,
async_origin=True,
).record_start()
ctx = context_lib.get_context()
# Ensure the operation ID is incremented as soon as possible. This must be
# done uniquely for each load operation.
asyncio_utils.run_sync(
synchronization.synchronize_next_operation_id(
prefix=ctx.multiprocessing_options.barrier_sync_key_prefix,
processes=ctx.multiprocessing_options.active_processes,
)
)
if not path:
raise ValueError('Path must not be None.')
if ctx.checkpoint_layout != options_lib.CheckpointLayout.SAFETENSORS:
raise NotImplementedError(
'Asynchronous loading only supported for SAFETENSORS checkpoint '
f'layout, not {ctx.checkpoint_layout}.'
)
path = ctx.file_options.path_class(path)
abstract_state = _standardize_abstract_checkpointables(abstract_state)
validation.validate_pytree_checkpointable_name(checkpointable_name)
async def _blocking_load() -> Any:
resolver = await layout_registry.CheckpointLayoutResolver.resolve(
path, ctx.checkpoint_layout, pytree_name=checkpointable_name
)
return await resolver.layout.load(
path,
checkpointable_name=resolver.pytree_name,
abstract_state=abstract_state,
)
background_awaitable = asyncio_utils.run_sync(_blocking_load())
response = _LoadPyTreeResponse.create(
background_awaitable,
path,
start_time=start_time,
context=ctx,
)
return response
[docs]
def load_checkpointables_async(
path: path_types.PathLike,
abstract_checkpointables: (
dict[str, AbstractCheckpointable]
| CheckpointMetadata[dict[str, AbstractCheckpointable]]
| None
) = None,
) -> async_types.AsyncResponse[dict[str, Checkpointable]]:
"""Loads checkpointables asynchronously. Not yet implemented."""
del path, abstract_checkpointables
raise NotImplementedError('Asynchronous loading is not yet supported.')
@deprecated('Use `load` instead.')
def load_pytree(*args, **kwargs):
return load(*args, **kwargs)
@deprecated('Use `load_async` instead.')
def load_pytree_async(*args, **kwargs):
return load_async(*args, **kwargs)