PyTree Checkpointable Handler#

Public API for CheckpointableHandlers.

PyTreeHandler#

final class orbax.checkpoint.experimental.v1.handlers.PyTreeHandler(*, context=None, array_metadata_validator=<orbax.checkpoint._src.metadata.array_metadata_store.Validator object>, leaf_handler_registry=None, partial_save_mode=False)[source][source]#

An implementation of CheckpointableHandler for PyTrees.

PyTreeHandler manages the decomposition of JAX PyTree structures into leaf- level parameters for persistence. It utilizes an asynchronous two-tier execution model to allow for background I/O, ensuring that heavy array serialization does not block the main training process.

Note: Users are encouraged NEVER to instantiate or use this handler directly. Always use the top-level APIs like ocp.save_checkpointables and ocp.load_checkpointables. Orbax uses this handler by default for standard JAX PyTrees (like nested dictionaries of arrays).

To configure a specific serialization context for a PyTree and aggressively force Orbax to use the customized PyTreeHandler, the recommended approach is to use ocp.Context with CheckpointablesOptions. This allows you to bind the handler to a specific dictionary key within the Context scope.

See CheckpointablesOptions for more details on handler registration.

Usage Example:

Save a state dictionary configuration:

import orbax.checkpoint as ocp

state_pytree = {'weights': [1.0, 2.0], 'bias': 0.0}

registry = ocp.handlers.local_registry()
registry.add(
    ocp.handlers.PyTreeHandler, checkpointable_name='model_state'
)
ctx = ocp.Context()
ctx.checkpointables.registry = registry
with ctx:
    ocp.save_checkpointables(path, dict(model_state=state_pytree))
context#

Optional V1 Context providing configuration for serialization, array options, and multiprocessing coordination.

Type:

Optional[Context]

array_metadata_validator#

A validator object used to verify consistency of array metadata during restoration.

Type:

Validator

__init__(*, context=None, array_metadata_validator=<orbax.checkpoint._src.metadata.array_metadata_store.Validator object>, leaf_handler_registry=None, partial_save_mode=False)[source][source]#
async save(directory, checkpointable)[source][source]#

Saves the given checkpointable to the given directory.

Save should perform any operations that need to block the main thread, such as device-to-host copying of on-device arrays. It then creates a background operation to continue writing the object to the storage location.

IMPORTANT: Do not assume that directory already exists at the start of this method. All directories are created by upper layers of the Orbax library, for performance reasons in a multihost setting and because upper layers also need to modify the directories. Before engaging in any filesystem operations, wait for the directory to exist. For example:

async def _background_save(
    self,
    directory: path_types.PathAwaitingCreation,
    checkpointable: T,
) -> None:
  directory = await directory.await_creation()
  # Write to `directory` here.
  ...

async def save(
    self,
    directory: path_types.PathAwaitingCreation,
    checkpointable: T,
) -> Awaitable[None]:
  # OK to access path properties, as long as we don't touch the actual
  # directory in the filesystem.
  logging.info(directory.name)
  return self._background_save(directory, checkpointable)
Parameters:
  • directory (PathAwaitingCreation) – The directory to save the checkpoint to. Note that the directory should not be expected to exist yet - it is in the process of being created. To wait for it to be created, use await_creation, preferably in a background awaitable to avoid blocking the main thread.

  • checkpointable (PyTree) – The checkpointable object to save.

Return type:

Awaitable[None]

Returns:

An Awaitable. This object represents the result of the save operation running in the background.

async load(directory, abstract_checkpointable=None)[source][source]#

Loads a PyTree from a checkpoint directory.

Parameters:
  • directory (Path) – The directory to load from.

  • abstract_checkpointable (UnionType[PyTree, None]) – The abstract checkpointable to load into. If None, the handler will attempt to load the entire checkpoint using the recorded metadata. Otherwise, the abstract_checkpointable is expected to be a PyTree of abstract leaves. The abstract leaf may be a value of type AbstractLeaf, Type[AbstractLeaf], or None. Passing the latter two indicates that the metadata should be used to restore the leaf.

Return type:

Awaitable[PyTree]

Returns:

A awaitable which can be awaited to complete the load operation and obtain a PyTree.

async metadata(directory)[source][source]#

Returns the metadata for the given directory.

The logic in this method must be executed fully in the main thread; metadata access is expected to be cheap and fast.

In many cases it is desirable to return additional metadata properties beyond the limited set in AbstractCheckpointable. In this case, AbstractCheckpointable should be subclasses, and this subclass can be returned from metadata.

Parameters:

directory (Path) – The directory where the checkpoint is located.

Returns:

The metadata is an AbstractCheckpointable, which is the abstract representation of the checkpointable.

Return type:

AbstractT

is_handleable(checkpointable)[source][source]#

Returns whether the handler can handle the given checkpointable.

The method should return True if it is possible to save such an object.

See class docstring for more details.

Parameters:

checkpointable (Any) – Either a concrete checkpointable, for saving.

Return type:

bool

Returns:

True if the handler can handle the given checkpointable.

is_abstract_handleable(abstract_checkpointable)[source][source]#

Returns whether the handler can handle the abstract checkpointable.

The method should return True if it is possible to use the given abstract_checkpointable for loading a concrete Checkpointable. Note that None is always considered handleable for loading, so this method does not need to check for it. If an implementation defines AbstractCheckpointable as None, then this method should only return True for values of None.

See class docstring for more details.

Parameters:

abstract_checkpointable (Any) – An abstract checkpointable, for loading.

Return type:

bool

Returns:

True if the handler can handle the given checkpointable. None if the handler cannot decide whether it can handle the abstract checkpointable and defers to the typestr.

classmethod __subclasshook__(other)[source]#

Abstract classes can override this to customize issubclass().

This is invoked early on by abc.ABCMeta.__subclasscheck__(). It should return True, False or NotImplemented. If it returns NotImplemented, the normal algorithm is used. Otherwise, it overrides the normal algorithm (and the outcome is cached).