CheckpointManager#

A class providing functionalities for managing a series of checkpoints.

CheckpointManager#

class orbax.checkpoint.checkpoint_manager.CheckpointManager(directory, checkpointers=None, options=None, metadata=None, item_names=None, item_handlers=None, logger=None)[source][source]#

A generic, synchronous AbstractCheckpointManager implementation.

__init__(directory, checkpointers=None, options=None, metadata=None, item_names=None, item_handlers=None, logger=None)[source][source]#

CheckpointManager constructor.

IMPORTANT: CheckpointManager has been refactored to provide a new API. Please ensure you have migrated all existing use cases to the newer style by August 1st, 2024. Please see https://orbax.readthedocs.io/en/latest/api_refactor.html for technical details.

The CheckpointManager is ultimately backed by a single Checkpointer, to which saving and restoring is delegated. Behind step management options, metrics-related logic, and other frills, saving and restoring with CheckpointManager is quite similar to using Checkpointer(CompositeCheckpointHandler).

Example:

with CheckpointManager(
  'path/to/dir/',
  # Multiple items.
  item_names=('train_state', 'custom_metadata'),
  metadata={'version': 1.1, 'lang': 'en'},
) as mngr:
  mngr.save(0, args=args.Composite(
      train_state=args.StandardSave(train_state),
      custom_metadata=args.JsonSave(custom_metadata),
    )
  )
  restored = mngr.restore(0)
  print(restored.train_state)
  print(restored.custom_metadata)
  restored = mngr.restore(0, args=args.Composite(
      train_state=args.StandardRestore(abstract_train_state),
    )
  )
  print(restored.train_state)
  print(restored.custom_metadata)  # Error, not restored

# Single item, no need to specify `item_names`.
with CheckpointManager(
    'path/to/dir/',
    options = CheckpointManagerOptions(max_to_keep=5, ...),
  ) as mngr:
  mngr.save(0, args=StandardSave(train_state))
  train_state = mngr.restore(0)
  train_state = mngr.restore(0,
  args=StandardRestore(abstract_train_state))

IMPORTANT: Don’t forget to use the keyword args=… for save and restore! Otherwise you will get the legacy API. This will not be necessary forever, but only until the legacy API is removed.

IMPORTANT: The CheckpointManager is designed to be used as a context manager. Use with CheckpointManager schematic for automatic cleanup. If you can’t use a context manager, always call close() to release resources properly. Otherwise, background operations such as deleting old checkpoints might not finish before your program exits.

CheckpointManager:
  • is NOT thread-safe.

  • IS multi-process-safe.

  • is NOT multi-job-safe.

This means that CheckpointManager is intended to be created and called across all processes within a single job, where each process is single-threaded, but is not safe to use when multiple jobs each have CheckpointManager instances pointing to the same root directory. Concretely, this means that if you have a trainer job and one or more evaluator jobs, the CheckpointManager should be created and called across all processes in the trainer, but a CheckpointManager cannot be used in the evaluators. Instead, utilities used during evaluation can be found in checkpoint_utils (see https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.checkpoint_utils.html).

Parameters:
  • directory (Union[str, PathLike]) – the top level directory in which to save all files.

  • checkpointers (Union[AbstractCheckpointer, Mapping[str, AbstractCheckpointer], None]) – a mapping of object name to Checkpointer object. For example, items provided to save below should have keys matching the keys in this argument. Alternatively, a single Checkpointer may be provided, in which case save and restore should always be called with a single item rather than a dictionary of items. See below for more details. item_names and checkpointers are mutually exclusive - do not use together. Also, please don’t use checkpointers and item_handlers together.

  • options (Optional[CheckpointManagerOptions]) – CheckpointManagerOptions. May be provided to specify additional arguments. If None, uses default values of CheckpointManagerOptions.

  • metadata (Optional[Mapping[str, Any]]) – High-level metadata that does not depend on step number. If directory is write enabled then given metadata is saved only once. A new CheckpointManager instance with that directory does not overwrite the existing metadata and ignores the current given metadata. If directory is read-only then the current given metadata is not saved as expected. A CheckpointManager instance with a read-only directory uses the metadata if already present, otherwise always uses the current given metadata.

  • item_names (Optional[Sequence[str]]) – Names of distinct items that may be saved/restored with this CheckpointManager. item_names and checkpointers are mutually exclusive - do not use together. Also see item_handlers below.

  • item_handlers (Union[CheckpointHandler, Mapping[str, CheckpointHandler], None]) – A mapping of item name to CheckpointHandler. The mapped CheckpointHandler must be registered against the CheckpointArgs input in save/restore operations. Please don’t use checkpointers and item_handlers together. It can be used with or without item_names. The item name key may or may not be present in item_names. Alternatively, a single CheckpointHandler may be provided, in which case save and restore should always be called in a single item context.

  • logger (Optional[AbstractLogger]) – A logger to log checkpointing events.

property directory: Path#

See superclass documentation.

Return type:

Path

all_steps(read=False)[source][source]#

See superclass documentation.

Return type:

Sequence[int]

latest_step()[source][source]#

See superclass documentation.

Return type:

Optional[int]

best_step()[source][source]#

See superclass documentation.

Return type:

Optional[int]

reload()[source][source]#

Reloads internal properties.

Resets internal cache of checkpoint steps, in case the directory managed by this object has been updated externally.

reached_preemption(step)[source][source]#

See superclass documentation.

Return type:

bool

should_save(step)[source][source]#

See superclass documentation.

Return type:

bool

delete(step)[source][source]#

See superclass documentation.

Delete can be run asynchronously if CheckpointManagerOptions.enable_background_delete is set to True.

Parameters:

step (int) – The step to delete.

save(step, items=None, save_kwargs=None, metrics=None, force=False, args=None)[source][source]#

See superclass documentation.

Return type:

bool

restore(step, items=None, restore_kwargs=None, directory=None, args=None)[source][source]#

See superclass documentation.

Return type:

Union[Any, Mapping[str, Any]]

item_metadata(step)[source][source]#

See superclass documentation.

Return type:

Union[Any, CompositeArgs]

metrics(step)[source][source]#

Returns metrics for step, if present.

Return type:

Optional[Any]

metadata()[source][source]#

See superclass documentation.

Return type:

Mapping[str, Any]

wait_until_finished()[source][source]#

See superclass documentation.

check_for_errors()[source][source]#

See superclass documentation.

__subclasshook__()[source]#

Abstract classes can override this to customize issubclass().

This is invoked early on by abc.ABCMeta.__subclasscheck__(). It should return True, False or NotImplemented. If it returns NotImplemented, the normal algorithm is used. Otherwise, it overrides the normal algorithm (and the outcome is cached).

close()[source][source]#

See superclass documentation.

CheckpointManagerOptions#

class orbax.checkpoint.checkpoint_manager.CheckpointManagerOptions(save_interval_steps=1, max_to_keep=None, keep_time_interval=None, keep_period=None, best_fn=None, best_mode='max', keep_checkpoints_without_metrics=True, step_prefix=None, step_format_fixed_length=None, step_name_format=None, create=True, cleanup_tmp_directories=False, save_on_steps=None, single_host_load_and_broadcast=False, todelete_subdir=None, enable_background_delete=False, read_only=False, enable_async_checkpointing=True, async_options=None, multiprocessing_options=<factory>)[source][source]#

Optional arguments for CheckpointManager.

save_interval_steps:

The interval at which checkpoints should be saved. Ensures checkpoints will only be saved every n steps. Defaults to 1.

max_to_keep:

If provided, specifies the maximum number of checkpoints to keep. Older checkpoints are removed. By default, does not remove any old checkpoints. Must be None or non-negative. When set, checkpoints may be considered for deletion when there are more than max_to_keep checkpoints present. Checkpoints are kept if they meet any of the conditions below, such as keep_time_interval, keep_period, etc. Any remaining checkpoints that do not meet these conditions are garbage-collected.

keep_time_interval:

When more than max_to_keep checkpoints are present, an older checkpoint that would ordinarily be deleted will be preserved if it has been at least keep_time_interval since the previous preserved checkpoint. The default setting of None does not preserve any checkpoints in this way. For example, this may be used to ensure checkpoints are retained at a frequency of approximately than one per hour.

keep_period:

If set, any existing checkpoints matching checkpoint_step % keep_period == 0 will not be deleted.

best_fn:

If set, maintains checkpoints based on the quality of given metrics rather than recency. The function should accept a PyTree of metrics, and return a scalar value that can be used to determine the quality score of the checkpoint. If max_to_keep is also set, then the retained checkpoints will be kept based on their quality, as measured by this function.

best_mode:

One of [‘max’, ‘min’]. The best metric is determine on the basis of this value.

keep_checkpoints_without_metrics:

If False, checkpoints without metrics present are eligible for cleanup. Otherwise, they will never be deleted.

step_prefix:

If provided, step directories will take the form f’{step_prefix}_<step>’. Otherwise, they will simply be an integer <step>.

step_format_fixed_length:

If set, formats step with n digits (leading zeros). This makes sorting steps easier. Otherwise, step has no leading zeros.

step_name_format:

NameFormat to build or find steps under input root directory. If provided, step_prefix, step_format_fixed_length are ignored.

create:

If True, creates the top-level directory if it does not already exist.

cleanup_tmp_directories:

If True, cleans up any existing temporary directories on CheckpointManager creation.

save_on_steps:

Optional set of steps at which checkpoints should be saved. Useful to save checkpoints on a fixed set of steps that are not multiple of save_interval_steps.

single_host_load_and_broadcast:

If True, calling all_steps(read=True) will load on only a single host, and will then be broadcast to other hosts. Otherwise, I/O will be performed on every host. This can be helpful to reduce QPS to the filesystem if there are a large number of hosts.

todelete_subdir: If set, checkpoints to be deleted will be only renamed into a

subdirectory with the provided string. Otherwise, they will be directly deleted from the file system. Useful if checkpoint deletion is time consuming. By default, delete the checkpoint assets. Ignored if file system is Google Cloud Storage (directory is prefixed with gs://)

enable_background_delete: If True, old checkpoint deletions will be done in a

background thread, otherwise, it will be done at the end of each save. When it’s enabled, make sure to call CheckpointManager.close() or use context to make sure all old steps are deleted before exit.

read_only: If True, then checkpoints save and delete are skipped. However,

checkpoints restore works as usual.

enable_async_checkpointing: If True, enables async checkpointing. async_options: Used to configure properties of async behavior. See above.

__eq__(other)#

Return self==value.

__hash__ = None#
__init__(save_interval_steps=1, max_to_keep=None, keep_time_interval=None, keep_period=None, best_fn=None, best_mode='max', keep_checkpoints_without_metrics=True, step_prefix=None, step_format_fixed_length=None, step_name_format=None, create=True, cleanup_tmp_directories=False, save_on_steps=None, single_host_load_and_broadcast=False, todelete_subdir=None, enable_background_delete=False, read_only=False, enable_async_checkpointing=True, async_options=None, multiprocessing_options=<factory>)#

AsyncOptions#

class orbax.checkpoint.checkpoint_manager.AsyncOptions(timeout_secs=300, barrier_sync_fn=None)[source][source]#

Options used to configure async behavior.

See AsyncCheckpointer for details.

__eq__(other)#

Return self==value.

__hash__ = None#
__init__(timeout_secs=300, barrier_sync_fn=None)#

MultiprocessingOptions#

class orbax.checkpoint.checkpoint_manager.MultiprocessingOptions(primary_host=0, active_processes=None)[source][source]#

Options used to configure multiprocessing behavior.

primary_host: the host id of the primary host. Default to 0. If it’s set

to None, then all hosts will be considered as primary. It’s useful in the case that all hosts are only working with local storage.

active_processes: A set of process indices (corresponding to

multihost.process_index()) over which CheckpointManager is expected to be called. This makes it possible to have a CheckpointManager instance that runs over a subset of processes, rather than all processes as it is normally expected to do. If specified, primary_host must belong to active_processes.

__eq__(other)#

Return self==value.

__hash__ = None#
__init__(primary_host=0, active_processes=None)#

Utility functions#

orbax.checkpoint.checkpoint_manager.is_async_checkpointer(checkpointer)[source][source]#