Checkpointable Handler Types#

Public API for CheckpointableHandlers.

CheckpointableHandler#

class orbax.checkpoint.experimental.v1.handlers.CheckpointableHandler(*args, **kwargs)[source][source]#

An interface that defines save/load logic for a checkpointable object.

NOTE: Prefer to use StatefulCheckpointable interface when possible.

A PyTree of arrays, representing model parameters, is the most basic “checkpointable”. A singular array is also a checkpointable.

In most contexts, when dealing with just a PyTree, the API of choice is:

ocp.save(directory, pytree)

The concept of “checkpointable” is not so obvious in this case. When dealing with multiple objects, we can use:

ocp.save_checkpointables(
    directory,
    dict(
        pytree=model_params,
        dataset=dataset_iterator,
        # other checkpointables, e.g. extra metadata, etc.
    ),
)

Now, it is easy to simply skip loading the dataset, as is commonly desired when running evals or inference:

ocp.load_checkpointables(
    directory,
    dict(
        pytree=abstract_model_params,
    ),
)
# Equivalently,
ocp.load(directory, abstract_model_params)

With the methods defined in this Protocol (save, load), logic within the method itself is executed in the main thread, in a blocking fashion. Additional logic can be executed in the background by returning an Awaitable function (which itself may return a result).

Let’s look at some suggestions on how to implement a CheckpointableHandler.

To create a custom handler, you must define a class that implements the methods defined in this Protocol. The class should be generic over the concrete type Checkpointable (the object being saved/loaded) and the abstract type AbstractCheckpointable (the lightweight metadata representation).

Crucially, once implemented, the handler must be registered with the global registry or a context-local registry so that save_checkpointables and load_checkpointables can automatically detect and use it for the corresponding types. Use orbax.checkpoint.v1.handlers.register_handler for global registration, or provide handlers via orbax.checkpoint.v1.context.CheckpointablesOptions for context-local registration.

First, take a look at orbax/checkpoint/experimental/v1/_src/testing/handler_utils.py for some toy implementations used for unit testing.

Here are some details on how to implement is_handleable and is_abstract_handleable.

For example, if a handler may be defined as follows:

class FooHandler(CheckpointableHandler[Foo, AbstractFoo]):

  def is_handleable(self, checkpointable: Foo) -> bool:
    return isinstance(foo, Foo)

  def is_abstract_handleable(
      self, abstract_checkpointable: AbstractFoo) -> bool:
    return isinstance(abstract_foo, AbstractFoo)

This is simple because the handler only works with Foo and AbstractFoo. But the handler may work on more generic types. In a toy example, let’s say we’ve developed an improved way of storing very large arrays, which is still suboptimal for more normal-sized arrays. We can implement the handler as:

class FooHandler(CheckpointableHandler[jax.Array, jax.ShapeDtypeStruct]):

  def is_handleable(self, checkpointable: jax.Array) -> bool:
    return (
        isinstance(checkpointable, jax.Array)
        and checkpointable.size > LARGE_ARRAY_THRESHOLD
    )

  def is_abstract_handleable(
      self, abstract_checkpointable: jax.ShapeDtypeStruct) -> bool:
    return (
        isinstance(abstract_checkpointable, jax.ShapeDtypeStruct)
        and abstract_checkpointable.size > LARGE_ARRAY_THRESHOLD
    )

In many cases, no information is needed for loading. In this case, AbstractCheckpointable may be defined as None. For example:

class FooHandler(CheckpointableHandler[Foo, None]):

  def is_handleable(self, checkpointable: Foo) -> bool:
    return isinstance(checkpointable, Foo)

  def is_abstract_handleable(self, abstract_checkpointable: None) -> bool:
    return abstract_checkpointable is None
async save(directory, checkpointable)[source][source]#

Saves the given checkpointable to the given directory.

Save should perform any operations that need to block the main thread, such as device-to-host copying of on-device arrays. It then creates a background operation to continue writing the object to the storage location.

IMPORTANT: Do not assume that directory already exists at the start of this method. All directories are created by upper layers of the Orbax library, for performance reasons in a multihost setting and because upper layers also need to modify the directories. Before engaging in any filesystem operations, wait for the directory to exist. For example:

async def _background_save(
    self,
    directory: path_types.PathAwaitingCreation,
    checkpointable: T,
) -> None:
  directory = await directory.await_creation()
  # Write to `directory` here.
  ...

async def save(
    self,
    directory: path_types.PathAwaitingCreation,
    checkpointable: T,
) -> Awaitable[None]:
  # OK to access path properties, as long as we don't touch the actual
  # directory in the filesystem.
  logging.info(directory.name)
  return self._background_save(directory, checkpointable)
Parameters:
  • directory (PathAwaitingCreation) – The directory to save the checkpoint to. Note that the directory should not be expected to exist yet - it is in the process of being created. To wait for it to be created, use await_creation, preferably in a background awaitable to avoid blocking the main thread.

  • checkpointable (~_Checkpointable) – The checkpointable object to save.

Return type:

Awaitable[None]

Returns:

An Awaitable. This object represents the result of the save operation running in the background.

async load(directory, abstract_checkpointable=None)[source][source]#

Loads the checkpointable from the given directory.

Parameters:
  • directory (Path) – The directory to load the checkpoint from.

  • abstract_checkpointable (Optional[~_AbstractCheckpointable, None]) – An optional abstract representation of the checkpointable to load. If provided, this is used to provide properties to guide the restoration logic of the checkpoint. In the case of arrays, for example, this conveys properties like shape and dtype, for casting and reshaping. In some cases, no information is needed, and AbstractCheckpointable may always be None. In other cases, the abstract representation may be a hard requirement for loading.

Return type:

Awaitable[~_Checkpointable]

Returns:

An Awaitable that continues to load the checkpointable in the background and returns the loaded checkpointable when complete.

async metadata(directory)[source][source]#

Returns the metadata for the given directory.

The logic in this method must be executed fully in the main thread; metadata access is expected to be cheap and fast.

In many cases it is desirable to return additional metadata properties beyond the limited set in AbstractCheckpointable. In this case, AbstractCheckpointable should be subclasses, and this subclass can be returned from metadata.

Parameters:

directory (Path) – The directory where the checkpoint is located.

Returns:

The metadata is an AbstractCheckpointable, which is the abstract representation of the checkpointable.

Return type:

AbstractT

is_handleable(checkpointable)[source][source]#

Returns whether the handler can handle the given checkpointable.

The method should return True if it is possible to save such an object.

See class docstring for more details.

Parameters:

checkpointable (Any) – Either a concrete checkpointable, for saving.

Return type:

bool

Returns:

True if the handler can handle the given checkpointable.

__init__(*args, **kwargs)[source]#
classmethod __subclasshook__(other)[source]#

Abstract classes can override this to customize issubclass().

This is invoked early on by abc.ABCMeta.__subclasscheck__(). It should return True, False or NotImplemented. If it returns NotImplemented, the normal algorithm is used. Otherwise, it overrides the normal algorithm (and the outcome is cached).

is_abstract_handleable(abstract_checkpointable)[source][source]#

Returns whether the handler can handle the abstract checkpointable.

The method should return True if it is possible to use the given abstract_checkpointable for loading a concrete Checkpointable. Note that None is always considered handleable for loading, so this method does not need to check for it. If an implementation defines AbstractCheckpointable as None, then this method should only return True for values of None.

See class docstring for more details.

Parameters:

abstract_checkpointable (Any) – An abstract checkpointable, for loading.

Return type:

UnionType[bool, None]

Returns:

True if the handler can handle the given checkpointable. None if the handler cannot decide whether it can handle the abstract checkpointable and defers to the typestr.