Source code for orbax.checkpoint.experimental.v1._src.loading.loading

# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Defines free-function interface for loading."""

from __future__ import annotations

import functools
import time
from typing import Any, Awaitable, Protocol

from absl import logging
import jax

from orbax.checkpoint._src import asyncio_utils
from orbax.checkpoint._src.logging import event_tracking
from orbax.checkpoint.experimental.v1._src.context import context as context_lib
from orbax.checkpoint.experimental.v1._src.context import options as options_lib
from orbax.checkpoint.experimental.v1._src.handlers import types as handler_types
import orbax.checkpoint.experimental.v1._src.handlers.global_registration  # pylint: disable=unused-import
from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout
from orbax.checkpoint.experimental.v1._src.layout import registry as layout_registry
from orbax.checkpoint.experimental.v1._src.loading import validation
from orbax.checkpoint.experimental.v1._src.metadata import types as metadata_types
from orbax.checkpoint.experimental.v1._src.path import types as path_types
from orbax.checkpoint.experimental.v1._src.synchronization import multihost
from orbax.checkpoint.experimental.v1._src.synchronization import synchronization
from orbax.checkpoint.experimental.v1._src.synchronization import thread_utils
from orbax.checkpoint.experimental.v1._src.synchronization import types as async_types
from orbax.checkpoint.experimental.v1._src.tree import types as tree_types
from typing_extensions import deprecated  # pytype: disable=not-supported-yet


STATE_CHECKPOINTABLE_KEY = checkpoint_layout.STATE_CHECKPOINTABLE_KEY
AUTO_CHECKPOINTABLE_KEY = checkpoint_layout.AUTO_CHECKPOINTABLE_KEY
AbstractPyTree = tree_types.PyTreeOf[tree_types.AbstractLeaf]
CheckpointMetadata = metadata_types.CheckpointMetadata
PLACEHOLDER = ...

Checkpointable = handler_types.Checkpointable
AbstractCheckpointable = handler_types.AbstractCheckpointable

AsyncResponse = async_types.AsyncResponse


class LoadFn(Protocol):
  """Protocol for a two-phase load function used in `_load_impl`.

  Is a callable that, when awaited, performs validation and setup, then
  resolves to a second awaitable for the background load operation (I/O).
  """

  async def __call__(self) -> Awaitable[Any]:
    ...


def _standardize_abstract_checkpointables(abstract_checkpointables):
  """Standardizes abstract checkpointables for loading.

  This function resolves AbstractMesh instances within NamedSharding to concrete
  Mesh instances and extracts metadata if the input is a CheckpointMetadata
  object.

  Args:
    abstract_checkpointables: The abstract checkpointables, potentially wrapped
      in CheckpointMetadata.

  Returns:
    The standardized abstract checkpointables.
  """

  def _resolve_abstract_mesh(leaf):
    if (
        hasattr(leaf, 'sharding')
        and isinstance(leaf.sharding, jax.sharding.NamedSharding)
        and isinstance(leaf.sharding.mesh, jax.sharding.AbstractMesh)
    ):
      new_sharding = jax.sharding.NamedSharding(
          jax.sharding.get_mesh(), leaf.sharding.spec
      )
      return jax.ShapeDtypeStruct(leaf.shape, leaf.dtype, sharding=new_sharding)
    return leaf

  if isinstance(abstract_checkpointables, CheckpointMetadata):
    abstract_checkpointables = abstract_checkpointables.metadata
  elif abstract_checkpointables is not None and hasattr(
      jax.sharding, 'get_mesh'
  ):
    # jax.sharding.get_mesh() was recently added so check to ensure that we
    # don't crash on older versions.
    return jax.tree_util.tree_map(
        _resolve_abstract_mesh, abstract_checkpointables
    )
  return abstract_checkpointables


[docs] def load( path: path_types.PathLike, abstract_state: ( AbstractPyTree | CheckpointMetadata[AbstractPyTree] | None ) = None, *, checkpointable_name: str | None = AUTO_CHECKPOINTABLE_KEY, ) -> tree_types.PyTreeOf[tree_types.Leaf]: """Loads a PyTree. Loads from a `PyTree` checkpoint. A `PyTree` checkpoint must be a path containing a subdirectory with the name provided by `checkpointable_name`, with default value `AUTO`. See `checkpointable_name` for more details. This function must be called on all available controller processes. The operation blocks until complete. For improved performance, consider using :py:func:`.load_async` instead. If `abstract_state` is not provided, the `PyTree` will be loaded exactly as saved. IMPORTANT: Loading is more brittle and error-prone when not providing `abstract_state`. Always provide `abstract_state` if possible. Note that you can always obtain the tree structure from a saved checkpoint using :py:func:`.metadata`. Providing the `abstract_state` guarantees two things: 1. The restored tree will exactly match the structure of `abstract_state` (or raise an error if it is impossible to guarantee this). For example, if `abstract_state` is a custom object registered as a `PyTree`, the checkpoint will be restored as the same object, if possible. 2. The leaves of the restored tree will be restored with the properties indicated by the abstract leaves. For example, if a leaf in `abstract_state` is a `jax.ShapeDtypeStruct`, the restored leaf will be a `jax.Array` with the same shape and `dtype`. Each `AbstractLeaf` has a corresponding `Leaf` that is restored. See `orbax.checkpoint.v1.tree` for a table of standard supported leaf types. Example Usage: Load a saved PyTree with and without providing its abstract structure:: path = '/tmp/my_checkpoint' # Save a checkpoint state = {'a': jnp.arange(8), 'b': jnp.zeros(4)} ocp.save(path, state) # Load the checkpoint # Highly recommended to provide the abstract pytree (structure/shapes) abstract_state = jax.eval_shape(lambda: state) # Method A: Load using the abstract structure. # This automatically looks for the 'pytree' subdirectory inside 'path'. restored = ocp.load(path, abstract_state) # Method B: Infer structure from file (Not recommended for production use) # cases or for complex trees. restored_inferred = ocp.load(path) Args: path: The path to load the checkpoint from. This path must contain a subdirectory with name provided by `checkpointable_name`. See `checkpointable_name` for more details. abstract_state: Provides a tree structure for the checkpoint to be restored into. May be omitted to load exactly as saved, but this is much more brittle than providing the tree. checkpointable_name: The name of the checkpointable to load. A subdirectory with this name must exist in `path`. If None, then path itself is expected to contain all files relevant for loading the PyTree, rather than any subdirectory. Such files include, for example, `manifest.ocdbt`, `_METADATA`, `ocp.process_X`. Defaults to `AUTO`. Setting to `AUTO` mode dynamically discovers and resolves a pytree checkpointable. It prioritizes the standard 'pytree' checkpointable name if present, then sorts any other valid pytree checkpointable names alphabetically and returns the first valid one, and ultimately falls back to interpreting the path as a flat V0 root layout if no standard pytree exists. Returns: The restored `PyTree`. """ start_time = time.time() event_tracking.OperationRecorder( path, operation_type=event_tracking.OperationType.LOAD, async_origin=False, ).record_start() abstract_state = _standardize_abstract_checkpointables(abstract_state) validation.validate_pytree_checkpointable_name(checkpointable_name) ctx = context_lib.get_context() path = ctx.file_options.path_class(path) resolver = asyncio_utils.run_sync( layout_registry.CheckpointLayoutResolver.resolve( path, ctx.checkpoint_layout, pytree_name=checkpointable_name ) ) loaded_pytree = _load_impl( path, functools.partial( resolver.layout.load, path=path, checkpointable_name=resolver.pytree_name, abstract_state=abstract_state, ), start_time=start_time, ) return loaded_pytree
[docs] def load_checkpointables( path: path_types.PathLike, abstract_checkpointables: ( dict[str, AbstractCheckpointable] | CheckpointMetadata[dict[str, AbstractCheckpointable]] | None ) = None, ) -> dict[str, Checkpointable]: """Loads checkpointables. See documentation for :py:func:`.save_checkpointables` for more context on what a checkpointable is. This function can be used to load any checkpoint saved by :py:func:`.save_checkpointables` (or :py:func:`.save`). The path should contain a number of subdirectories - each of these represents the name of a checkpointable. This function must be called on all available controller processes. The operation blocks until complete. For improved performance, consider using :py:func:`.load_checkpointables_async` instead. If `abstract_checkpointables` is not provided, the checkpointables will be loaded exactly as saved. IMPORTANT: Loading is more brittle and error-prone when not providing `abstract_checkpointables`. Always provide `abstract_checkpointables` if possible. Note that you can always obtain the information about the checkpointables using :py:func:`.checkpointables_metadata`. If `abstract_checkpointables` is provided, the value provided for each key is treated as the abstract type for the given checkpointable. For example, for a `PyTree` of `jax.Array`, the corresponding abstract checkpointable is a `PyTree` of `jax.ShapeDtypeStruct`. `None` is always a valid abstract checkpointable, which just indicates that the checkpointable should be loaded exactly as saved. The keys provided in `abstract_checkpointables` may be any subset of the checkpointables in the checkpoint. Any checkpointables names not provided in `abstract_checkpointables` will not be loaded. Example Usage: Load checkpointables from a saved checkpoint:: path = '/tmp/my_checkpoint_step_100' # Save multiple components (checkpointables) params = {'w': jnp.ones((8, 8)), 'b': jnp.zeros(8)} opt_state = {'count': jnp.array(100)} # Setup Grain (Stateful Checkpointable) import grain dataset_iter = iter( grain.MapDataset.range(30) .batch(3) .map(lambda x: x.tolist()) ) ocp.save_checkpointables(path, { 'model': params, 'optimizer': opt_state, 'dataset': dataset_iter, }) # Load the checkpointables abstract_params = jax.eval_shape(lambda: params) abstract_opt = jax.eval_shape(lambda: opt_state) abstract_checkpointables = { 'model': abstract_params, 'optimizer': abstract_opt, # Dataset is restored statefully. An initialized object must be # passed, but its position will be set to the position recorded in the # checkpoint after restoring. 'dataset': dataset_iter, } # Load all components restored = ocp.load_checkpointables(path, abstract_checkpointables) # Load only a subset restored_subset = ocp.load_checkpointables( path, {'model': abstract_params} ) Args: path: The path to load the checkpoint from. This path must contain a subdirectory for each checkpointable. abstract_checkpointables: A dictionary of abstract checkpointables. Dictionary keys represent the names of the checkpointables, while the values are the abstract checkpointable objects themselves. Returns: A dictionary of checkpointables. Dictionary keys represent the names of the checkpointables, while the values are the checkpointable objects themselves. Raises: FileNotFoundError: If the checkpoint path does not exist. """ start_time = time.time() event_tracking.OperationRecorder( path, operation_type=event_tracking.OperationType.LOAD, async_origin=False, ).record_start() abstract_checkpointables = _standardize_abstract_checkpointables( abstract_checkpointables ) validation.validate_abstract_checkpointables(abstract_checkpointables) ctx = context_lib.get_context() path = ctx.file_options.path_class(path) layout = asyncio_utils.run_sync( layout_registry.get_checkpoint_layout(path, ctx.checkpoint_layout) ) if not hasattr(layout, 'load_checkpointables'): raise NotImplementedError( f'Layout {type(layout)} does not support loading checkpointables.' ) return _load_impl( path, functools.partial( layout.load_checkpointables, path=path, abstract_checkpointables=abstract_checkpointables, ), start_time=start_time, )
def _load_impl( path: path_types.Path, load_fn: LoadFn, start_time: float, ) -> dict[str, Checkpointable] | tree_types.PyTreeOf[tree_types.Leaf]: """Implementation of loading logic for both :py:func:`.load_checkpointables` and :py:func:`.load`. Args: path: The path to the checkpoint. load_fn: A function that returns an awaitable for loading the checkpoint based on either :py:func:`.load_checkpointables` or :py:func:`.load`. start_time: The time when the loading process started. Returns: The loaded checkpointables or PyTree itself. """ if not path: raise ValueError('Path must not be None.') ctx = context_lib.get_context() # Ensure the operation ID is incremented as soon as possible. This must be # done uniquely for each load operation. asyncio_utils.run_sync( synchronization.synchronize_next_operation_id( prefix=ctx.multiprocessing_options.barrier_sync_key_prefix, processes=ctx.multiprocessing_options.active_processes, ) ) async def _load() -> Checkpointable: load_awaitable = await load_fn() event_tracking.OperationRecorder( path, operation_type=event_tracking.OperationType.LOAD, async_origin=False, ).record_blocking_completion(time.time() - start_time) result = await load_awaitable await multihost.sync_global_processes( multihost.unique_barrier_key( '_load_impl', prefix=ctx.multiprocessing_options.barrier_sync_key_prefix, ), operation_id=synchronization.get_operation_id(), processes=ctx.multiprocessing_options.active_processes, ) return result result = asyncio_utils.run_sync(_load()) duration_secs = time.time() - start_time event_tracking.OperationRecorder( path, operation_type=event_tracking.OperationType.LOAD, async_origin=False, ).record_completion(duration_secs) return result class _LoadPyTreeResponse(AsyncResponse[tree_types.PyTreeOf[tree_types.Leaf]]): """An :py:class:`.AsyncResponse` for :py:func:`.load_async`.""" def __init__( self, operation_id: str, path: path_types.Path, background_awaitable: Awaitable[tree_types.PyTreeOf[tree_types.Leaf]], *, start_time: float, context: context_lib.Context, ): self._operation_id = operation_id self._path = path self._background_awaitable = background_awaitable self._start_time = start_time self._context = context self._thread_runner = thread_utils.BackgroundThreadRunner[ tree_types.PyTreeOf[tree_types.Leaf] ](self._finalize_load()) @classmethod def create( cls, background_awaitable: Awaitable[tree_types.PyTreeOf[tree_types.Leaf]], path: path_types.Path, start_time: float, *, context: context_lib.Context, ) -> _LoadPyTreeResponse: """Creates and returns the final AsyncResponse for a load operation.""" blocking_duration_secs = time.time() - start_time event_tracking.OperationRecorder( path, operation_type=event_tracking.OperationType.LOAD, async_origin=True, ).record_blocking_completion(blocking_duration_secs) return cls( synchronization.get_operation_id(), path, background_awaitable, start_time=start_time, context=context, ) async def _finalize_load(self) -> tree_types.PyTreeOf[tree_types.Leaf]: logging.info( '[process=%s] Waiting for background load operations', multihost.process_index(), ) result = await self._background_awaitable logging.vlog( 1, '[process=%s] Finished waiting for background load operations.', multihost.process_index(), ) await multihost.sync_global_processes( multihost.unique_barrier_key( '_load_async:finalize', prefix=( self._context.multiprocessing_options.barrier_sync_key_prefix ), ), operation_id=synchronization.get_operation_id(), processes=self._context.multiprocessing_options.active_processes, ) total_duration_secs = time.time() - self._start_time event_tracking.OperationRecorder( self._path, operation_type=event_tracking.OperationType.LOAD, async_origin=True, ).record_completion(total_duration_secs) return result def result( self, timeout: float | None = None ) -> tree_types.PyTreeOf[tree_types.Leaf]: return self._thread_runner.result(timeout=timeout)
[docs] def load_async( path: path_types.PathLike, abstract_state: ( AbstractPyTree | CheckpointMetadata[AbstractPyTree] | None ) = None, *, checkpointable_name: str | None = STATE_CHECKPOINTABLE_KEY, ) -> async_types.AsyncResponse[tree_types.PyTreeOf[tree_types.Leaf]]: """Loads a PyTree asynchronously. Currently has limited support.""" start_time = time.time() event_tracking.OperationRecorder( path, operation_type=event_tracking.OperationType.LOAD, async_origin=True, ).record_start() ctx = context_lib.get_context() # Ensure the operation ID is incremented as soon as possible. This must be # done uniquely for each load operation. asyncio_utils.run_sync( synchronization.synchronize_next_operation_id( prefix=ctx.multiprocessing_options.barrier_sync_key_prefix, processes=ctx.multiprocessing_options.active_processes, ) ) if not path: raise ValueError('Path must not be None.') if ctx.checkpoint_layout != options_lib.CheckpointLayout.SAFETENSORS: raise NotImplementedError( 'Asynchronous loading only supported for SAFETENSORS checkpoint ' f'layout, not {ctx.checkpoint_layout}.' ) path = ctx.file_options.path_class(path) abstract_state = _standardize_abstract_checkpointables(abstract_state) validation.validate_pytree_checkpointable_name(checkpointable_name) async def _blocking_load() -> Any: resolver = await layout_registry.CheckpointLayoutResolver.resolve( path, ctx.checkpoint_layout, pytree_name=checkpointable_name ) return await resolver.layout.load( path, checkpointable_name=resolver.pytree_name, abstract_state=abstract_state, ) background_awaitable = asyncio_utils.run_sync(_blocking_load()) response = _LoadPyTreeResponse.create( background_awaitable, path, start_time=start_time, context=ctx, ) return response
[docs] def load_checkpointables_async( path: path_types.PathLike, abstract_checkpointables: ( dict[str, AbstractCheckpointable] | CheckpointMetadata[dict[str, AbstractCheckpointable]] | None ) = None, ) -> async_types.AsyncResponse[dict[str, Checkpointable]]: """Loads checkpointables asynchronously. Not yet implemented.""" del path, abstract_checkpointables raise NotImplementedError('Asynchronous loading is not yet supported.')
@deprecated('Use `load` instead.') def load_pytree(*args, **kwargs): return load(*args, **kwargs) @deprecated('Use `load_async` instead.') def load_pytree_async(*args, **kwargs): return load_async(*args, **kwargs)