AbstractCheckpointManager#

Abstract class to manage checkpoints: AbstractCheckpointManager.

AbstractCheckpointManager#

class orbax.checkpoint.abstract_checkpoint_manager.AbstractCheckpointManager(*args, **kwargs)[source][source]#

Interface to manage checkpoints.

Allows a user to save and restore objects for which a Checkpointer implementation exists (e.g. PyTreeCheckpointer for PyTrees). The class keeps track of multiple checkpointable objects in the following structure:

path/to/directory/    (top-level directory)
  0/    (step)
    params/    (first saveable)
      ...
    metadata/    (second saveable)
      ...
  1/    (step)
    ...
  2/    (step)
    ...
  ...
abstract property directory: Path#

Returns the top-level directory containing checkpoints for all items.

Return type:

Path

abstract all_steps(read=False)[source][source]#

Returns all steps tracked by the manager.

Parameters:

read (bool) – If True, forces a read directly from the storage location. Otherwise, a cached result can be returned.

Return type:

Sequence[int]

Returns:

A sequence of steps (integers)

abstract latest_step()[source][source]#

Returns the latest step saved.

Returns None if no steps have been saved.

Return type:

Optional[int]

Returns:

A step (int) or None if no steps are present.

abstract best_step()[source][source]#

Returns the best step saved, as defined by options.best_fn.

Returns None if no steps have been saved.

Return type:

Optional[int]

Returns:

A step (int) or None if no steps are present.

abstract reload()[source][source]#

Performs disk reads to ensure internal properties are up to date.

abstract reached_preemption(step)[source][source]#

Returns True if a preemption sync point has been reached.

Return type:

bool

abstract should_save(step)[source][source]#

Returns True if a checkpoint should be saved for the current step.

This depends the previous step and save interval.

Parameters:

step (int) – int

Return type:

bool

Returns:

True if the checkpoint should be saved.

abstract delete(step)[source][source]#

Deletes a step checkpoint.

abstract save(step, items=None, save_kwargs=None, metrics=None, force=False, args=None)[source][source]#

Saves the provided items.

This method should be called by all hosts - process synchronization and actions that need to be performed on only one host are managed internally.

NOTE: The items and save_kwargs arguments are deprecated, use args instead. Make sure to configure CheckpointManager with item_names.

args should be a subclass of orbax.checkpoint.args.CheckpointArgs, the specific type of which is used to indicate what logic is used to save the object. For a typical, PyTree of arrays, use StandardSave/StandardRestore.

When constructing the CheckpointManager, if no item_names were provided, it is assumed that we are managing a single object. If item_names were provided, it is assumed that we are managing multiple objects, and args must be orbax.checkpoint.args.CompositeArgs. See below for details.

Example:

# Single item
mngr = ocp.CheckpointManager(directory)
mngr.save(step, args=ocp.args.StandardSave(my_train_state))

# Multiple items
mngr = ocp.CheckpointManager(directory, item_names=('state', 'meta'))
mngr.save(step, args=ocp.args.Composite(
    state=ocp.args.StandardSave(my_train_state),
    meta=ocp.args.JsonSave(my_metadata)
))
Parameters:
  • step (int) – current step, int

  • items (Union[Any, Mapping[str, Any], None]) – a savable object, or a dictionary of object name to savable object.

  • save_kwargs (Union[Mapping[str, Any], Mapping[str, Mapping[str, Any]], None]) – save kwargs for a single Checkpointer, or a dictionary of object name to kwargs needed by the Checkpointer implementation to save the object.

  • metrics (Optional[Any]) – a dictionary of metric name (string) to numeric value to be tracked along with this checkpoint. Required if options.best_fn is set. Allows users to specify a metric value to determine which checkpoints are best and should be kept (in conjunction with options.max_to_keep).

  • force (Optional[bool]) – if True, this method will attempt to save a checkpoint regardless of the result of AbstractCheckpointManager.should_save(step). By default, save will only write a checkpoint to disk when the options permit, e.g. when step is in options.save_interval_steps or options.save_on_steps. Setting force=True will not overwrite existing checkpoints.

  • args (Optional[CheckpointArgs]) – CheckpointArgs which is used to save checkpointable objects with the appropriate logic.

Return type:

bool

Returns:

bool indicating whether a save operation was performed.

Raises:
  • ValueError – if track_best was indicated but metrics is not provided.

  • ValueError – directory creation failed.

  • ValueError – if an item is provided for which no Checkpointer is

  • found.

  • ValueError – if the checkpoint already exists.

abstract restore(step, items=None, restore_kwargs=None, directory=None, args=None)[source][source]#

Restores from the given step and provided items.

This method should be called by all hosts - process synchronization and actions that need to be performed on only one host are managed internally.

NOTE: The items and restore_kwargs arguments are deprecated, use args instead. Make sure to configure CheckpointManager with item_names. See save docstring for additional details.

Example:

# Single item
mngr = ocp.CheckpointManager(directory)
mngr.restore(step, args=ocp.args.StandardRestore(abstract_train_state))

# Multiple items
mngr = ocp.CheckpointManager(directory, item_names=('state', 'meta'))
mngr.restore(step, args=ocp.args.Composite(
    state=ocp.args.StandardRestore(abstract_train_state),
    meta=ocp.args.JsonRestore(),
))
# If it is acceptable to restore without providing additional arguments,
# and if a save has already been performed, it is ok to do the following:
mngr.restore(step, args=ocp.args.Composite(state=None, meta=None))
# If a save has not already been performed, there is no way for Orbax to
# know how to restore the objects. If a save has already been performed,
# it remembers the logic used to save the objects.
Parameters:
  • step (int) – current step, int

  • items (Union[Any, Mapping[str, Any], None]) – a restoreable object, or a dictionary of object name to restorable object.

  • restore_kwargs (Union[Mapping[str, Any], Mapping[str, Mapping[str, Any]], None]) – restore kwargs for a single Checkpointer, or a dictionary of object name to kwargs needed by the Checkpointer implementation to restore the object.

  • directory (Union[str, PathLike, None]) – if provided, uses the given directory rather than the directory property of this class. Can be used to restore checkpoints from an independent location.

  • args (Optional[CheckpointArgs]) – CheckpointArgs which is used to restore checkpointable objects with the appropriate logic.

Return type:

Union[Any, Mapping[str, Any], CompositeArgs]

Returns:

If managing a single item, returns a single checkpointable object. If managing multiple items, returns ocp.args.Composite, where the keys are item names, and values are checkpointable objects.

abstract item_metadata(step)[source][source]#

For all Checkpointers, returns any metadata associated with the item.

Calls the metadata method for each Checkpointer and returns a mapping of each item name to the restored metadata. If the manager only manages a single item, a single metadata will be returned instead.

To avoid errors due to missing CheckpointHandlers, concrete CheckpointManager constructor must allow mapping from item names to respective CheckpointHandlers to be input other than via save() and restore(). Please note that save() and restore() calls automatically map CheckpointHandlers to respective item names and retain it during the lifetime of the CheckpointManager instance.

Example:

# Single item
mngr = ocp.CheckpointManager(directory)
# No calls to save() or restore() before calling item_metadata().
mngr.item_metadata(step)  # Raises error.

mngr = ocp.CheckpointManager(directory,
    item_handlers=ocp.StandardCheckpointHandler)
# No calls to save() or restore() before calling item_metadata().
metadata = mngr.item_metadata(step)  # Successful.

# Multiple items
mngr = ocp.CheckpointManager(directory, item_names=('state', 'extra'))
# No calls to save() or restore() before calling item_metadata().
mngr.item_metadata(step)  # Raises error.

mngr = ocp.CheckpointManager(directory,
  item_names=('state', 'extra'),
  item_handlers={
      'state': ocp.StandardCheckpointHandler,
      'extra': ocp.PytreeCheckpointHandler,
  }
)
# No calls to save() or restore() before calling item_metadata().
metadata = mngr.item_metadata(step)  # Successful.

Metadata may be None for an individual item.

Parameters:

step (int) – Step for which to retrieve metadata.

Return type:

Union[Any, Mapping[str, Any], CompositeArgs]

Returns:

A dictionary mapping name to item metadata, or a single item metadata.

abstract metadata()[source][source]#

Returns CheckpointManager level metadata if present, empty otherwise.

Return type:

Mapping[str, Any]

abstract metrics(step)[source][source]#

Returns metrics for step, if present.

Return type:

Optional[Any]

abstract wait_until_finished()[source][source]#

Blocks until any incomplete save operations are completed.

Note that this method will typically be a no-op if all checkpointers are synchronous, since old checkpoints are already cleaned up immediately after completing save, and there is no background thread to wait for.

If some checkpointers are of type AsyncCheckpointer, however, this method will wait until each of these checkpointers is finished.

abstract check_for_errors()[source][source]#

Checks for any outstanding errors in completed asynchronous save operations.

Delegates to underlying Checkpointer.

abstract close()[source][source]#

Waits for outstanding operations to finish and closes Checkpointers.

__init__(*args, **kwargs)[source]#
__subclasshook__()[source]#

Abstract classes can override this to customize issubclass().

This is invoked early on by abc.ABCMeta.__subclasscheck__(). It should return True, False or NotImplemented. If it returns NotImplemented, the normal algorithm is used. Otherwise, it overrides the normal algorithm (and the outcome is cached).