Source code for orbax.checkpoint._src.checkpointers.checkpointer

# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Synchronous Checkpointer implementation."""

import time
from typing import Any, Iterable, Optional, Type

from absl import logging
from etils import epath
from etils import epy
import jax
from orbax.checkpoint import checkpoint_args
from orbax.checkpoint import options as options_lib
from orbax.checkpoint._src import asyncio_utils
from orbax.checkpoint._src.checkpointers import abstract_checkpointer
from orbax.checkpoint._src.futures import synchronization
from orbax.checkpoint._src.handlers import checkpoint_handler
from orbax.checkpoint._src.handlers import composite_checkpoint_handler
from orbax.checkpoint._src.logging import event_tracking
from orbax.checkpoint._src.metadata import checkpoint
from orbax.checkpoint._src.metadata import step_metadata_serialization
from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint._src.path import atomicity
from orbax.checkpoint._src.path import atomicity_defaults
from orbax.checkpoint._src.path import atomicity_types
from orbax.checkpoint._src.path import gcs_utils
from orbax.checkpoint._src.path import step as step_lib
from typing_extensions import Self  # for Python version < 3.11



CheckpointArgs = checkpoint_args.CheckpointArgs
register_with_handler = checkpoint_args.register_with_handler
get_legacy_handler_wrapper = (
    composite_checkpoint_handler.get_legacy_handler_wrapper
)
StepMetadata = checkpoint.StepMetadata


def construct_checkpoint_args(
    handler: checkpoint_handler.CheckpointHandler,
    for_save: bool,
    *args,
    **kwargs,
) -> checkpoint_args.CheckpointArgs:
  """Constructs `CheckpointArgs` for save or restore for the handler."""
  for arg in args:
    if isinstance(arg, checkpoint_args.CheckpointArgs):
      return arg
  for arg in kwargs.values():
    if isinstance(arg, checkpoint_args.CheckpointArgs):
      return arg
  jax.monitoring.record_event('/jax/orbax/deprecation/checkpointer_legacy_args')
  save_arg_cls, restore_arg_cls = checkpoint_args.get_registered_args_cls(
      handler
  )
  if for_save:
    return save_arg_cls(*args, **kwargs)
  else:
    return restore_arg_cls(*args, **kwargs)


[docs] class Checkpointer( abstract_checkpointer.AbstractCheckpointer, epy.ContextManager ): """A synchronous implementation of AbstractCheckpointer. This class saves synchronously to a given directory using an underlying :py:class:`.CheckpointHandler`. Atomicity of the operation is guaranteed. IMPORTANT: Async checkpointing can often be faster for saving. Strongly consider using :py:class:`.AsyncCheckpointer` instead. IMPORTANT: Remember that to save and restore a checkpoint, one should always use an :py:class:`.AbstractCheckpointer` coupled with a `CheckpointHandler`. The specific `CheckpointHandler` to use depends on the object being saved or restored. Basic example:: ckptr = Checkpointer(StandardCheckpointHandler()) args = ocp.args.StandardSave(item=pytree_of_arrays) ckptr.save(path, args=args) args = ocp.args.StandardRestore(item=abstract_pytree_target) ckptr.restore(path, args=args) Each handler includes `...SaveArgs` and `...RestoreArgs` classes that document what arguments are expected. When using `Checkpointer`, you can either use this dataclass directly, or you can provide the arguments in keyword form. For example:: ckptr = Checkpointer(StandardCheckpointHandler()) ckptr.save(path, state=pytree_of_arays) ckptr.restore(path, state=abstract_pytree_target) """
[docs] def __init__( self, handler: checkpoint_handler.CheckpointHandler, *, multiprocessing_options: options_lib.MultiprocessingOptions = options_lib.MultiprocessingOptions(), file_options: options_lib.FileOptions = options_lib.FileOptions(), checkpoint_metadata_store: Optional[checkpoint.MetadataStore] = None, temporary_path_class: Optional[ Type[atomicity_types.TemporaryPath] ] = None, ): if not checkpoint_args.has_registered_args(handler): logging.warning( 'No registered CheckpointArgs found for handler type: %s', type(handler), ) handler = get_legacy_handler_wrapper(handler) self._handler = handler self._primary_host = multiprocessing_options.primary_host self._active_processes = multiprocessing_options.active_processes self._barrier_sync_key_prefix = ( multiprocessing_options.barrier_sync_key_prefix ) self._multiprocessing_options = options_lib.MultiprocessingOptions( primary_host=self._primary_host, active_processes=self._active_processes, barrier_sync_key_prefix=self._barrier_sync_key_prefix, ) self._file_options = file_options self._temporary_path_class = temporary_path_class # If not provided then use checkpoint_metadata_store with blocking writes. self._metadata_store = ( checkpoint_metadata_store or checkpoint.metadata_store(enable_write=True, blocking_write=True) ) if not self._metadata_store.is_blocking_writer(): raise ValueError('Checkpoint metadata store must be blocking writer.') jax.monitoring.record_event('/jax/orbax/checkpointer/init')
def get_temporary_path( self, directory: epath.Path ) -> atomicity_types.TemporaryPath: temporary_path_class = ( self._temporary_path_class or atomicity_defaults.get_default_temporary_path_class(directory) ) tmpdir = temporary_path_class.from_final( directory, checkpoint_metadata_store=self._metadata_store, file_options=self._file_options, ) return tmpdir async def create_temporary_path( self, temporary_path: atomicity_types.TemporaryPath ): await atomicity.create_all( [temporary_path], multiprocessing_options=self._multiprocessing_options, ) def synchronize_next_awaitable_signal_operation_id(self): synchronization.OperationIdGenerator.next_operation_id() multihost.sync_global_processes( multihost.unique_barrier_key( 'next_awaitable_signal_operation_id:sync', prefix=self._barrier_sync_key_prefix, ), timeout=multihost.coordination_timeout(), processes=self._active_processes, ) logging.vlog( 1, '[process=%s] Synchronized next awaitable signal operation id to %s', multihost.process_index(), synchronization.OperationIdGenerator.get_current_operation_id(), )
[docs] def save( self, directory: epath.PathLike, *args, force: bool = False, custom_metadata: dict[str, Any] | None = None, **kwargs, ): """Saves the given item to the provided directory. Delegates to the underlying CheckpointHandler. Ensures save operation atomicity. This method should be called by all hosts - process synchronization and actions that need to be performed on only one host are managed internally. Args: directory: a path to which to save. *args: additional args to provide to the CheckpointHandler's save method. force: if True, allows overwriting an existing directory. May add overhead due to the need to delete any existing files. custom_metadata: a dictionary of custom metadata to be written to the checkpoint directory via StepMetadata. **kwargs: additional keyword args to provide to the CheckpointHandler's save method. Raises: ValueError if the provided directory already exists. """ checkpoint_start_time = time.time() multihost.sync_global_processes( multihost.unique_barrier_key( 'Checkpointer:save_start', prefix=self._barrier_sync_key_prefix, ), processes=self._active_processes, record_event_name=( '/jax/orbax/write/checkpoint_start_sync_duration_secs' ), ) directory = epath.Path(directory) operation_recorder = event_tracking.OperationRecorder( directory, operation_type=event_tracking.OperationType.SAVE, async_origin=False, primary_host=self._primary_host, ) operation_recorder.record_start() self.synchronize_next_awaitable_signal_operation_id() if directory.exists(): if force: if multihost.is_primary_host(self._primary_host): logging.info('Specified `force`: removing existing directory.') # TODO(b/493110683): Cleanup with refactoring of HNS GCS logic into # StorageBackend. if gcs_utils.is_gcs_path(directory): gcs_utils.rmtree(directory) else: directory.rmtree() # Post-sync handled by create_tmp_directory. else: raise ValueError(f'Destination {directory} already exists.') ckpt_args = construct_checkpoint_args(self._handler, True, *args, **kwargs) tmpdir = self.get_temporary_path(directory) # tmpdir creation also does an initial StepMetadata save. asyncio_utils.run_sync(self.create_temporary_path(tmpdir)) self._handler.save(tmpdir.get(), args=ckpt_args) if multihost.is_primary_host(self._primary_host): # Update StepMetadata after the handler save is complete. (blocking write) self._save_step_metadata(tmpdir.get(), custom_metadata=custom_metadata) multihost.sync_global_processes( multihost.unique_barrier_key( 'Checkpointer:save', prefix=self._barrier_sync_key_prefix, ), processes=self._active_processes, ) operation_recorder.record_blocking_completion( time.time() - checkpoint_start_time ) # Ensure save operation atomicity and record time saved by checkpoint. if multihost.is_primary_host(self._primary_host): # finalize does a final StepMetadata update. self._handler.finalize(tmpdir.get()) asyncio_utils.run_sync( atomicity.on_commit_callback( tmpdir, checkpoint_start_time=checkpoint_start_time, ) ) multihost.sync_global_processes( multihost.unique_barrier_key( 'Checkpointer:finalize', prefix=self._barrier_sync_key_prefix, # suffix=directory.name, ), processes=self._active_processes, ) operation_recorder.record_completion(time.time() - checkpoint_start_time)
[docs] def restore(self, directory: epath.PathLike, *args, **kwargs) -> Any: """See superclass documentation.""" restore_start_time = time.time() directory = epath.Path(directory) operation_recorder = event_tracking.OperationRecorder( directory, operation_type=event_tracking.OperationType.LOAD, async_origin=False, primary_host=self._primary_host, ) operation_recorder.record_start() if not directory.exists(): raise FileNotFoundError(f'Checkpoint at {directory} not found.') if not step_lib.is_path_finalized(directory): raise ValueError(f'Found incomplete checkpoint at {directory}.') logging.info('Restoring checkpoint from %s.', directory) ckpt_args = construct_checkpoint_args(self._handler, False, *args, **kwargs) operation_recorder.record_blocking_completion( time.time() - restore_start_time ) restored = self._restore(directory, args=ckpt_args) multihost.sync_global_processes( multihost.unique_barrier_key( 'Checkpointer:restore', prefix=self._barrier_sync_key_prefix, ), processes=self._active_processes, ) operation_recorder.record_completion(time.time() - restore_start_time) return restored
def _restore( self, directory: epath.PathLike, args: checkpoint_args.CheckpointArgs ) -> Any: return self._handler.restore(directory, args=args)
[docs] def metadata(self, directory: epath.PathLike) -> StepMetadata: """See superclass documentation.""" directory = epath.Path(directory) if isinstance( self._handler, composite_checkpoint_handler.CompositeCheckpointHandler ): # CompositeHandler contains metadata for all items in this step. return self._handler.metadata(directory) serialized_metadata = self._metadata_store.read( checkpoint.step_metadata_file_path(directory) ) if serialized_metadata is None: step_metadata = StepMetadata() else: step_metadata = step_metadata_serialization.deserialize( serialized_metadata ) try: step_metadata.item_metadata = self._handler.metadata(directory) except (FileNotFoundError, NotImplementedError, ValueError, TypeError): logging.warning( 'Failed to get item metadata from directory %s. Either it was not' ' present in the checkpoint, or the handler does not support it.', directory, ) return step_metadata
def _save_step_metadata( self, directory: epath.Path, custom_metadata: dict[str, Any] | None ): """Saves StepMetadata to the checkpoint directory.""" update_dict = { 'custom_metadata': custom_metadata, } if isinstance( self._handler, composite_checkpoint_handler.CompositeCheckpointHandler ): try: # get item_handlers from handler partial_metadata: StepMetadata = ( self._handler.metadata_from_temporary_paths(directory) ) except (FileNotFoundError, NotImplementedError, ValueError, TypeError): logging.warning( 'Failed to get per-item metadata from directory %s. Handler types ' 'will not be saved.', directory, ) else: update_dict['item_handlers'] = partial_metadata.item_handlers else: try: item_handler = self._handler.typestr() except (NotImplementedError, AttributeError): logging.warning( 'Failed to get item handler typestr from directory %s. Backup ' 'handler type will be saved.', directory, ) item_handler = ( f'{self._handler.__module__}.{self._handler.__class__.__qualname__}' ) update_dict['item_handlers'] = item_handler self._metadata_store.update( file_path=checkpoint.step_metadata_file_path(directory), **step_metadata_serialization.serialize_for_update(**update_dict), )
[docs] def close(self): """Closes the underlying CheckpointHandler.""" self._handler.close() self._metadata_store.close()
@property def handler(self) -> checkpoint_handler.CheckpointHandler: return self._handler def __contextmanager__( self, ) -> Iterable[Self]: try: yield self finally: self.close()