Checkpointing Utilities#

High-level checkpoint utils provided for user convenience.

orbax.checkpoint.checkpoint_utils.wait_for_new_checkpoint(checkpoint_dir, *, until_step=None, seconds_to_sleep=1, timeout=None, timeout_fn=None, step_prefix=None, step_format_fixed_length=None, step_name_format=None)[source][source]#

Waits until a new checkpoint file is found.

Automatically locks any checkpoint that is returned, and unlocks the checkpoint when execution returns to this function.

Parameters:
  • checkpoint_dir (Path) – The directory in which checkpoints are saved.

  • until_step (Optional[int]) – If specified, waits until a step greater than or equal to until_step has been found. If set to None (default), returns the first step found.

  • seconds_to_sleep (int) – The number of seconds to sleep for before looking for a new checkpoint.

  • timeout (Optional[int]) – The maximum number of seconds to wait. If left as None, then the process will wait indefinitely.

  • timeout_fn (Optional[Callable[[], bool]]) – Optional function to call after a timeout. If the function returns True, then it means that no new checkpoints will be generated and the iterator will exit. The function is called with no arguments.

  • step_prefix (Optional[str]) – A prefix applied to step numbers (e.g. <prefix>_42).

  • step_format_fixed_length (Optional[int]) – Expects to find checkpoint step directories with exactly this number of digits (leading zeros if necessary).

  • step_name_format (Optional[NameFormat]) – Step NameFormat used to find step under given root directory. If provided, step_prefix and step_format_fixed_length are ignored.

Yields:

a new checkpoint step, or -1 if the timeout was reached.

orbax.checkpoint.checkpoint_utils.unlock_existing_checkpoints(checkpoint_dir, step_prefix=None, step_format_fixed_length=None, step_name_format=None)[source][source]#

Removes LOCKED file for all existing steps, if present.

Parameters:
  • checkpoint_dir (Path) – The directory containing StepDirs.

  • step_prefix (Optional[str]) – A prefix applied to step numbers (e.g. <prefix>_42).

  • step_format_fixed_length (Optional[int]) – Expects to find checkpoint step directories with exactly this number of digits (leading zeros if necessary).

  • step_name_format (Optional[NameFormat]) – Step NameFormat used to find step under given root directory. If provided, step_prefix and step_format_fixed_length are ignored.

orbax.checkpoint.checkpoint_utils.checkpoints_iterator(checkpoint_dir, *, min_interval_secs=0, seconds_to_sleep=1, timeout=None, timeout_fn=None, step_prefix=None, step_format_fixed_length=None, step_name_format=None)[source][source]#

Continuously yield new checkpoint files as they appear.

Based on the equivalent TF method.

The iterator only checks for new checkpoints when control flow has been reverted to it. This means it can miss checkpoints if your code takes longer to run between iterations than min_interval_secs or the interval at which new checkpoints are written.

Warning: If CheckpointManager is running in a different process for training and is cleaning up old checkpoints (via the max_to_keep argument), steps returned by this function may not be valid after being clean up by another process. In this case, max_to_keep should be increased (suggested value: 5)

The timeout argument is the maximum number of seconds to block waiting for a new checkpoint. It is used in combination with the timeout_fn as follows:

  • If the timeout expires and no timeout_fn was specified, the iterator stops yielding.

  • If a timeout_fn was specified, that function is called and if it returns a true boolean value the iterator stops yielding.

  • If the function returns a false boolean value then the iterator resumes the wait for new checkpoints. At this point the timeout logic applies again.

This behavior gives control to callers on what to do if checkpoints do not come fast enough or stop being generated. For example, if callers have a way to detect that the training has stopped and know that no new checkpoints will be generated, they can provide a timeout_fn that returns True when the training has stopped. If they know that the training is still going on they return False instead.

Parameters:
  • checkpoint_dir (Union[str, PathLike]) – The directory in which checkpoints are saved.

  • min_interval_secs (int) – The minimum number of seconds between yielding checkpoints.

  • seconds_to_sleep (int) – Seconds to sleep if a checkpoint is not found. Note the difference with min_interval_secs, which puts a lower bound on how when a new checkpoint will be looked for after yielding one checkpoint. seconds_to_sleep instead specifies how we should sleep for if no new checkpoints are found. Note that the timeout is only checked when not sleeping, so a seconds_to_sleep longer than the timeout would result in timing out after seconds_to_sleep seconds rather than timeout seconds.

  • timeout (Optional[int]) – The maximum number of seconds to wait between checkpoints. The function will time out if timeout seconds have passed since a new checkpoint step was found. If left as None, then the process will wait indefinitely.

  • timeout_fn (Optional[Callable[[], bool]]) – Optional function called after a timeout. If the function returns True, then it means that no new checkpoints will be generated and the iterator will exit. The function is called with no arguments.

  • step_prefix (Optional[str]) – A prefix applied to step numbers (e.g. <prefix>_42).

  • step_format_fixed_length (Optional[int]) – Expects to find checkpoint step directories with exactly this number of digits (leading zeros if necessary).

  • step_name_format (Optional[NameFormat]) – Step NameFormat used to find step under given root directory. If provided, step_prefix and step_format_fixed_length are ignored.

Yields:

Integer step numbers of the latest checkpoints as they arrive.

Return type:

Iterator[int]

orbax.checkpoint.checkpoint_utils.construct_restore_args(target, sharding_tree=None, set_global_shape=True)[source][source]#

Creates restore_args given a target PyTree.

This method should be used in conjunction with a CheckpointManager or Checkpointer that wraps a PyTreeCheckpointHandler.

For example:

mngr = CheckpointManager(path, Checkpointer(PyTreeCheckpointHandler()))
restore_args = construct_restore_args(train_state, train_state_sharding)
restore_kwargs = {'restore_args': restore_args}
mngr.restore(..., restore_kwargs=restore_kwargs)

OR:

mngr = CheckpointManager(path, {
    'train_state': Checkpointer(PyTreeCheckpointHandler())
})
restore_args = construct_restore_args(train_state, train_state_sharding)
restore_kwargs = {'train_state': {'restore_args': restore_args} }
mngr.restore(..., restore_kwargs=restore_kwargs)

OR:

ckptr = Checkpointer(PyTreeCheckpointHandler())
restore_args = construct_restore_args(train_state, train_state_sharding)
ckptr.restore(..., restore_args=restore_args)

If a leaf in target is a np.ndarray, or int, or string, for example, a corresponding value for that leaf must be provided in axes_tree, but will be ignored.

Parameters:
  • target (Any) – The returned PyTree will match the structure of target. target may contain value_metadata.Metadata, real scalar or array values, or may contain jax.ShapeDtypeStruct.

  • sharding_tree (Optional[Any]) – A PyTree matching target which will be used to set the restoration sharding. If not provided, sharding will default to the shardings specified by target.

  • set_global_shape (bool) – If true, set the global_shape field of ArrayRestoreArgs.

Return type:

Any

Returns:

A PyTree matching target of RestoreArgs (or ArrayRestoreArgs) objects.