Checkpointers#

Defines exported symbols for the namespace package orbax.checkpoint.

AbstractCheckpointer#

class orbax.checkpoint.AbstractCheckpointer[source][source]#

An interface allowing atomic save and restore for a single object.

Typically, an implementation of this class should rely on a CheckpointHandler object, which type-specific logic can be delegated to. In this way, the Checkpointer can be used for many different types, while itself only handling common logic related to atomicity, synchronization, or asynchronous thread management.

abstract save(directory, *args, **kwargs)[source][source]#

Saves the given item to the provided directory.

Parameters:
  • directory (Union[str, PathLike]) – a path to which to save.

  • *args – additional args to provide to the CheckpointHandler’s save method.

  • **kwargs – additional keyword args to provide to the CheckpointHandler’s save method.

abstract restore(directory, *args, **kwargs)[source][source]#

Restores from the provided directory.

Delegates to underlying handler.

Parameters:
  • directory (Union[str, PathLike]) – a path to restore from.

  • *args – additional args to provide to the CheckpointHandler’s restore method.

  • **kwargs – additional keyword args to provide to the CheckpointHandler’s restore method.

Return type:

Any

Returns:

a restored object

structure(directory)[source][source]#

DEPRECATED.

The structure of the saved object at directory.

Delegates to underlying handler.

Parameters:

directory (Union[str, PathLike]) – a path to a saved checkpoint.

Return type:

Optional[Any]

Returns:

the object structure or None, if the underlying handler does not implement structure.

metadata(directory)[source][source]#

Returns metadata about the saved item.

Ideally, this is a cheap way to collect information about the checkpoint without requiring a full restoration.

Parameters:

directory (Union[str, PathLike]) – the directory where the checkpoint is located.

Return type:

Optional[Any]

Returns:

item metadata

close()[source][source]#

Closes the Checkpointer.

Checkpointer#

class orbax.checkpoint.Checkpointer(handler, primary_host=0, active_processes=None)[source][source]#

A synchronous implementation of AbstractCheckpointer.

This class saves synchronously to a given directory using an underlying CheckpointHandler. Atomicity of the operation is guaranteed.

IMPORTANT: Async checkpointing can often be faster for saving. Strongly consider using AsyncCheckpointer instead.

IMPORTANT: Remember that to save and restore a checkpoint, one should always use an AbstractCheckpointer coupled with a CheckpointHandler. The specific CheckpointHandler to use depends on the object being saved or restored.

Basic example:

ckptr = Checkpointer(StandardCheckpointHandler())
args = ocp.args.StandardSave(state=pytree_of_arrays)
ckptr.save(path, args=args)
args = ocp.args.StandardRestore(state=abstract_pytree_target)
ckptr.restore(path, args=args)

Each handler includes …SaveArgs and …RestoreArgs classes that document what arguments are expected. When using Checkpointer, you can either use this dataclass directly, or you can provide the arguments in keyword form.

For example:

ckptr = Checkpointer(StandardCheckpointHandler())
ckptr.save(path, state=pytree_of_arays)
ckptr.restore(path, state=abstract_pytree_target)
__init__(handler, primary_host=0, active_processes=None)[source][source]#
save(directory, *args, force=False, **kwargs)[source][source]#

Saves the given item to the provided directory.

Delegates to the underlying CheckpointHandler. Ensures save operation atomicity.

This method should be called by all hosts - process synchronization and actions that need to be performed on only one host are managed internally.

Parameters:
  • directory (Union[str, PathLike]) – a path to which to save.

  • *args – additional args to provide to the CheckpointHandler’s save method.

  • force (bool) – if True, allows overwriting an existing directory. May add overhead due to the need to delete any existing files.

  • **kwargs – additional keyword args to provide to the CheckpointHandler’s save method.

Raises:

ValueError if the provided directory already exists.

restore(directory, *args, **kwargs)[source][source]#

See superclass documentation.

Return type:

Any

metadata(directory)[source][source]#

See superclass documentation.

Return type:

Optional[Any]

close()[source][source]#

Closes the underlying CheckpointHandler.

AsyncCheckpointer#

class orbax.checkpoint.AsyncCheckpointer(handler, timeout_secs=300, *, primary_host=0, active_processes=None, barrier_sync_fn=None)[source][source]#

An asynchronous implementation of Checkpointer.

Save operations take place in a background thread (this functionality is provided by AsyncManager). Users should call wait_until_finished to block until a save operation running in the background is complete.

Like its parent, AsyncCheckpointer also makes use of an underlying CheckpointHandler to deal with type-specific logic.

Please see Checkpointer documentation for more generic usage instructions.

__init__(handler, timeout_secs=300, *, primary_host=0, active_processes=None, barrier_sync_fn=None)[source][source]#
save(directory, *args, force=False, **kwargs)[source][source]#

Saves the given item to the provided directory.

Delegates to the underlying CheckpointHandler. Ensures save operation atomicity. Must first block until any previous save operations running in the background are completed.

This method should be called by all hosts - process synchronization and actions that need to be performed on only one host are managed internally.

Parameters:
  • directory (Union[str, PathLike]) – a path to which to save.

  • *args – additional args to provide to the CheckpointHandler’s save method.

  • force (bool) – if True, allows overwriting an existing directory. May add overhead due to the need to delete any existing files.

  • **kwargs – additional keyword args to provide to the CheckpointHandler’s save method.

Raises:

ValueError if the provided directory already exists.

restore(directory, *args, **kwargs)[source][source]#

See superclass documentation.

Return type:

Any

check_for_errors()[source][source]#

Surfaces any errors from the background commit operations.

wait_until_finished()[source][source]#

Waits for any outstanding operations to finish.

close()[source][source]#

Waits to finish any outstanding operations before closing.

StandardCheckpointer#

class orbax.checkpoint.StandardCheckpointer(primary_host=0)[source][source]#

Shorthand class.

Instead of::

ckptr = Checkpointer(StandardCheckpointHandler())

we can use::

ckptr = StandardCheckpointer()

__init__(primary_host=0)[source][source]#

PyTreeCheckpointer#

class orbax.checkpoint.PyTreeCheckpointer(primary_host=0, use_ocdbt=True, use_zarr3=False)[source][source]#

Shorthand class.

Instead of::

ckptr = Checkpointer(PyTreeCheckpointHandler())

we can use::

ckptr = PyTreeCheckpointer()

__init__(primary_host=0, use_ocdbt=True, use_zarr3=False)[source][source]#