Source code for orbax.checkpoint.experimental.v1._src.training.checkpointer

# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Defines a class for managing a sequence of checkpoints in a training loop."""

from __future__ import annotations

import typing
from typing import Any, Callable, Iterable, Sequence

from absl import logging
from etils import epy
from orbax.checkpoint import checkpoint_manager
from orbax.checkpoint.experimental.v1._src.context import context as context_lib
import orbax.checkpoint.experimental.v1._src.handlers.global_registration  # pylint: disable=unused-import
from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout
from orbax.checkpoint.experimental.v1._src.loading import loading
from orbax.checkpoint.experimental.v1._src.metadata import loading as metadata_loading
from orbax.checkpoint.experimental.v1._src.metadata import types as metadata_types
from orbax.checkpoint.experimental.v1._src.path import step as path_step_lib
from orbax.checkpoint.experimental.v1._src.path import types as path_types
from orbax.checkpoint.experimental.v1._src.saving import saving
from orbax.checkpoint.experimental.v1._src.saving import validation
from orbax.checkpoint.experimental.v1._src.synchronization import thread_utils
from orbax.checkpoint.experimental.v1._src.synchronization import types as async_types
from orbax.checkpoint.experimental.v1._src.training import errors
from orbax.checkpoint.experimental.v1._src.training import preservation_policies
from orbax.checkpoint.experimental.v1._src.training import save_decision_policies
from orbax.checkpoint.experimental.v1._src.training.metadata import types as training_metadata_types
from orbax.checkpoint.experimental.v1._src.tree import types as tree_types
from typing_extensions import deprecated  # pytype: disable=not-supported-yet

CheckpointMetadata = training_metadata_types.CheckpointMetadata
RootMetadata = training_metadata_types.RootMetadata


STATE_CHECKPOINTABLE_KEY = checkpoint_layout.STATE_CHECKPOINTABLE_KEY


class _AsyncSaveResponse(async_types.AsyncResponse[bool]):
  """Response for asynchronous saving."""

  def __init__(
      self, manager: checkpoint_manager.CheckpointManager
  ):

    async def _wait() -> bool:
      # If a background operation fails wait_until_finished() will re-raise the
      # exception back to caller.
      manager.wait_until_finished()
      return True

    self._thread_runner = thread_utils.BackgroundThreadRunner[bool](_wait())

  def result(self, timeout: float | None = None) -> bool:
    return self._thread_runner.result(timeout=timeout)

  def on_complete(self, callback: Callable[[bool], None]) -> None:
    self._thread_runner.on_complete(callback)


def _resolve_integer_step(
    step: int | CheckpointMetadata,
) -> int:
  if isinstance(step, int):
    return step
  return step.step


[docs] @typing.final class Checkpointer(epy.ContextManager): """An object that manages a sequence of checkpoints in a training loop."""
[docs] def __init__( self, directory: path_types.PathLike, *, context: context_lib.Context | None = None, save_decision_policy: ( save_decision_policies.SaveDecisionPolicy | None ) = None, preservation_policy: ( preservation_policies.PreservationPolicy | None ) = None, step_name_format: ( path_step_lib.NameFormat[CheckpointMetadata[None]] | None ) = None, custom_metadata: tree_types.JsonType | None = None, cleanup_tmp_directories: bool = False, lightweight_initialize: bool = False, ): """Initializes a Checkpointer. IMPORTANT: This class is not thread safe. All APIs should be called across all available processes, from the main thread. The Checkpointer is intended for use in a training loop, where a sequence of checkpoints are saved at regular intervals. Example usage:: # Configure the frequency at which checkpoints are saved. save_decision_policies = ocp.training.save_decision_policies # Save every 1000 steps, or when a preemption is detected. save_decision_policy = save_decision_policies.AnySavePolicy([ save_decision_policies.FixedIntervalPolicy(1000), save_decision_policies.PreemptionPolicy(), ]) # Configure the checkpoints to preserve (avoid garbage collection). preservation_policies = ocp.training.preservation_policies # Avoid garbage collection on the latest 10, or every 10000 steps. preservation_policy = preservation_policies.AnyPreservationPolicy([ preservation_policies.LatestN(10), preservation_policies.EveryNSteps(10000), ]) with ocp.training.Checkpointer( directory, save_decision_policy=save_decision_policy, preservation_policy=preservation_policy, ) as ckptr: if ckptr.latest is None: model_state = init_from_scratch(rng) else: model_state = ckptr.load() # Loads latest checkpoint. # Note: prefer to specify the abstract tree if available. model_state = ckptr.load( ckptr.latest, abstract_state=abstract_model_state) start_step = ckptr.latest.step if ckptr.latest else 0 for step in range(start_step, num_steps): model_state = train_step(model_state) # Saves a checkpoint if needed (according to `save_decision_policy`). ckptr.save(step, model_state) Prefer to use the context manager style as shown above, which ensures that the Checkpointer is closed properly and any outstanding async operations are completed. Args: directory: The root directory where checkpoints are stored. The directory will be created if it does not exist. context: A :py:class:`~orbax.checkpoint.v1.Context` object that will be used to wrap all function calls for this `Checkpointer`. save_decision_policy: A policy used to determine when a checkpoint should be saved. If not provided, the `Checkpointer` saves as often as possible by default (assuming no checkpoint is currently being saved), and saves when a preemption is detected by the JAX distributed system. preservation_policy: A policy used to determine when a checkpoint should be preserved. Any checkpoints not preserved are garbage collected. If not provided, step_name_format: An object used to specify the format for step paths. By default, steps are rendered as simple integers, like `/root/directory/<step>`. custom_metadata: A JSON dictionary representing user-specified custom metadata. This should be information that is relevant to the entire sequence of checkpoints, rather than to any single checkpoint. cleanup_tmp_directories: If True, cleans up any existing temporary directories on Checkpointer creation. lightweight_initialize: If True, checkpoint step metadata is not read on Checkpointer initialization during checkpoint info loading. This is useful to improve init performance when there are O(1k) or more existing checkpoint steps present and checkpoint info properties like `time` and `metrics` are not needed. """ self._context = context_lib.Context(context or context_lib.get_context()) default_save_decision_policy = save_decision_policies.AnySavePolicy([ save_decision_policies.InitialSavePolicy(), save_decision_policies.ContinuousCheckpointingPolicy(), save_decision_policies.PreemptionCheckpointingPolicy(), ]) save_decision_policy = save_decision_policy or default_save_decision_policy default_preservation_policy = preservation_policies.PreserveAll() preservation_policy = preservation_policy or default_preservation_policy self._step_name_format = ( step_name_format or path_step_lib.standard_name_format() ) options = checkpoint_manager.CheckpointManagerOptions( save_decision_policy=save_decision_policy, preservation_policy=preservation_policy, step_name_format=step_name_format, cleanup_tmp_directories=cleanup_tmp_directories, lightweight_initialize=lightweight_initialize, max_to_keep=None, # Unlimited. todelete_full_path=self._context.deletion_options.gcs_deletion_options.todelete_full_path, async_options=self._context.async_options.v0(), file_options=self._context.file_options.v0(), multiprocessing_options=self._context.multiprocessing_options.v0(), temporary_path_class=self._context.file_options.temporary_path_class, # Prevent the checkpoint manager from writing metrics on its own. This # class will take responsibility for writing metrics. prevent_write_metrics=True, ) self._manager = checkpoint_manager.CheckpointManager( directory, options=options, metadata=custom_metadata, )
@property def directory(self) -> path_types.Path: """The root directory where checkpoint steps are located.""" return self._manager.directory @property def latest(self) -> CheckpointMetadata[None] | None: """Returns the latest :py:class:`.CheckpointMetadata`, or None if no checkpoints exist. See `checkpoints` documentation below. Returns: The latest checkpoint, or None if no checkpoints exist. """ if not self.checkpoints: return None return self.checkpoints[-1] @property def checkpoints(self) -> Sequence[CheckpointMetadata[None]]: """Returns a list of :py:class:`.CheckpointMetadata`, sorted ascending by step. The method returns a list of :py:class:`.CheckpointMetadata` objects, which contain selected properties describing the checkpoint. Contrast this with the methods :py:func:`.metadata` and :py:func:`.checkpointables_metadata`, which may perform a more expensive disk read to retrieve additional information. This method only returns cheap cacheable properties like step and timestamp. The return value is annotated as :py:class:`.CheckpointMetadata[None]` because the core `metadata` property is not retrieved, and is therefore `None`. The property is cached to avoid repeated disk reads. This is not a problem unless checkpoints are manually deleted, or deleted by some other job or class that `Checkpointer` is unaware of. Note that doing this is discouraged. Returns: A list of checkpoints, sorted ascending by step. """ infos = sorted(self._manager._checkpoints, key=lambda info: info.step) # pylint: disable=protected-access return [ CheckpointMetadata[None]( info.step, path=self.directory / self._step_name_format.build_name(info.step), metadata=None, metrics=info.metrics, commit_timestamp_nsecs=int(info.time.timestamp() * 1e9), ) for info in infos ] def _resolve_existing_checkpoint( self, step: int | CheckpointMetadata | None ) -> CheckpointMetadata[None]: if step is None: latest = self.latest if latest is None: raise errors.StepNotFoundError( 'Specified `step=None`, but no checkpoints were found.' ) return latest step = _resolve_integer_step(step) for checkpoint in self.checkpoints: if checkpoint.step == step: return checkpoint raise errors.StepNotFoundError(f'No checkpoint found at step {step}.')
[docs] def should_save(self, step: int) -> bool: """Returns whether a checkpoint should be saved at the given step.""" with context_lib.get_context(self._context): step = _resolve_integer_step(step) return self._manager.should_save(step)
[docs] def save( self, step: int, state: tree_types.PyTreeOf[tree_types.Leaf], *, checkpointable_name: str = STATE_CHECKPOINTABLE_KEY, force: bool = False, overwrite: bool = False, metrics: tree_types.JsonType | None = None, custom_metadata: tree_types.JsonType | None = None, ) -> bool: """Saves a checkpoint, if dictated by :py:class:`.SaveDecisionPolicy`. This method behaves similarly to the standalone free function :py:func:`~orbax.checkpoint.v1.save` (see documentation), but performs additional tasks related to managing a sequence of checkpoint steps. It consists roughly of the following steps: - Check whether a checkpoint should be saved at the given step. - Check whether a save is already in progress. If so, wait for it to finish. - Save to a directory given by `root_directory / <step_format>`. - Perform garbage collection if necessary. - Return whether a checkpoint was saved or not. It is important to note that the `Checkpointer` never allows saving more than one checkpoint at a time. Depending on the :py:class:`.SaveDecisionPolicy`, a checkpoint may be saved or skipped at a given step, but if a save is initiated, as dictated by the policy, then it will proceed as normal as long as no other save is currently in progress. If a save is already in progress, the function will block until the previous save has finished. Example usage: 1. Basic Usage: Save a PyTree at a specific training step. The checkpointer automatically manages the step-based directory structure inside your root folder:: from orbax.checkpoint.v1 import training # Initialize the checkpointer for a directory ckptr = training.Checkpointer(directory) # Save the tree at step 0. saved = ckptr.save(step=0, state=tree) # Clean up background threads gracefully when the training loop ends ckptr.close() 2. Advanced Saving with Metrics and Metadata: Attach JSON-serializable metrics (like loss/accuracy) and custom metadata to a specific step for thorough experiment tracking:: from orbax.checkpoint.v1 import training ckptr = training.Checkpointer(directory) ckptr.save( step=1, state=tree, metrics={'loss': 0.12, 'accuracy': 0.95}, custom_metadata={'description': 'Model after epoch 1'}, ) ckptr.close() Args: step: The step number to save. state: The PyTree to save. checkpointable_name: The name of the checkpointable to save a pytree under. Defaults to 'pytree'. force: If True, ignores all :py:class:`.SaveDecisionPolicy` checks, and always decides to save a checkpoint. overwrite: If True, deletes any existing checkpoint at the given step before saving. Otherwise, raises an error if the checkpoint already exists. metrics: A PyTree of metrics to be saved with the checkpoint. custom_metadata: A JSON dictionary representing user-specified custom metadata. This should be information that is relevant to the checkpoint at the given step, rather than to the entire sequence of checkpoints. Returns: Whether a checkpoint was saved or not. """ response = self.save_async( step, state, checkpointable_name=checkpointable_name, force=force, overwrite=overwrite, metrics=metrics, custom_metadata=custom_metadata, ) if response is None: return False return response.result()
[docs] def save_checkpointables( self, step: int, checkpointables: dict[str, Any], *, force: bool = False, overwrite: bool = False, metrics: tree_types.JsonType | None = None, custom_metadata: tree_types.JsonType | None = None, ) -> bool: """Saves a dictionary of checkpointable objects at the given step. This method saves a dictionary of checkpointable objects, mapping string names to values. See `the guide on Checkpointables <https://orbax.readthedocs.io/en/latest/guides/checkpoint/v1/checkpointables.html>`_ for more details on checkpointables. Also see documentation for :py:func:`~orbax.checkpoint.v1.save`. Example: 1. Basic Usage: Save multiple named items (checkpointables) at a specific step. The dictionary keys define the names of the saved components:: from orbax.checkpoint.v1 import training # Initialize the checkpointer for a directory ckptr = training.Checkpointer(directory) # Save multiple items, such as model weights and optimizer state items_to_save = { 'model': my_model_state, 'optimizer': my_opt_state, } saved = ckptr.save_checkpointables( step=0, checkpointables=items_to_save ) # Clean up background threads gracefully when the training loop ends ckptr.close() 2. Advanced Saving with Metrics and Metadata: Attach JSON-serializable metrics and custom metadata to a specific step for thorough experiment tracking:: from orbax.checkpoint.v1 import training ckptr = training.Checkpointer(directory) items_to_save = {'model': my_model_state} ckptr.save_checkpointables( step=1, checkpointables=items_to_save, metrics={'loss': 0.12, 'accuracy': 0.95}, custom_metadata={'description': 'Model after epoch 1'}, ) ckptr.close() Args: step: The step number to save. checkpointables: A dictionary mapping string names to the corresponding objects (checkpointables) that need to be saved. force: If True, ignores all policy checks and always decides to save a checkpoint. overwrite: If True, deletes any existing checkpoint at the given step before saving. Otherwise, raises an error if the checkpoint already exists. metrics: A dictionary of metrics to be saved with the checkpoint. Must be JSON-serializable. custom_metadata: A JSON dictionary representing user-specified custom metadata relevant to the checkpoint at this specific step. Returns: bool: True if the checkpoint was successfully saved, False otherwise. """ response = self.save_checkpointables_async( step, checkpointables, force=force, overwrite=overwrite, metrics=metrics, custom_metadata=custom_metadata, ) if response is None: return False return response.result()
[docs] def save_async( self, step: int, state: tree_types.PyTreeOf[tree_types.Leaf], *, checkpointable_name: str = STATE_CHECKPOINTABLE_KEY, force: bool = False, overwrite: bool = False, metrics: tree_types.JsonType | None = None, custom_metadata: tree_types.JsonType | None = None, ) -> async_types.AsyncResponse[bool] | None: """Saves a checkpoint asynchronously. This function is the asynchronous equivalent of :py:meth:`~.save`. It accepts the exact same arguments; please refer to that method for detailed descriptions. This method executes mostly in the background, blocking the main thread for as little time as possible. Example: :: async_response = ckptr.save_async(step=0, state=tree) if async_response is not None: saved = async_response.result() Args: step: The step number to save. state: The PyTree to save. checkpointable_name: The name of the checkpointable to save a pytree under. Defaults to 'pytree'. force: See `save`. overwrite: See `save`. metrics: See `save`. custom_metadata: See `save`. Returns: An `AsyncResponse`, which can be awaited via `result()`, which returns a bool indicating whether a checkpoint was saved or not, or None if the save was skipped by policy. """ return self.save_checkpointables_async( step, {checkpointable_name: state}, force=force, overwrite=overwrite, metrics=metrics, custom_metadata=custom_metadata, )
[docs] def save_checkpointables_async( self, step: int, checkpointables: dict[str, Any], *, force: bool = False, overwrite: bool = False, metrics: tree_types.JsonType | None = None, custom_metadata: tree_types.JsonType | None = None, ) -> async_types.AsyncResponse[bool] | None: """Saves checkpointable objects asynchronously. This function is the asynchronous equivalent of :py:meth:`~.save_checkpointables`. Please refer to that method for detailed instructions and argument descriptions. Example: Save checkpointable objects asynchronously:: async_response = ckptr.save_checkpointables_async( step=0, checkpointables=items_to_save ) if async_response is not None: saved = async_response.result() Args: step: The step number to save. checkpointables: A dictionary mapping string names to objects to save. force: See `save_checkpointables`. overwrite: See `save_checkpointables`. metrics: See `save_checkpointables`. custom_metadata: See `save_checkpointables`. Returns: An object representing the background operation, or None if the save was skipped by policy. Call `.result()` on it to block and return a boolean indicating whether the checkpoint was successfully saved. Raises: StepAlreadyExistsError: If `overwrite` is False and a checkpoint at the target `step` already exists. """ context = context_lib.get_context(self._context) validation.validate_save_checkpointables(checkpointables) if overwrite: logging.info( 'Specified `overwrite`: deleting existing checkpoint %d if it' ' exists.', step, ) try: self._manager.delete(step) except FileNotFoundError: pass elif any(c.step == step for c in self.checkpoints): raise errors.StepAlreadyExistsError(f'Step {step} already exists.') with context: checkpointer, args = saving.get_v0_checkpointer_and_args( checkpointables, metrics=metrics ) self._manager._checkpointer = checkpointer # pylint: disable=protected-access save_initiated = self._manager.save( step, args=args, metrics=metrics, force=force, custom_metadata=custom_metadata, ) if not save_initiated: return None return _AsyncSaveResponse(self._manager)
[docs] def load( self, step: int | CheckpointMetadata | None = None, abstract_state: ( tree_types.PyTreeOf[tree_types.AbstractLeaf] | None ) = None, *, checkpointable_name: str = STATE_CHECKPOINTABLE_KEY, ) -> tree_types.PyTreeOf[tree_types.Leaf]: """Loads a PyTree checkpoint at the given step. This method behaves similarly to the standalone free function :py:func:`~orbax.checkpoint.v1.load`. **Note:** Loading a PyTree without providing an `abstract_state` is provided purely for convenience. For serious or production use cases, it is STRONGLY recommended to always provide an `abstract_state` to ensure the restored PyTree strictly matches the expected shapes, dtypes, and sharding. Example: 1. Basic Loading: Load a PyTree without providing an abstract structure. By passing `step=None` (or omitting it), it automatically loads the latest step:: from orbax.checkpoint.v1 import training # Initialize the checkpointer for the directory ckptr = training.Checkpointer(directory) # Load the saved PyTree from latest step restored_tree = ckptr.load(step=None) 2. Loading with an Abstract PyTree: Provide an abstract structure (such as target shapes and dtypes) to ensure the restored PyTree is safely and correctly formatted:: import jax import jax.numpy as jnp from orbax.checkpoint.v1 import training ckptr = training.Checkpointer(directory) # Define the expected structure (shapes and dtypes) to restore into target_structure = { 'weights': jax.ShapeDtypeStruct((128, 128), dtype=jnp.float32), 'bias': jax.ShapeDtypeStruct((128,), dtype=jnp.float32) } # Restore exactly matching the target structure restored_tree = ckptr.load( step=1, abstract_state=target_structure ) Args: step: The step number or :py:class:`.CheckpointMetadata` to load. If None, the checkpointer will attempt to resolve and load the latest existing checkpoint. abstract_state: The abstract PyTree to load. checkpointable_name: The name of the checkpointable to load a pytree under. Defaults to 'pytree'. Returns: The loaded PyTree. """ return self.load_checkpointables( step, {checkpointable_name: abstract_state} )[checkpointable_name]
[docs] def load_checkpointables( self, step: int | CheckpointMetadata | None = None, abstract_checkpointables: dict[str, Any] | None = None, ) -> dict[str, Any]: """Loads a set of checkpointables at the given step. This method behaves similarly to the standalone free function :py:func:`~orbax.checkpoint.v1.load_checkpointables`. This function retrieves multiple named items (such as model weights or optimizer states) from a specific checkpoint directory. If no step is provided, it automatically resolves to and loads the most recently saved checkpoint. **Note:** Loading without providing an `abstract_checkpointables` dictionary is provided purely for convenience. For serious or production use cases, it is STRONGLY recommended to always provide `abstract_checkpointables` to ensure the restored items strictly match the exact nested structures, shapes, and data types expected. Example: 1. Basic Loading: Load multiple named items (such as a model and optimizer) from a specific step. If step is omitted, it resolves to the latest available checkpoint:: from orbax.checkpoint.v1 import training # Initialize the checkpointer for the directory ckptr = training.Checkpointer(directory) # Load all checkpointables saved at the latest step restored_items = ckptr.load_checkpointables(step=None) # Access the individual components by their original string keys my_model = restored_items["model"] my_opt = restored_items["optimizer"] 2. Loading with Abstract Checkpointables (Recommended): Provide a dictionary of abstract structures to ensure the restored items strictly match your expected shapes and data types:: import jax import jax.numpy as jnp from orbax.checkpoint.v1 import training ckptr = training.Checkpointer(directory) # Define the expected structure for each named item using JAX arrays target_items = { "model": { 'weights': jax.ShapeDtypeStruct((128, 128), jnp.float32), 'bias': jax.ShapeDtypeStruct((128,), jnp.float32) }, "optimizer": { 'momentum': jax.ShapeDtypeStruct((128, 128), jnp.float32) } } # Restore exactly matching the target structures restored_items = ckptr.load_checkpointables( step=1, abstract_checkpointables=target_items ) 3. Partial Loading: If you only need to load a subset of checkpointables (e.g., loading model weights but omitting optimizer state), you can provide an `abstract_checkpointables` dictionary containing only the keys for the items you wish to restore:: import jax import jax.numpy as jnp from orbax.checkpoint.v1 import training ckptr = training.Checkpointer(directory) # Define abstract structure for ONLY the items to load target_items = { "model": { 'weights': jax.ShapeDtypeStruct((128, 128), jnp.float32), 'bias': jax.ShapeDtypeStruct((128,), jnp.float32) }, } # Load only "model", omitting "optimizer" restored_items = ckptr.load_checkpointables( step=1, abstract_checkpointables=target_items ) my_model = restored_items["model"] # my_opt = restored_items["optimizer"] Args: step: The step number or :py:class:`.CheckpointMetadata` to load. If None, the checkpointer will attempt to resolve and load the latest existing checkpoint. abstract_checkpointables: A dictionary mapping string names to their corresponding abstract structures (e.g., target PyTrees). This guides the loading process to ensure shape and type compliance. If provided, it can be used to load only a subset of checkpointables by providing only a subset of keys. Returns: dict[str, Any]: A dictionary containing the loaded checkpointable objects, keyed by string names. If `abstract_checkpointables` was specified, returns only the keys specified in that dict, otherwise returns all keys saved with `save_checkpointables`. """ with context_lib.get_context(self._context): step = self._resolve_existing_checkpoint(step).step return loading.load_checkpointables( self.directory / self._step_name_format.build_name(step), abstract_checkpointables, )
[docs] def load_async( self, step: int | CheckpointMetadata | None = None, abstract_state: ( tree_types.PyTreeOf[tree_types.AbstractLeaf] | None ) = None, ) -> async_types.AsyncResponse[tree_types.PyTreeOf[tree_types.Leaf]]: """Not yet supported.""" raise NotImplementedError()
[docs] def load_checkpointables_async( self, step: int | CheckpointMetadata | None = None, abstract_checkpointables: dict[str, Any] | None = None, ) -> async_types.AsyncResponse[dict[str, Any]]: """Loads a set of checkpointables asynchronously at the given step.""" raise NotImplementedError()
[docs] def metadata( self, step: int | CheckpointMetadata | None = None ) -> training_metadata_types.CheckpointMetadata[ metadata_types.PyTreeMetadata ]: """Returns checkpoint metadata for the given step. Retrieves metadata describing the structure of the PyTree stored at the given step. If no step is provided, the method resolves to the latest available checkpoint. Args: step: The step number to retrieve metadata for. If `None`, the latest step is used. Can also be a :py:class:`.CheckpointMetadata` object, from which the step is extracted. Returns: A :py:class:`.CheckpointMetadata` object containing :py:class:`.PyTreeMetadata`, along with checkpoint timestamp and metrics information. """ with context_lib.get_context(self._context): checkpoint = self._resolve_existing_checkpoint(step) del step checkpoint_metadata = metadata_loading.metadata( self._manager.directory / self._step_name_format.build_name(checkpoint.step) ) return training_metadata_types.CheckpointMetadata[ metadata_types.PyTreeMetadata ]( step=checkpoint.step, path=checkpoint_metadata.path, metadata=checkpoint_metadata.metadata, init_timestamp_nsecs=checkpoint_metadata.init_timestamp_nsecs, commit_timestamp_nsecs=checkpoint_metadata.commit_timestamp_nsecs, custom_metadata=checkpoint_metadata.custom_metadata, metrics=checkpoint.metrics, )
[docs] def checkpointables_metadata( self, step: int | CheckpointMetadata | None = None ) -> training_metadata_types.CheckpointMetadata[dict[str, Any]]: """Returns checkpoint metadata for the given step. Retrieves metadata describing the structure of the checkpointables stored at the given step. If no step is provided, the method resolves to the latest available checkpoint. Args: step: The step number to retrieve metadata for. If `None`, the latest step is used. Can also be a :py:class:`.CheckpointMetadata` object, from which the step is extracted. Returns: A :py:class:`.CheckpointMetadata` object containing a `dict[str, Any]` describing the checkpointables, along with checkpoint timestamp and metrics information. """ with context_lib.get_context(self._context): checkpoint = self._resolve_existing_checkpoint(step) del step checkpoint_metadata = metadata_loading.checkpointables_metadata( self._manager.directory / self._step_name_format.build_name(checkpoint.step) ) return training_metadata_types.CheckpointMetadata[dict[str, Any]]( step=checkpoint.step, path=checkpoint_metadata.path, metadata=checkpoint_metadata.metadata, init_timestamp_nsecs=checkpoint_metadata.init_timestamp_nsecs, commit_timestamp_nsecs=checkpoint_metadata.commit_timestamp_nsecs, custom_metadata=checkpoint_metadata.custom_metadata, metrics=checkpoint.metrics, )
def root_metadata( self, ) -> training_metadata_types.RootMetadata: with context_lib.get_context(self._context): metadata = self._manager.metadata(None) return RootMetadata( directory=self.directory, custom_metadata=metadata.custom_metadata )
[docs] def reload(self): """Reloads internal properties from the root directory. Updates the list of available checkpoints by rescanning the storage location. Use this method to sync the checkpointer with the file system if checkpoints have been added or removed externally. """ self._manager.reload()
[docs] def is_saving_in_progress(self) -> bool: """Returns whether a checkpoint save operation is currently in progress. Checks if there are any background persistence operations currently active. Returns: `True` if a save operation is in progress, `False` otherwise. """ return self._manager.is_saving_in_progress()
[docs] def wait(self): """Waits for any outstanding async operations to complete. This method blocks until all background tasks, such as asynchronous saves, have finished. Use this method to ensure that all operations are finalized before proceeding with dependent actions. """ self._manager.wait_until_finished()
[docs] def close(self): """Waits for pending async operations to complete and releases resources. This method blocks until all background tasks, such as asynchronous saves, have finished. It also performs necessary cleanup, such as closing file handles. """ self._manager.close()
def __contextmanager__( self, ) -> Iterable[Checkpointer]: try: yield self finally: self.close() @deprecated('Use `save` instead.') def save_pytree(self, *args, **kwargs): return self.save(*args, **kwargs) @deprecated('Use `save_async` instead.') def save_pytree_async(self, *args, **kwargs): return self.save_async(*args, **kwargs) @deprecated('Use `load` instead.') def load_pytree(self, *args, **kwargs): return self.load(*args, **kwargs) @deprecated('Use `load_async` instead.') def load_pytree_async(self, *args, **kwargs): return self.load_async(*args, **kwargs) @deprecated('Use `metadata` instead.') def pytree_metadata(self, *args, **kwargs): return self.metadata(*args, **kwargs)