Source code for orbax.checkpoint.checkpoint_utils

# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""High-level checkpoint utils provided for user convenience."""

import contextlib
import time
from typing import Any, Callable, Iterator, Optional

from absl import logging
from etils import epath
import jax
from jax.experimental import layout
import numpy as np
from orbax.checkpoint import utils
from orbax.checkpoint._src import asyncio_utils
from orbax.checkpoint._src.arrays import sharding as arrays_sharding_lib
from orbax.checkpoint._src.metadata import tree as tree_metadata
from orbax.checkpoint._src.metadata import value as value_metadata
from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint._src.path import step as step_lib
from orbax.checkpoint._src.path.snapshot import snapshot as snapshot_lib
from orbax.checkpoint._src.serialization import type_handlers



PyTree = Any
STANDARD_ARRAY_TYPES = (int, float, np.ndarray, jax.Array)
_SNAPSHOTS = '_SNAPSHOTS'
if jax.__version_info__ >= (0, 6, 2):
  Format = layout.Format
else:
  Format = layout.Layout
PLACEHOLDER = type_handlers.PLACEHOLDER


def _init_step_name_format(
    step_name_format: Optional[step_lib.NameFormat[step_lib.Metadata]] = None,
    step_prefix: Optional[str] = None,
    step_format_fixed_length: Optional[int] = None,
):
  return step_name_format or step_lib.standard_name_format(
      step_prefix=step_prefix,
      step_format_fixed_length=step_format_fixed_length,
  )


def get_snapshot_dir_from_step_dir(
    step_dir: epath.Path, snapshot_dir: Optional[epath.Path] = None
) -> epath.Path:
  """Returns the snapshot directory from the step directory."""
  if snapshot_dir is None:
    snapshot_dir = step_dir.parent / _SNAPSHOTS
  new_path = snapshot_dir / step_dir.name
  return new_path


@contextlib.contextmanager
def _manage_snapshot_file_not_found(
    ignore_errors: bool, formatted_message: str
) -> Iterator[None]:
  """Context manager to optionally suppress FileNotFoundError."""
  try:
    yield
  except FileNotFoundError:
    if ignore_errors:
      logging.warning(formatted_message, exc_info=True)
    else:
      raise


def _snapshot_checkpoint(
    checkpoint_dir: epath.Path,
    step: int,
    step_name_format: step_lib.NameFormat[step_lib.Metadata],
    snapshot_dir: Optional[epath.Path] = None,
    *,
    set_immutable: bool | None = None,
    ignore_file_not_found_error: bool | None = None,
):
  """Uses `Snapshot` class to create a cheap "copy" of the checkpoint."""
  if multihost.process_index() != 0:
    return False

  logging.info('Snapshotting step: %d.', step)
  step_dir = step_name_format.find_step(checkpoint_dir, step).path
  if not step_dir.exists():
    raise ValueError(f'Step directory {step_dir} does not exist.')

  if snapshot_dir is None:
    snapshot_dir = checkpoint_dir / _SNAPSHOTS
  if not snapshot_dir.exists():
    try:
      snapshot_dir.mkdir(parents=True, exist_ok=True)
    except OSError as e:
      raise ValueError(
          f'Failed to create snapshot directory {snapshot_dir}. Please'
          ' provide a snapshot directory instead.'
      ) from e

  snapshot_path = get_snapshot_dir_from_step_dir(step_dir, snapshot_dir)
  if epath.Path(snapshot_path).exists():
    return True
  snapshot_impl = snapshot_lib.create_instance(
      step_dir, snapshot_path, set_immutable=set_immutable
  )
  with _manage_snapshot_file_not_found(
      ignore_errors=ignore_file_not_found_error,
      formatted_message=(
          f'Ignoring error when snapshotting checkpoint for step: {step}'
      ),
  ):
    asyncio_utils.run_sync(snapshot_impl.create_snapshot())
  return True


def _release_snapshot(
    checkpoint_dir: epath.Path,
    step: int,
    step_name_format: step_lib.NameFormat[step_lib.Metadata],
    snapshot_dir: Optional[epath.Path] = None,
    *,
    ignore_file_not_found_error: bool | None = None,
):
  """Releases snapshot by deleting the snapshot of the checkpoint."""
  if multihost.process_index() == 0:
    logging.info('Releasing snapshot at step: %d.', step)
    if snapshot_dir is None:
      snapshot_dir = checkpoint_dir / _SNAPSHOTS
    snapshot_path = snapshot_dir / step_name_format.build_name(step)
    snapshot_impl = snapshot_lib.create_instance(checkpoint_dir, snapshot_path)
    with _manage_snapshot_file_not_found(
        ignore_errors=ignore_file_not_found_error,
        formatted_message=(
            f'Ignoring error when releasing snapshot for step: {step}'
        ),
    ):
      asyncio_utils.run_sync(snapshot_impl.release_snapshot())


def _reached_desired_step(step: int, until_step: Optional[int]) -> bool:
  if step is None:
    return False
  elif until_step is None:
    return True
  elif step >= until_step:
    return True
  return False


def _wait_for_new_checkpoint(
    checkpoint_dir: epath.Path,
    *,
    step_name_format: step_lib.NameFormat[step_lib.Metadata],
    until_step: Optional[int] = None,
    seconds_to_sleep: int = 1,
    timeout: Optional[int] = None,
    timeout_fn: Optional[Callable[[], bool]] = None,
    snapshot_dir: Optional[epath.Path] = None,
    set_immutable: bool | None = None,
    ignore_snapshot_errors: bool | None = None,
) -> int:
  """See documentation for wait_for_new_checkpoint."""
  start = time.time()
  stop_time = start + timeout if timeout is not None else None

  def _sleep_and_maybe_exit():
    if stop_time is not None and time.time() + seconds_to_sleep > stop_time:
      if timeout_fn is None:
        return True
      elif timeout_fn():  # Only exit when timeout_fn indicates completion.
        return True
    logging.info('Sleeping for %d seconds.', seconds_to_sleep)
    time.sleep(seconds_to_sleep)
    return False

  log_str = f'Waiting for new checkpoint at {checkpoint_dir}. '
  if until_step is not None:
    log_str += f'Waiting until step {until_step} is reached. '
  if timeout is not None:
    log_str += f'Will time out after {timeout} seconds. '
  logging.info(log_str)

  result = -1
  if multihost.process_index() == 0:
    while True:
      if not checkpoint_dir.exists():
        if _sleep_and_maybe_exit():
          break
        continue  # continue waiting until directory creation or timeout.

      steps = utils.checkpoint_steps(checkpoint_dir)
      checkpoint_step = max(steps) if steps else None
      if _reached_desired_step(checkpoint_step, until_step):
        if not _snapshot_checkpoint(
            checkpoint_dir,
            checkpoint_step,
            step_name_format,
            snapshot_dir,
            set_immutable=set_immutable,
            ignore_file_not_found_error=ignore_snapshot_errors,
        ):
          continue
        result = checkpoint_step
        break
      elif _sleep_and_maybe_exit():
        break

  result = multihost.broadcast_one_to_all(np.int32(result)).item()
  wait_duration = time.time() - start
  jax.monitoring.record_event_duration_secs(
      '/jax/orbax/checkpoint_utils/wait_duration', wait_duration
  )

  if result == -1:
    logging.info('Timed out waiting for new checkpoint. Returning -1.')
  else:
    logging.info('Found new checkpoint step: %d.', result)
  return result


[docs] @contextlib.contextmanager def wait_for_new_checkpoint( checkpoint_dir: epath.Path, *, until_step: Optional[int] = None, seconds_to_sleep: int = 1, timeout: Optional[int] = None, timeout_fn: Optional[Callable[[], bool]] = None, step_prefix: Optional[str] = None, step_format_fixed_length: Optional[int] = None, step_name_format: Optional[step_lib.NameFormat[step_lib.Metadata]] = None, snapshot_dir: Optional[epath.Path] = None, set_immutable: bool | None = None, ignore_snapshot_errors: bool | None = True, ): """Waits until a new checkpoint file is found. Automatically snapshots any checkpoint that is returned, and releases the snapshot of the checkpoint when execution returns to this function. Args: checkpoint_dir: The directory in which checkpoints are saved. until_step: If specified, waits until a step greater than or equal to `until_step` has been found. If set to None (default), returns the first step found. seconds_to_sleep: The number of seconds to sleep for before looking for a new checkpoint. timeout: The maximum number of seconds to wait. If left as `None`, then the process will wait indefinitely. timeout_fn: Optional function to call after a timeout. If the function returns True, then it means that no new checkpoints will be generated and the iterator will exit. The function is called with no arguments. step_prefix: A prefix applied to step numbers (e.g. <prefix>_42). step_format_fixed_length: Expects to find checkpoint step directories with exactly this number of digits (leading zeros if necessary). step_name_format: Step NameFormat used to find step under given root directory. If provided, `step_prefix` and `step_format_fixed_length` are ignored. snapshot_dir: The directory in which snapshots are saved. If not provided, the snapshot directory is created as `checkpoint_dir / _SNAPSHOTS`. set_immutable: If True and applicable, sets the files in the snapshot to be immutable. ignore_snapshot_errors: If True, ignore errors when creating/releasing snapshots. Yields: a new checkpoint step, or -1 if the timeout was reached. """ step_name_format = _init_step_name_format( step_name_format=step_name_format, step_prefix=step_prefix, step_format_fixed_length=step_format_fixed_length, ) step = _wait_for_new_checkpoint( checkpoint_dir, step_name_format=step_name_format, until_step=until_step, seconds_to_sleep=seconds_to_sleep, timeout=timeout, timeout_fn=timeout_fn, snapshot_dir=snapshot_dir, set_immutable=set_immutable, ignore_snapshot_errors=ignore_snapshot_errors, ) try: yield step finally: # Release snapshot on the checkpoint step. if step != -1: logging.info( 'Releasing snapshot for step: %d after releasing control.', step ) _release_snapshot( checkpoint_dir, step, step_name_format, snapshot_dir, ignore_file_not_found_error=ignore_snapshot_errors, )
[docs] def checkpoints_iterator( checkpoint_dir: epath.PathLike, *, min_interval_secs: int = 0, seconds_to_sleep: int = 1, timeout: Optional[int] = None, timeout_fn: Optional[Callable[[], bool]] = None, step_prefix: Optional[str] = None, step_format_fixed_length: Optional[int] = None, step_name_format: Optional[step_lib.NameFormat[step_lib.Metadata]] = None, snapshot_dir: Optional[epath.Path] = None, set_immutable: bool | None = None, ignore_snapshot_errors: bool | None = True, ) -> Iterator[int]: """Continuously yield new checkpoint files as they appear. Based on the equivalent TF method. The iterator only checks for new checkpoints when control flow has been reverted to it. This means it can miss checkpoints if your code takes longer to run between iterations than `min_interval_secs` or the interval at which new checkpoints are written. Warning: If CheckpointManager is running in a different process for training and is cleaning up old checkpoints (via the `max_to_keep` argument), steps returned by this function may not be valid after being clean up by another process. In this case, `max_to_keep` should be increased (suggested value: 5) The `timeout` argument is the maximum number of seconds to block waiting for a new checkpoint. It is used in combination with the `timeout_fn` as follows: * If the timeout expires and no `timeout_fn` was specified, the iterator stops yielding. * If a `timeout_fn` was specified, that function is called and if it returns a true boolean value the iterator stops yielding. * If the function returns a false boolean value then the iterator resumes the wait for new checkpoints. At this point the timeout logic applies again. This behavior gives control to callers on what to do if checkpoints do not come fast enough or stop being generated. For example, if callers have a way to detect that the training has stopped and know that no new checkpoints will be generated, they can provide a `timeout_fn` that returns `True` when the training has stopped. If they know that the training is still going on they return `False` instead. Args: checkpoint_dir: The directory in which checkpoints are saved. min_interval_secs: The minimum number of seconds between yielding checkpoints. seconds_to_sleep: Seconds to sleep if a checkpoint is not found. Note the difference with min_interval_secs, which puts a lower bound on how when a new checkpoint will be looked for after yielding one checkpoint. seconds_to_sleep instead specifies how we should sleep for if no new checkpoints are found. Note that the timeout is only checked when not sleeping, so a `seconds_to_sleep` longer than the timeout would result in timing out after `seconds_to_sleep` seconds rather than `timeout` seconds. timeout: The maximum number of seconds to wait between checkpoints. The function will time out if `timeout` seconds have passed since a new checkpoint step was found. If left as `None`, then the process will wait indefinitely. timeout_fn: Optional function called after a timeout. If the function returns True, then it means that no new checkpoints will be generated and the iterator will exit. The function is called with no arguments. step_prefix: A prefix applied to step numbers (e.g. <prefix>_42). step_format_fixed_length: Expects to find checkpoint step directories with exactly this number of digits (leading zeros if necessary). step_name_format: Step NameFormat used to find step under given root directory. If provided, `step_prefix` and `step_format_fixed_length` are ignored. snapshot_dir: The directory in which snapshots are saved. If not provided, the snapshot directory is created as `checkpoint_dir / _SNAPSHOTS`. set_immutable: If True and applicable, sets the files in the snapshot to be immutable. ignore_snapshot_errors: If True, ignore errors when creating/releasing snapshots. Yields: Integer step numbers of the latest checkpoints as they arrive. """ checkpoint_dir = epath.Path(checkpoint_dir) step_name_format = _init_step_name_format( step_name_format=step_name_format, step_prefix=step_prefix, step_format_fixed_length=step_format_fixed_length, ) if snapshot_dir is None: snapshot_dir = checkpoint_dir / _SNAPSHOTS if snapshot_dir.exists(): for step_dir in snapshot_dir.iterdir(): snapshot_impl = snapshot_lib.create_instance(checkpoint_dir, step_dir) with _manage_snapshot_file_not_found( ignore_errors=ignore_snapshot_errors, formatted_message=( 'Ignoring error when cleaning up leftover snapshot:' f' {step_dir.name}' ), ): asyncio_utils.run_sync(snapshot_impl.release_snapshot()) checkpoint_step = None while True: until_step = checkpoint_step + 1 if checkpoint_step is not None else None with wait_for_new_checkpoint( checkpoint_dir, until_step=until_step, seconds_to_sleep=seconds_to_sleep, timeout=timeout, timeout_fn=timeout_fn, step_name_format=step_name_format, snapshot_dir=snapshot_dir, set_immutable=set_immutable, ignore_snapshot_errors=ignore_snapshot_errors, ) as new_checkpoint_step: if new_checkpoint_step == -1: if not timeout_fn: logging.info('Timed-out waiting for a checkpoint.') return if timeout_fn(): # The timeout_fn indicated that we are truly done. return else: # The timeout_fn indicated that more checkpoints may come. continue start = time.time() checkpoint_step = new_checkpoint_step yield checkpoint_step time_to_next_eval = start + min_interval_secs - time.time() if time_to_next_eval > 0: time.sleep(time_to_next_eval)
def _python_type_from_dtype(dtype): if hasattr(dtype, 'type'): return type(dtype.type(0).item()) else: return type(dtype(0).item())
[docs] def construct_restore_args( target: PyTree, sharding_tree: Optional[PyTree] = None, set_global_shape: bool = True, support_layout: bool = False, strict: bool = True, ) -> PyTree: """Creates restore_args given a target PyTree. This method should be used in conjunction with a CheckpointManager or Checkpointer that wraps a PyTreeCheckpointHandler. For example:: mngr = CheckpointManager(path, Checkpointer(PyTreeCheckpointHandler())) restore_args = construct_restore_args(train_state, train_state_sharding) restore_kwargs = {'restore_args': restore_args} mngr.restore(..., restore_kwargs=restore_kwargs) OR:: mngr = CheckpointManager(path, { 'train_state': Checkpointer(PyTreeCheckpointHandler()) }) restore_args = construct_restore_args(train_state, train_state_sharding) restore_kwargs = {'train_state': {'restore_args': restore_args} } mngr.restore(..., restore_kwargs=restore_kwargs) OR:: ckptr = Checkpointer(PyTreeCheckpointHandler()) restore_args = construct_restore_args(train_state, train_state_sharding) ckptr.restore(..., restore_args=restore_args) If a leaf in target is a np.ndarray, or int, or string, for example, a corresponding value for that leaf must be provided in axes_tree, but will be ignored. Args: target: The returned PyTree will match the structure of `target`. `target` may contain `value_metadata.Metadata`, real scalar or array values, or may contain jax.ShapeDtypeStruct. sharding_tree: A PyTree matching `target` which will be used to set the restoration sharding. If not provided, sharding will default to the shardings specified by `target`. set_global_shape: If true, set the `global_shape` field of ArrayRestoreArgs. support_layout: If true, layout is extracted from jax.Array or jax.ShapeDtypeStruct. strict: If False, allow padding/slicing for uneven sharding. Returns: A PyTree matching target of RestoreArgs (or ArrayRestoreArgs) objects. """ def _array_restore_args( value: Any, sharding: Optional[jax.sharding.Sharding | Format], # pytype: disable=unsupported-operands dtype: Optional[np.dtype] = None, ) -> type_handlers.ArrayRestoreArgs: global_shape = None # For random keys, we only allow overriding the sharding. if set_global_shape and not jax.dtypes.issubdtype( value.dtype, jax.dtypes.prng_key ): global_shape = value.shape return type_handlers.ArrayRestoreArgs( restore_type=jax.Array, sharding=sharding, global_shape=global_shape, dtype=dtype, strict=strict, ) def _restore_args( value: Any, sharding: Optional[jax.sharding.Sharding], ) -> type_handlers.RestoreArgs: if isinstance(value, jax.ShapeDtypeStruct): if sharding is None: return type_handlers.RestoreArgs( restore_type=np.ndarray, dtype=value.dtype ) else: return _array_restore_args(value, sharding, value.dtype) elif isinstance(value, value_metadata.Metadata): if isinstance(value, value_metadata.StringMetadata): return type_handlers.RestoreArgs(restore_type=str) elif isinstance(value, value_metadata.ScalarMetadata): return type_handlers.RestoreArgs( restore_type=_python_type_from_dtype(value.dtype), dtype=value.dtype ) elif isinstance(value, value_metadata.ArrayMetadata): if sharding is None: return type_handlers.RestoreArgs( restore_type=np.ndarray, dtype=value.dtype ) else: return _array_restore_args(value, sharding, value.dtype) else: raise ValueError(f'Unsupported value_metadata class: {type(value)}.') elif isinstance(value, STANDARD_ARRAY_TYPES): if isinstance(value, np.ndarray): return type_handlers.RestoreArgs( restore_type=type(value), dtype=value.dtype ) elif isinstance(value, jax.Array): return _array_restore_args(value, sharding, value.dtype) else: return type_handlers.RestoreArgs(restore_type=type(value)) elif isinstance(value, str): return type_handlers.RestoreArgs(restore_type=str) elif type_handlers.is_placeholder(value): return type_handlers.RestoreArgs(restore_type=type(PLACEHOLDER)) else: raise ValueError(f'Unsupported type: {type(value)}') def _get_sharding_or_layout(value): return arrays_sharding_lib.get_sharding_or_format( value, support_format=support_layout ) def _return_key_data(value): # replace jax.random.key with underneath jax.Array if isinstance(value, jax.Array) and jax.dtypes.issubdtype( value.dtype, jax.dtypes.prng_key ): # For random keys, extract the dtype and shape as a regular Jax array. # Stored metadata will help restoring the original random key. return jax.random.key_data(value) return value target = jax.tree.map(_return_key_data, target) if sharding_tree is None: sharding_tree = jax.tree.map(_get_sharding_or_layout, target) if isinstance(target, tree_metadata.TreeMetadata): return tree_metadata.build_default_tree_metadata( jax.tree.map(_restore_args, target.tree, sharding_tree.tree), custom_metadata=target.custom_metadata, ) else: return jax.tree.map(_restore_args, target, sharding_tree)