ocp.v1.handlers module#

Public API for CheckpointableHandlers.

Types#

class orbax.checkpoint.experimental.v1.handlers.CheckpointableHandler(*args, **kwargs)[source][source]#

An interface that defines save/load logic for a checkpointable object.

NOTE: Prefer to use StatefulCheckpointable interface when possible.

A PyTree of arrays, representing model parameters, is the most basic “checkpointable”. A singular array is also a checkpointable.

In most contexts, when dealing with just a PyTree, the API of choice is:

ocp.save(directory, pytree)

The concept of “checkpointable” is not so obvious in this case. When dealing with multiple objects, we can use:

ocp.save_checkpointables(
    directory,
    dict(
        pytree=model_params,
        dataset=dataset_iterator,
        # other checkpointables, e.g. extra metadata, etc.
    ),
)

Now, it is easy to simply skip loading the dataset, as is commonly desired when running evals or inference:

ocp.load_checkpointables(
    directory,
    dict(
        pytree=abstract_model_params,
    ),
)
# Equivalently,
ocp.load(directory, abstract_model_params)

With the methods defined in this Protocol (save, load), logic within the method itself is executed in the main thread, in a blocking fashion. Additional logic can be executed in the background by returning an Awaitable function (which itself may return a result).

Let’s look at some suggestions on how to implement a CheckpointableHandler.

To create a custom handler, you must define a class that implements the methods defined in this Protocol. The class should be generic over the concrete type Checkpointable (the object being saved/loaded) and the abstract type AbstractCheckpointable (the lightweight metadata representation).

Crucially, once implemented, the handler must be registered with the global registry or a context-local registry so that save_checkpointables and load_checkpointables can automatically detect and use it for the corresponding types. Use orbax.checkpoint.v1.handlers.register_handler for global registration, or provide handlers via orbax.checkpoint.v1.context.CheckpointablesOptions for context-local registration.

First, take a look at orbax/checkpoint/experimental/v1/_src/testing/handler_utils.py for some toy implementations used for unit testing.

Here are some details on how to implement is_handleable and is_abstract_handleable.

For example, if a handler may be defined as follows:

class FooHandler(CheckpointableHandler[Foo, AbstractFoo]):

  def is_handleable(self, checkpointable: Foo) -> bool:
    return isinstance(foo, Foo)

  def is_abstract_handleable(
      self, abstract_checkpointable: AbstractFoo) -> bool:
    return isinstance(abstract_foo, AbstractFoo)

This is simple because the handler only works with Foo and AbstractFoo. But the handler may work on more generic types. In a toy example, let’s say we’ve developed an improved way of storing very large arrays, which is still suboptimal for more normal-sized arrays. We can implement the handler as:

class FooHandler(CheckpointableHandler[jax.Array, jax.ShapeDtypeStruct]):

  def is_handleable(self, checkpointable: jax.Array) -> bool:
    return (
        isinstance(checkpointable, jax.Array)
        and checkpointable.size > LARGE_ARRAY_THRESHOLD
    )

  def is_abstract_handleable(
      self, abstract_checkpointable: jax.ShapeDtypeStruct) -> bool:
    return (
        isinstance(abstract_checkpointable, jax.ShapeDtypeStruct)
        and abstract_checkpointable.size > LARGE_ARRAY_THRESHOLD
    )

In many cases, no information is needed for loading. In this case, AbstractCheckpointable may be defined as None. For example:

class FooHandler(CheckpointableHandler[Foo, None]):

  def is_handleable(self, checkpointable: Foo) -> bool:
    return isinstance(checkpointable, Foo)

  def is_abstract_handleable(self, abstract_checkpointable: None) -> bool:
    return abstract_checkpointable is None
class orbax.checkpoint.experimental.v1.handlers.StatefulCheckpointable(*args, **kwargs)[source][source]#

An interface that defines save/load logic for a checkpointable object.

Handlers#

final class orbax.checkpoint.experimental.v1.handlers.PyTreeHandler(*, context=None, array_metadata_validator=<orbax.checkpoint._src.metadata.array_metadata_store.Validator object>, leaf_handler_registry=None, partial_save_mode=False)[source][source]#

An implementation of CheckpointableHandler for PyTrees.

PyTreeHandler manages the decomposition of JAX PyTree structures into leaf- level parameters for persistence. It utilizes an asynchronous two-tier execution model to allow for background I/O, ensuring that heavy array serialization does not block the main training process.

Note: Users are encouraged NEVER to instantiate or use this handler directly. Always use the top-level APIs like ocp.save_checkpointables and ocp.load_checkpointables. Orbax uses this handler by default for standard JAX PyTrees (like nested dictionaries of arrays).

To configure a specific serialization context for a PyTree and aggressively force Orbax to use the customized PyTreeHandler, the recommended approach is to use ocp.Context with CheckpointablesOptions. This allows you to bind the handler to a specific dictionary key within the Context scope.

See CheckpointablesOptions for more details on handler registration.

Usage Example:

Save a state dictionary configuration:

import orbax.checkpoint as ocp

state_pytree = {'weights': [1.0, 2.0], 'bias': 0.0}

checkpointables_options = (
    ocp.options.CheckpointablesOptions.create_with_handlers(
        model_state=ocp.handlers.PyTreeHandler()
    )
)
with ocp.Context(checkpointables_options=checkpointables_options):
    ocp.save_checkpointables(path, dict(model_state=state_pytree))
context#

Optional V1 Context providing configuration for serialization, array options, and multiprocessing coordination.

Type:

Optional[Context]

array_metadata_validator#

A validator object used to verify consistency of array metadata during restoration.

Type:

Validator

final class orbax.checkpoint.experimental.v1.handlers.ProtoHandler(filename='proto.pbtxt')[source][source]#

Implementation of CheckpointableHandler for protocol buffers.

ProtoHandler manages the serialization and deserialization of Protocol Buffer messages in text format. It utilizes an asynchronous two-tier execution model to offload I/O operations, ensuring background writing does not block the main process. In distributed environments, it provides multihost coordination to ensure that only the primary host performs the write operation.

Note: Users are encouraged NEVER to instantiate or use this handler directly. Always use the top-level APIs like ocp.save_checkpointables and ocp.load_checkpointables. Orbax uses this handler by default for standard protocol buffer messages.

To save a custom Protocol Buffer message and aggressively force Orbax to use the ProtoHandler (e.g., to specify a custom filename), the recommended approach is to use ocp.Context with CheckpointablesOptions. This allows you to bind the handler to a specific dictionary key within the Context scope.

See CheckpointablesOptions for more details on handler registration.

Example Usage:

Save a protobuf message configuration:

import orbax.checkpoint as ocp

# Assuming MyProtoMessage is your compiled protobuf class
my_proto_msg = MyProtoMessage(config_field="value")

checkpointables_options = (
    ocp.options.CheckpointablesOptions.create_with_handlers(
        proto_config=ocp.handlers.ProtoHandler(
            filename="model_config.pbtxt"
        )
    )
)
with ocp.Context(checkpointables_options=checkpointables_options):
    ocp.save_checkpointables(path, dict(proto_config=my_proto_msg))
filename#

An optional filename used for saving and loading the protobuf data. If not provided, it defaults to a standard internal default filename.

Type:

str

final class orbax.checkpoint.experimental.v1.handlers.JsonHandler(filename=None)[source][source]#

An implementation of CheckpointableHandler for Json.

JsonHandler enables the persistence of standard Python structures (dicts, lists, and primitives) that are JSON-serializable. It utilizes an asynchronous two-tier execution model to offload I/O operations, ensuring background writing does not block the main process. It also provides multihost coordination to ensure that only the primary host performs the write operation.

Note: Users are encouraged NEVER to instantiate or use this handler directly. Always use the top-level APIs like ocp.save_checkpointables and ocp.load_checkpointables. Orbax uses this handler by default for standard JSON-serializable objects.

To save a custom JSON-serializable object (like a specific dictionary containing metadata) and aggressively force Orbax to use the JsonHandler, the recommended approach is to use ocp.Context with CheckpointablesOptions, which only applies to save/load operations strictly within the Context scope.

See CheckpointablesOptions for more details on handler registration.

Example Usage:

Save a dictionary configuration:

import orbax.checkpoint as ocp

config = {'learning_rate': 0.01, 'batch_size': 32}

checkpointables_options = (
    ocp.options.CheckpointablesOptions.create_with_handlers(
        experiment_config=ocp.handlers.JsonHandler(
            filename='experiment_config.json'
        )
    )
)
with ocp.Context(checkpointables_options=checkpointables_options):
    ocp.save_checkpointables(path, dict(experiment_config=config))
filename#

An optional specific filename to use for saving and loading the JSON data. If not provided, the handler will fall back to a default set of supported JSON filenames.

Registration#

class orbax.checkpoint.experimental.v1.handlers.CheckpointableHandlerRegistry(*args, **kwargs)[source][source]#

A registry for CheckpointableHandler instances.

This protocol defines the core interface for adding, retrieving, and checking for the existence of handlers that manage the saving and loading of specific checkpointable types within the Orbax framework.

As a Protocol, it serves as a structural type definition. Any class that implements these four methods (add, get, has, and get_all_entries) with the correct signatures is considered a valid registry by static type checkers, without needing to explicitly inherit from this class.

Example

Implementing a custom registry that fulfills this protocol. Note that explicit inheritance is not required for type checkers to recognize it:

from typing import Type, Sequence, Tuple, Optional
from orbax.checkpoint.v1 import handlers

class MyCustomRegistry:
  def __init__(self) -> None:
    self._entries: list[
        Tuple[Type[handlers.CheckpointableHandler], Optional[str]]
    ] = []

  def add(
      self,
      handler_type: Type[handlers.CheckpointableHandler],
      checkpointable: Optional[str] = None,
  ) -> 'MyCustomRegistry':
    self._entries.append((handler_type, checkpointable))
    return self

  def get(
      self, checkpointable: str
  ) -> Type[handlers.CheckpointableHandler]:
    for h_type, name in self._entries:
      if name == checkpointable:
        return h_type
    raise KeyError(f'Not found: {checkpointable}')

  def has(self, checkpointable: str) -> bool:
    return any(name == checkpointable for _, name in self._entries)

  def get_all_entries(
      self,
  ) -> Sequence[
      Tuple[Type[handlers.CheckpointableHandler], Optional[str]]
  ]:
    return self._entries
add(handler_type, checkpointable=None)[source][source]#

Adds an entry to the registry. Returns the registry instance to allow method chaining.

get(checkpointable)[source][source]#

Gets the type of a CheckpointableHandler from the registry by its associated checkpointable name.

has(checkpointable)[source][source]#

Checks if an entry exists in the registry for the given checkpointable name. Returns True if it exists, False otherwise.

get_all_entries()[source][source]#

Returns a sequence of all registered entries as (handler_type, name) tuples.

orbax.checkpoint.experimental.v1.handlers.global_registry()[source][source]#

Returns the global registry.

The global registry serves as the default, singleton storage for all handlers registered throughout the application’s lifecycle via register_handler.

Example

Retrieve the global registry to inspect available handlers:

from orbax.checkpoint.v1 import handlers

# Fetch the singleton global registry
registry = handlers.global_registry()

# Check if a specific handler name is registered globally
is_registered = registry.has("my_custom_model_handler")
Returns:

The global singleton registry instance.

Return type:

CheckpointableHandlerRegistry

orbax.checkpoint.experimental.v1.handlers.local_registry(other_registry=None, *, include_global_registry=True)[source][source]#

Creates a local registry.

This function builds a new registry by optionally combining the existing global registry with a provided custom registry. It is highly useful for overriding handlers for a specific checkpointer operation without mutating the global state.

Example

Create a registry with custom handlers, potentially including global ones:

from orbax.checkpoint.v1 import handlers

class MyHandler(handlers.CheckpointableHandler):
  pass

# Create a registry and add a handler. By default, it includes
# globally-registered handlers.
my_registry = handlers.local_registry()
my_registry.add(MyHandler)

# To start with an empty registry, use:
# my_registry = handlers.local_registry(include_global_registry=False)
Parameters:
  • other_registry (UnionType[CheckpointableHandlerRegistry, None]) – An optional registry of handlers to include in the returned registry.

  • include_global_registry (bool) – If True, includes globally-registered handlers in the returned registry by default.

Return type:

CheckpointableHandlerRegistry

Returns:

A local registry.

orbax.checkpoint.experimental.v1.handlers.register_handler(cls, *, checkpointable_name=None, secondary_typestrs=None)[source][source]#

Registers a CheckpointableHandler globally.

The order in which handlers are registered strictly matters. If multiple handlers could potentially be used to save or load an object (i.e., are capable of handling the checkpointable according to is_handleable/ is_abstract_handleable for save/load, respectively), the framework resolves them in Last-In, First-Out (LIFO) order. This means the handler added most recently will be selected.

Example

Registering a custom handler using a direct function call. Note the import path from the v1 namespace:

from orbax.checkpoint.v1 import handlers

class BarHandler(handlers.CheckpointableHandler):
  pass

handlers.register_handler(BarHandler)
Parameters:
  • cls (~CheckpointableHandlerType) – The handler class to register globally.

  • checkpointable_name (UnionType[str, None]) – The checkpointable name. If not-None, the registered handler will be scoped to that specific name. Otherwise, the handler will be available for any checkpointable name.

  • secondary_typestrs (Optional[Sequence[str], None]) – A sequence of alternate handler typestrs that serve as secondary identifiers for the handler.

Return type:

~CheckpointableHandlerType

Returns:

The handler class.