Proto Checkpointable Handler#

Public API for CheckpointableHandlers.

ProtoHandler#

final class orbax.checkpoint.experimental.v1.handlers.ProtoHandler(filename='proto.pbtxt')[source][source]#

Implementation of CheckpointableHandler for protocol buffers.

ProtoHandler manages the serialization and deserialization of Protocol Buffer messages in text format. It utilizes an asynchronous two-tier execution model to offload I/O operations, ensuring background writing does not block the main process. In distributed environments, it provides multihost coordination to ensure that only the primary host performs the write operation.

Note: Users are encouraged NEVER to instantiate or use this handler directly. Always use the top-level APIs like ocp.save_checkpointables and ocp.load_checkpointables. Orbax uses this handler by default for standard protocol buffer messages.

To save a custom Protocol Buffer message and aggressively force Orbax to use the ProtoHandler (e.g., to specify a custom filename), the recommended approach is to use ocp.Context with CheckpointablesOptions. This allows you to bind the handler to a specific dictionary key within the Context scope.

See CheckpointablesOptions for more details on handler registration.

Example Usage:

Save a protobuf message configuration:

import orbax.checkpoint as ocp

# Assuming MyProtoMessage is your compiled protobuf class
my_proto_msg = MyProtoMessage(config_field="value")

registry = ocp.handlers.local_registry()
registry.add(
    ocp.handlers.ProtoHandler, checkpointable_name="proto_config"
)
ctx = ocp.Context()
ctx.checkpointables.registry = registry
with ctx:
    ocp.save_checkpointables(path, dict(proto_config=my_proto_msg))
filename#

An optional filename used for saving and loading the protobuf data. If not provided, it defaults to a standard internal default filename.

Type:

str

__init__(filename='proto.pbtxt')[source][source]#

Initializes ProtoHandler.

async save(directory, checkpointable)[source][source]#

Saves the given checkpointable to the given directory.

Save should perform any operations that need to block the main thread, such as device-to-host copying of on-device arrays. It then creates a background operation to continue writing the object to the storage location.

IMPORTANT: Do not assume that directory already exists at the start of this method. All directories are created by upper layers of the Orbax library, for performance reasons in a multihost setting and because upper layers also need to modify the directories. Before engaging in any filesystem operations, wait for the directory to exist. For example:

async def _background_save(
    self,
    directory: path_types.PathAwaitingCreation,
    checkpointable: T,
) -> None:
  directory = await directory.await_creation()
  # Write to `directory` here.
  ...

async def save(
    self,
    directory: path_types.PathAwaitingCreation,
    checkpointable: T,
) -> Awaitable[None]:
  # OK to access path properties, as long as we don't touch the actual
  # directory in the filesystem.
  logging.info(directory.name)
  return self._background_save(directory, checkpointable)
Parameters:
  • directory (PathAwaitingCreation) – The directory to save the checkpoint to. Note that the directory should not be expected to exist yet - it is in the process of being created. To wait for it to be created, use await_creation, preferably in a background awaitable to avoid blocking the main thread.

  • checkpointable (Message) – The checkpointable object to save.

Return type:

Awaitable[None]

Returns:

An Awaitable. This object represents the result of the save operation running in the background.

async load(directory, abstract_checkpointable=None)[source][source]#

Loads the checkpointable from the given directory.

Parameters:
  • directory (Path) – The directory to load the checkpoint from.

  • abstract_checkpointable (Optional[Type[Message], None]) – An optional abstract representation of the checkpointable to load. If provided, this is used to provide properties to guide the restoration logic of the checkpoint. In the case of arrays, for example, this conveys properties like shape and dtype, for casting and reshaping. In some cases, no information is needed, and AbstractCheckpointable may always be None. In other cases, the abstract representation may be a hard requirement for loading.

Return type:

Awaitable[Message]

Returns:

An Awaitable that continues to load the checkpointable in the background and returns the loaded checkpointable when complete.

async metadata(directory)[source][source]#

Returns the metadata for the given directory.

The logic in this method must be executed fully in the main thread; metadata access is expected to be cheap and fast.

In many cases it is desirable to return additional metadata properties beyond the limited set in AbstractCheckpointable. In this case, AbstractCheckpointable should be subclasses, and this subclass can be returned from metadata.

Parameters:

directory (Path) – The directory where the checkpoint is located.

Returns:

The metadata is an AbstractCheckpointable, which is the abstract representation of the checkpointable.

Return type:

AbstractT

is_handleable(checkpointable)[source][source]#

Returns whether the handler can handle the given checkpointable.

The method should return True if it is possible to save such an object.

See class docstring for more details.

Parameters:

checkpointable (Any) – Either a concrete checkpointable, for saving.

Return type:

bool

Returns:

True if the handler can handle the given checkpointable.

is_abstract_handleable(abstract_checkpointable)[source][source]#

Returns whether the handler can handle the abstract checkpointable.

The method should return True if it is possible to use the given abstract_checkpointable for loading a concrete Checkpointable. Note that None is always considered handleable for loading, so this method does not need to check for it. If an implementation defines AbstractCheckpointable as None, then this method should only return True for values of None.

See class docstring for more details.

Parameters:

abstract_checkpointable (Any) – An abstract checkpointable, for loading.

Return type:

bool

Returns:

True if the handler can handle the given checkpointable. None if the handler cannot decide whether it can handle the abstract checkpointable and defers to the typestr.

classmethod __subclasshook__(other)[source]#

Abstract classes can override this to customize issubclass().

This is invoked early on by abc.ABCMeta.__subclasscheck__(). It should return True, False or NotImplemented. If it returns NotImplemented, the normal algorithm is used. Otherwise, it overrides the normal algorithm (and the outcome is cached).