Checkpointers#
Defines exported symbols for the namespace package orbax.checkpoint.
AbstractCheckpointer#
- class orbax.checkpoint.AbstractCheckpointer(*args, **kwargs)[source][source]#
An interface allowing atomic save and restore for a single object.
Typically, an implementation of this class should rely on a
CheckpointHandler
object, which type-specific logic can be delegated to. In this way, the Checkpointer can be used for many different types, while itself only handling common logic related to atomicity, synchronization, or asynchronous thread management.- abstractmethod save(directory, *args, **kwargs)[source][source]#
Saves the given item to the provided directory.
- Parameters:
directory (
Union
[str
,PathLike
]) – a path to which to save.*args – additional args to provide to the CheckpointHandler’s save method.
**kwargs – additional keyword args to provide to the CheckpointHandler’s save method.
- abstractmethod restore(directory, *args, **kwargs)[source][source]#
Restores from the provided directory.
Delegates to underlying handler.
- Parameters:
directory (
Union
[str
,PathLike
]) – a path to restore from.*args – additional args to provide to the CheckpointHandler’s restore method.
**kwargs – additional keyword args to provide to the CheckpointHandler’s restore method.
- Return type:
Any
- Returns:
a restored object
- structure(directory)[source][source]#
DEPRECATED.
The structure of the saved object at directory.
Delegates to underlying handler.
- Parameters:
directory (
Union
[str
,PathLike
]) – a path to a saved checkpoint.- Return type:
Optional
[Any
,None
]- Returns:
the object structure or None, if the underlying handler does not implement structure.
- metadata(directory)[source][source]#
Returns metadata about the saved item.
Ideally, this is a cheap way to collect information about the checkpoint without requiring a full restoration.
- Parameters:
directory (
Union
[str
,PathLike
]) – the directory where the checkpoint is located.- Return type:
Optional
[Any
,None
]- Returns:
item metadata
Checkpointer#
- class orbax.checkpoint.Checkpointer(handler, *, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), file_options=FileOptions(path_permission_mode=None), checkpoint_metadata_store=None, temporary_path_class=None)[source][source]#
A synchronous implementation of AbstractCheckpointer.
This class saves synchronously to a given directory using an underlying
CheckpointHandler
. Atomicity of the operation is guaranteed.IMPORTANT: Async checkpointing can often be faster for saving. Strongly consider using
AsyncCheckpointer
instead.IMPORTANT: Remember that to save and restore a checkpoint, one should always use an
AbstractCheckpointer
coupled with a CheckpointHandler. The specific CheckpointHandler to use depends on the object being saved or restored.Basic example:
ckptr = Checkpointer(StandardCheckpointHandler()) args = ocp.args.StandardSave(item=pytree_of_arrays) ckptr.save(path, args=args) args = ocp.args.StandardRestore(item=abstract_pytree_target) ckptr.restore(path, args=args)
Each handler includes …SaveArgs and …RestoreArgs classes that document what arguments are expected. When using Checkpointer, you can either use this dataclass directly, or you can provide the arguments in keyword form.
For example:
ckptr = Checkpointer(StandardCheckpointHandler()) ckptr.save(path, state=pytree_of_arays) ckptr.restore(path, state=abstract_pytree_target)
- __init__(handler, *, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), file_options=FileOptions(path_permission_mode=None), checkpoint_metadata_store=None, temporary_path_class=None)[source][source]#
- save(directory, *args, force=False, custom_metadata=None, **kwargs)[source][source]#
Saves the given item to the provided directory.
Delegates to the underlying CheckpointHandler. Ensures save operation atomicity.
This method should be called by all hosts - process synchronization and actions that need to be performed on only one host are managed internally.
- Parameters:
directory (
Union
[str
,PathLike
]) – a path to which to save.*args – additional args to provide to the CheckpointHandler’s save method.
force (
bool
) – if True, allows overwriting an existing directory. May add overhead due to the need to delete any existing files.custom_metadata (
UnionType
[dict
[str
,Any
],None
]) – a dictionary of custom metadata to be written to the checkpoint directory via StepMetadata.**kwargs – additional keyword args to provide to the CheckpointHandler’s save method.
- Raises:
ValueError if the provided directory already exists. –
AsyncCheckpointer#
- class orbax.checkpoint.AsyncCheckpointer(handler, timeout_secs=None, *, async_options=AsyncOptions(timeout_secs=600, barrier_sync_fn=None, post_finalization_callback=None, create_directories_asynchronously=True), multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), file_options=FileOptions(path_permission_mode=None), checkpoint_metadata_store=None, temporary_path_class=None)[source][source]#
An asynchronous implementation of Checkpointer.
Save operations take place in a background thread (this functionality is provided by AsyncManager). Users should call wait_until_finished to block until a save operation running in the background is complete.
Like its parent, AsyncCheckpointer also makes use of an underlying
CheckpointHandler
to deal with type-specific logic.Please see
Checkpointer
documentation for more generic usage instructions.- __init__(handler, timeout_secs=None, *, async_options=AsyncOptions(timeout_secs=600, barrier_sync_fn=None, post_finalization_callback=None, create_directories_asynchronously=True), multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), file_options=FileOptions(path_permission_mode=None), checkpoint_metadata_store=None, temporary_path_class=None)[source][source]#
- save(directory, *args, force=False, custom_metadata=None, **kwargs)[source][source]#
Saves the given item to the provided directory.
Delegates to the underlying CheckpointHandler. Ensures save operation atomicity. Must first block until any previous save operations running in the background are completed.
This method should be called by all hosts - process synchronization and actions that need to be performed on only one host are managed internally.
- Parameters:
directory (
Union
[str
,PathLike
]) – a path to which to save.*args – additional args to provide to the CheckpointHandler’s save method.
force (
bool
) – if True, allows overwriting an existing directory. May add overhead due to the need to delete any existing files.custom_metadata (
UnionType
[dict
[str
,Any
],None
]) – a dictionary of custom metadata to be written to the checkpoint directory via StepMetadata.**kwargs – additional keyword args to provide to the CheckpointHandler’s save method.
- Raises:
ValueError if the provided directory already exists. –
StandardCheckpointer#
- class orbax.checkpoint.StandardCheckpointer(*, async_options=AsyncOptions(timeout_secs=600, barrier_sync_fn=None, post_finalization_callback=None, create_directories_asynchronously=True), multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), file_options=FileOptions(path_permission_mode=None), checkpoint_metadata_store=None, temporary_path_class=None, **kwargs)[source][source]#
Shorthand class.
Note that this
Checkpointer
saves asynchronously.Initialization:
# Instead of: with AsyncCheckpointer(StandardCheckpointHandler()) as ckptr: ... # We can use: with StandardCheckpointer() as ckptr: ...
This class is convenient because ocp.args does not need to specified when saving and restoring. Saving/restoring:
# Instead of: with AsyncCheckpointer(StandardCheckpointHandler()) as ckptr: ckptr.save(directory, args=StandardSave(state, save_args)) ckptr.restore(directory, args=StandardRestore(abstract_target)) # We can use: with StandardCheckpointer() as ckptr: ckptr.save(directory, state, save_args=save_args) ckptr.restore(directory, abstract_target)
- __init__(*, async_options=AsyncOptions(timeout_secs=600, barrier_sync_fn=None, post_finalization_callback=None, create_directories_asynchronously=True), multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), file_options=FileOptions(path_permission_mode=None), checkpoint_metadata_store=None, temporary_path_class=None, **kwargs)[source][source]#
Constructor.
- Parameters:
async_options (
AsyncOptions
) – See superclass documentation.multiprocessing_options (
MultiprocessingOptions
) – See superclass documentation.file_options (
FileOptions
) – See superclass documentation.checkpoint_metadata_store (
Optional
[MetadataStore
,None
]) – See superclass documentation.temporary_path_class (
Optional
[Type
[TemporaryPath
],None
]) – See superclass documentation.**kwargs – Additional init args passed to StandardCHeckpointHandler. See orbax.checkpoint.standard_checkpoint_handler.StandardCheckpointHandler.
- save(directory, state, *, save_args=None, force=False, custom_metadata=None)[source][source]#
Saves a checkpoint asynchronously (does not block).
- Parameters:
directory (
Union
[str
,PathLike
]) – Path where the checkpoint will be saved.state (
Any
) – a PyTree of arrays to be saved.save_args (
Optional
[Any
,None
]) – a PyTree with the same structure of item, which consists of ocp.SaveArgs objects as values. None can be used for values where no SaveArgs are specified. Only necessary for fine-grained customization of saving behavior for individual parameters.force (
bool
) – See superclass documentation.custom_metadata (
UnionType
[dict
[str
,Any
],None
]) – a dictionary of custom metadata to be written to the checkpoint directory via StepMetadata.
- restore(directory, target=None, *, strict=True)[source][source]#
Restores a checkpoint.
- Parameters:
directory (
Union
[str
,PathLike
]) – Path where the checkpoint will be saved.target (
Optional
[Any
,None
]) – a PyTree representing the expected structure of the checkpoint. Values may be either real array or scalar values, or they may be jax.ShapeDtypeStruct. If real values are provided, that value will be restored as the given type, with the given properties. If jax.ShapeDtypeStruct is provided, the value will be restored as np.ndarray, unless sharding is specified. If item is a custom PyTree class, the tree will be restored with the same structure as provided. If not provided, restores as a serialized nested dict representation of the custom class.strict (
bool
) – if False, restoration allows silent truncating/padding of arrays if the stored array shape does not match the target shape. Otherwise, raises an error.
- Return type:
Any
- Returns:
The restored checkpoint.