CheckpointManager#
A class providing functionalities for managing a series of checkpoints.
CheckpointManager#
- class orbax.checkpoint.CheckpointManager(directory, checkpointers=None, options=None, metadata=None, item_names=None, item_handlers=None, logger=None, handler_registry=None)[source][source]#
A generic, synchronous AbstractCheckpointManager implementation.
- __init__(directory, checkpointers=None, options=None, metadata=None, item_names=None, item_handlers=None, logger=None, handler_registry=None)[source][source]#
CheckpointManager constructor.
IMPORTANT: CheckpointManager has been refactored to provide a new API. Please ensure you have migrated all existing use cases to the newer style by August 1st, 2024. Please see https://orbax.readthedocs.io/en/latest/guides/checkpoint/api_refactor.html for technical details.
The CheckpointManager is ultimately backed by a single
Checkpointer, to which saving and restoring is delegated. Behind step management options, metrics-related logic, and other frills, saving and restoring with CheckpointManager is quite similar to using Checkpointer(CompositeCheckpointHandler).Example:
# Multiple items. with CheckpointManager( 'path/to/dir/', # Global metadata. metadata={'version': 1.1, 'lang': 'en'}, ) as mngr: mngr.save( 0, args=args.Composite( train_state=args.StandardSave(train_state), json_states=args.JsonSave(json_states), ), # Metadata varying by step. custom_metadata={'learning_rate': 0.001} ) restored = mngr.restore(0) print(restored.train_state) print(restored.json_states) restored = mngr.restore(0, args=args.Composite( train_state=args.StandardRestore(abstract_train_state), ) ) print(restored.train_state) print(restored.json_states) # Error, not restored global_metadata = mngr.metadata() step_metadata = mngr.metadata(0) # custom_metadata in here # Single, unnamed (default) item. with CheckpointManager( 'path/to/dir/', options = CheckpointManagerOptions(max_to_keep=5, ...), ) as mngr: mngr.save(0, args=StandardSave(train_state)) train_state = mngr.restore(0) train_state = mngr.restore(0, args=StandardRestore(abstract_train_state))
IMPORTANT: Don’t forget to use the keyword args=… for save and restore! Otherwise you will get the legacy API. This will not be necessary forever, but only until the legacy API is removed.
IMPORTANT: The CheckpointManager is designed to be used as a context manager. Use with CheckpointManager schematic for automatic cleanup. If you can’t use a context manager, always call close() to release resources properly. Otherwise, background operations such as deleting old checkpoints might not finish before your program exits.
- CheckpointManager:
is NOT thread-safe.
IS multi-process-safe.
is NOT multi-job-safe.
This means that CheckpointManager is intended to be created and called across all processes within a single job, where each process is single-threaded, but is not safe to use when multiple jobs each have CheckpointManager instances pointing to the same root directory. Concretely, this means that if you have a trainer job and one or more evaluator jobs, the CheckpointManager should be created and called across all processes in the trainer, but a CheckpointManager cannot be used in the evaluators. Instead, utilities used during evaluation can be found in checkpoint_utils (see https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.checkpoint_utils.html).
- Parameters:
directory (
Union[str,PathLike]) – the top level directory in which to save all files.checkpointers (
Union[AbstractCheckpointer,Mapping[str,AbstractCheckpointer],None]) – deprecated, do not use. use handler_registry instead.options (
Optional[CheckpointManagerOptions,None]) – CheckpointManagerOptions. May be provided to specify additional arguments. If None, uses default values of CheckpointManagerOptions.metadata (
Optional[Mapping[str,Any],None]) – High-level metadata that does not depend on step number. If directory is write enabled then given metadata is saved only once. A new CheckpointManager instance with that directory does not overwrite the existing metadata and ignores the current given metadata. If directory is read-only then the current given metadata is not saved as expected. A CheckpointManager instance with a read-only directory uses the metadata if already present, otherwise always uses the current given metadata.item_names (
Optional[Sequence[str],None]) – deprecated, do not use. use handler_registry instead.item_handlers (
Union[CheckpointHandler,Mapping[str,CheckpointHandler],None]) – deprecated, do not use. use handler_registry instead.logger (
Optional[AbstractLogger,None]) – A logger to log checkpointing events.handler_registry (
Optional[CheckpointHandlerRegistry,None]) – A registry of handlers to use for checkpointing. This option is mutually exclusive with checkpointers,`item_handlers`, and ‘item_names’. SeeCheckpointHandlerRegistryfor more details.
- property directory: Path#
Returns the top-level directory containing checkpoints for all items.
- Return type:
- all_steps(read=False)[source][source]#
Returns all steps tracked by the manager.
- Parameters:
read (
bool) – If True, forces a read directly from the storage location. Otherwise, a cached result can be returned.- Return type:
Sequence[int]- Returns:
A sequence of steps (integers)
- latest_step()[source][source]#
Returns the latest step saved.
Returns None if no steps have been saved.
- Return type:
Optional[int,None]- Returns:
A step (int) or None if no steps are present.
- best_step()[source][source]#
Returns the best step saved, as defined by options.best_fn.
Returns None if no steps have been saved.
- Return type:
Optional[int,None]- Returns:
A step (int) or None if no steps are present.
- reload()[source][source]#
Reloads internal properties.
Resets internal cache of checkpoint steps, in case the directory managed by this object has been updated externally.
- reached_preemption(step)[source][source]#
Returns True if a preemption sync point has been reached.
- Return type:
bool
- should_save(step)[source][source]#
Returns True if a checkpoint should be saved for the current step.
This depends the previous step and SaveDecisionPolicy.
- Parameters:
step (
int) – int- Return type:
bool- Returns:
True if the checkpoint should be saved.
- delete(step)[source][source]#
See superclass documentation.
Delete can be run asynchronously if CheckpointManagerOptions.enable_background_delete is set to True.
- Parameters:
step (
int) – The step to delete.- Raises:
FileNotFoundError – If the step does not exist.
- save(step, items=None, save_kwargs=None, metrics=None, force=False, args=None, custom_metadata=None)[source][source]#
Saves the provided items.
This method should be called by all hosts - process synchronization and actions that need to be performed on only one host are managed internally.
NOTE: The items and save_kwargs arguments are deprecated, use args instead. Make sure to configure CheckpointManager with item_names.
args should be a subclass of orbax.checkpoint.args.CheckpointArgs, the specific type of which is used to indicate what logic is used to save the object. For a typical, PyTree of arrays, use StandardSave/StandardRestore.
When constructing the CheckpointManager, if no item_names were provided, it is assumed that we are managing a single object. If item_names were provided, it is assumed that we are managing multiple objects, and args must be orbax.checkpoint.args.CompositeArgs. See below for details.
Example:
# Single item mngr = ocp.CheckpointManager(directory) mngr.save(step, args=ocp.args.StandardSave(my_train_state)) # Multiple items mngr = ocp.CheckpointManager(directory, item_names=('state', 'meta')) mngr.save(step, args=ocp.args.Composite( state=ocp.args.StandardSave(my_train_state), meta=ocp.args.JsonSave(my_metadata) ))
- Parameters:
step (
int) – current step, intitems (
Union[Any,Mapping[str,Any],None]) – a savable object, or a dictionary of object name to savable object.save_kwargs (
Union[Mapping[str,Any],Mapping[str,Mapping[str,Any]],None]) – save kwargs for a single Checkpointer, or a dictionary of object name to kwargs needed by the Checkpointer implementation to save the object.metrics (
Optional[Any,None]) – a dictionary of metric name (string) to numeric value to be tracked along with this checkpoint. Required if options.best_fn is set. Allows users to specify a metric value to determine which checkpoints are best and should be kept (in conjunction with options.max_to_keep).force (
Optional[bool,None]) – if True, this method will attempt to save a checkpoint regardless of the result of AbstractCheckpointManager.should_save(step). By default, save will only write a checkpoint to disk when the options permit, e.g. when step is in options.save_interval_steps or options.save_on_steps. Setting force=True will not overwrite existing checkpoints.args (
Optional[CheckpointArgs,None]) – CheckpointArgs which is used to save checkpointable objects with the appropriate logic.custom_metadata (
UnionType[dict[str,Any],None]) – a dictionary of custom metadata to be written to the checkpoint directory via StepMetadata.
- Return type:
bool- Returns:
bool indicating whether a save operation was performed.
- Raises:
ValueError – if track_best was indicated but metrics is not provided.
ValueError – directory creation failed.
ValueError – if an item is provided for which no Checkpointer is
found. –
ValueError – if the checkpoint already exists.
- restore(step, items=None, restore_kwargs=None, directory=None, args=None)[source][source]#
Restores from the given step and provided items.
This method should be called by all hosts - process synchronization and actions that need to be performed on only one host are managed internally.
NOTE: The items and restore_kwargs arguments are deprecated, use args instead. Make sure to configure CheckpointManager with item_names. See save docstring for additional details.
Example:
# Single item mngr = ocp.CheckpointManager(directory) mngr.restore(step, args=ocp.args.StandardRestore(abstract_train_state)) # Multiple items mngr = ocp.CheckpointManager(directory, item_names=('state', 'meta')) mngr.restore(step, args=ocp.args.Composite( state=ocp.args.StandardRestore(abstract_train_state), meta=ocp.args.JsonRestore(), )) # If it is acceptable to restore without providing additional arguments, # and if a save has already been performed, it is ok to do the following: mngr.restore(step, args=ocp.args.Composite(state=None, meta=None)) # If a save has not already been performed, there is no way for Orbax to # know how to restore the objects. If a save has already been performed, # it remembers the logic used to save the objects.
- Parameters:
step (
Optional[int,None]) – current step, intitems (
Union[Any,Mapping[str,Any],None]) – a restoreable object, or a dictionary of object name to restorable object.restore_kwargs (
Union[Mapping[str,Any],Mapping[str,Mapping[str,Any]],None]) – restore kwargs for a single Checkpointer, or a dictionary of object name to kwargs needed by the Checkpointer implementation to restore the object.directory (
Union[str,PathLike,None]) – if provided, uses the given directory rather than the directory property of this class. Can be used to restore checkpoints from an independent location.args (
Optional[CheckpointArgs,None]) – CheckpointArgs which is used to restore checkpointable objects with the appropriate logic.
- Return type:
Union[Any,Mapping[str,Any]]- Returns:
If managing a single item, returns a single checkpointable object. If managing multiple items, returns ocp.args.Composite, where the keys are item names, and values are checkpointable objects.
- Raises:
FileNotFoundError – If no steps are found in the directory.
- item_metadata(step)[source][source]#
Retrieves metadata for all known items.
Important note: This method will soon be deprecated in favor of metadata(step).item_metadata. Please use that method instead.
Note that metadata will only be returned for items that can actually be interpreted. If an item is present in the checkpoint but not registered (using a prior save or restore, or with handler_registry at init), the item will not be returned.
- Parameters:
step (
int) – The step to retrieve metadata for.- Return type:
Union[Any,CompositeArgs,Composite]- Returns:
Either metadata for the item itself, if in default-item mode, or a Composite of metadata for each item.
- metrics(step)[source][source]#
Returns metrics for step, if present.
- Return type:
Optional[Any,None]
- metadata(step: None = None) RootMetadata[source][source]#
- metadata(step: int) StepMetadata
See superclass documentation.
- Return type:
UnionType[RootMetadata,StepMetadata]
- wait_until_finished()[source][source]#
Blocks until any incomplete save operations are completed.
Note that this method will typically be a no-op if all checkpointers are synchronous, since old checkpoints are already cleaned up immediately after completing save, and there is no background thread to wait for.
If some checkpointers are of type
AsyncCheckpointer, however, this method will wait until each of these checkpointers is finished.
- classmethod __subclasshook__(other)[source]#
Abstract classes can override this to customize issubclass().
This is invoked early on by abc.ABCMeta.__subclasscheck__(). It should return True, False or NotImplemented. If it returns NotImplemented, the normal algorithm is used. Otherwise, it overrides the normal algorithm (and the outcome is cached).
- is_saving_in_progress()[source][source]#
Returns whether a checkpoint save is in progress.
- Return type:
bool
CheckpointManagerOptions#
- class orbax.checkpoint.CheckpointManagerOptions(save_interval_steps=1, max_to_keep=None, keep_time_interval=None, keep_period=None, should_keep_fn=None, best_fn=None, best_mode='max', keep_checkpoints_without_metrics=True, step_prefix=None, step_format_fixed_length=None, step_name_format=None, create=True, cleanup_tmp_directories=False, save_on_steps=None, single_host_load_and_broadcast=False, todelete_subdir=None, todelete_full_path=None, enable_background_delete=False, read_only=False, enable_async_checkpointing=True, async_options=None, multiprocessing_options=<factory>, should_save_fn=None, file_options=<factory>, save_root_metadata=True, temporary_path_class=None, save_decision_policy=None, preservation_policy=None, prevent_write_metrics=False, enable_should_save_is_saving_in_progress_check=True, enable_per_process_directory_creation=False, lightweight_initialize=False)[source][source]#
Optional arguments for
CheckpointManager.- save_interval_steps:
The interval at which checkpoints should be saved. Ensures checkpoints will only be saved every n steps. Defaults to 1.
- max_to_keep:
deprecated, do not use. use preservation_policy instead.
- keep_time_interval:
deprecated, do not use. use preservation_policy instead.
- keep_period:
deprecated, do not use. use preservation_policy instead.
- should_keep_fn:
deprecated, do not use. use preservation_policy instead.
- best_fn:
If set, maintains checkpoints based on the quality of given metrics rather than recency. The function should accept a PyTree of metrics, and return a scalar value that can be used to determine the quality score of the checkpoint. If max_to_keep is also set, then the retained checkpoints will be kept based on their quality, as measured by this function.
- best_mode:
One of [‘max’, ‘min’]. The best metric is determine on the basis of this value.
- keep_checkpoints_without_metrics:
If False, checkpoints without metrics present are eligible for cleanup. Otherwise, they will never be deleted.
- step_prefix:
If provided, step directories will take the form f’{step_prefix}_<step>’. Otherwise, they will simply be an integer <step>.
- step_format_fixed_length:
If set, formats step with n digits (leading zeros). This makes sorting steps easier. Otherwise, step has no leading zeros.
- step_name_format:
NameFormat to build or find steps under input root directory. If provided, step_prefix, step_format_fixed_length are ignored.
- create:
If True, creates the top-level directory if it does not already exist.
- cleanup_tmp_directories:
If True, cleans up any existing temporary directories on CheckpointManager creation.
- save_on_steps:
Optional set of steps at which checkpoints should be saved. Useful to save checkpoints on a fixed set of steps that are not multiple of save_interval_steps.
- single_host_load_and_broadcast:
If True, calling all_steps(read=True) will load on only a single host, and will then be broadcast to other hosts. Otherwise, I/O will be performed on every host. This can be helpful to reduce QPS to the filesystem if there are a large number of hosts.
- todelete_subdir: If set, checkpoints to be deleted will be only renamed into a
subdirectory with the provided string. Otherwise, they will be directly deleted from the file system. Useful if checkpoint deletion is time consuming. By default, delete the checkpoint assets. Ignored if file system is Google Cloud Storage (directory is prefixed with gs://)
- todelete_full_path: Specifies a path relative to the bucket root for
“soft-deleting” checkpoints on Google Cloud Storage (GCS). Instead of being permanently removed, checkpoints are moved to this new location within the same bucket. For instance, if a checkpoint is in gs://my-bucket/experiments/run1/, providing the value trash/ will move a deleted step to gs://my-bucket/trash/<step_id>. Useful when direct deletion is time consuming. It gathers all deleted items in a centralized path for future cleanup.
- enable_background_delete: If True, old checkpoint deletions will be done in a
background thread, otherwise, it will be done at the end of each save. When it’s enabled, make sure to call CheckpointManager.close() or use context to make sure all old steps are deleted before exit.
- read_only: If True, then checkpoints save and delete are skipped. However,
checkpoints restore works as usual.
- enable_async_checkpointing:
If True, enables async checkpointing.
- async_options:
Used to configure properties of async behavior. See above.
- multiprocessing_options: MultiprocessingOptions instance to configure
multiprocessing behavior.
- should_save_fn:
Predicate callable to check if given step can be saved. This callable accepts step number and optional latest step number as param and returns bool. If present then save_interval_steps and save_on_steps options are ignored.
- file_options: Options to configure checkpoint directories and files.
default=FileOptions().
- save_root_metadata: If True, saves root-level metadata about checkpoints.
This metadata is not step-specific and is written only once.
- temporary_path_class:
Optional. The concrete
atomicity_types.TemporaryPathclass to be used by the underlyingCheckpointer.- save_decision_policy: An object used to determine when a checkpoint should be
saved. If provided, overrides any other options dealing with this subject, including save_interval_steps, save_on_steps, and should_save_fn, and is the sole means of determining when a checkpoint should be saved. If not provided, these other options are used instead. Prefer to use this option over others.
- preservation_policy: An object used to determine which checkpoints to
preserve. If provided, overrides any other options dealing with this subject, including max_to_keep, keep_time_interval, keep_period, and should_keep_fn, best_fn, and is the sole means of determining which checkpoints to preserve. If not provided, these other options are used instead. Prefer to use this option over others.
prevent_write_metrics: False by default. If True, metrics will not be written. enable_should_save_is_saving_in_progress_check: True by default. If False,
should_save_fn will not check is_saving_in_progress, and will assume that no save is in progress. This only affects users of ContinuousCheckpointingPolicy - otherwise the value is ignored. This is an interim workaround for b/428061876. Do not use without explicit approval.
- enable_per_process_directory_creation: Signifies wether directories are
supposed to be created per process. This is used to support async directory creation. If True, multiprocessing_options.primary_host must be None.
- lightweight_initialize: If True, checkpoint step metadata is not read on
CheckpointManager initialization during checkpoint info loading. This is useful to improve init performance when there are O(1k) or more existing checkpoint steps present and checkpoint info properties like time and metrics are not needed.
- __eq__(other)#
Return self==value.
- __hash__ = None#
- __init__(save_interval_steps=1, max_to_keep=None, keep_time_interval=None, keep_period=None, should_keep_fn=None, best_fn=None, best_mode='max', keep_checkpoints_without_metrics=True, step_prefix=None, step_format_fixed_length=None, step_name_format=None, create=True, cleanup_tmp_directories=False, save_on_steps=None, single_host_load_and_broadcast=False, todelete_subdir=None, todelete_full_path=None, enable_background_delete=False, read_only=False, enable_async_checkpointing=True, async_options=None, multiprocessing_options=<factory>, should_save_fn=None, file_options=<factory>, save_root_metadata=True, temporary_path_class=None, save_decision_policy=None, preservation_policy=None, prevent_write_metrics=False, enable_should_save_is_saving_in_progress_check=True, enable_per_process_directory_creation=False, lightweight_initialize=False)#