Source code for orbax.checkpoint.experimental.v1._src.context.options

# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Global configuration options."""

from __future__ import annotations

import contextvars
import dataclasses
import enum
from typing import Any, Callable, Protocol

from etils import epath
import numpy as np
from orbax.checkpoint import options as v0_options_lib
from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib
from orbax.checkpoint._src.metadata import tree as tree_metadata
from orbax.checkpoint._src.path import atomicity_types
from orbax.checkpoint._src.serialization import pathways_types
from orbax.checkpoint.experimental.v1._src.handlers import registration
from orbax.checkpoint.experimental.v1._src.path import types as path_types
from orbax.checkpoint.experimental.v1._src.serialization import types as serialization_types
from orbax.checkpoint.experimental.v1._src.tree import types as tree_types


FROZEN_IDS: contextvars.ContextVar[frozenset[int]] = contextvars.ContextVar(
    'orbax_frozen_option_ids', default=frozenset()
)


class _ActiveContextGuard:
  """Base class to guard against mutating options attached to an active context.

  In Orbax, `Context` objects and their child option dataclasses (e.g.,
  `ArrayOptions`, `AsyncOptions`) serve two distinct roles: standalone
  configuration templates (buildable/mutable in memory) and active execution
  policies when bound to a context manager (`with ctx:`).

  This base class ensures that once a `Context` enters a `with` block, its
  entire tree of option dataclasses is frozen against mutation. By checking
  `FROZEN_IDS` (which tracks the memory IDs of all active context components
  in the current thread/coroutine), it prevents mid-execution side effects or
  race conditions while allowing independent, non-active `Context` templates
  to be configured freely.
  """

  def __setattr__(self, name: str, value: Any) -> None:
    """Intercepts attribute assignment to enforce context immutability.

    Checks if the memory ID of this option instance is currently registered in
    the thread-local `FROZEN_IDS` set. A match indicates that this instance
    belongs to an active `Context` currently in use by an ongoing checkpointing
    operation; mutating configuration values during active execution is strictly
    prohibited.

    Args:
      name: The name of the attribute being assigned.
      value: The new value to assign to the attribute.

    Raises:
      RuntimeError: If this option instance belongs to an active context.
    """
    if id(self) in FROZEN_IDS.get():
      raise RuntimeError(
          'Cannot mutate options of an active context. '
          'Configure before entering the `with` block.'
      )
    object.__setattr__(self, name, value)


[docs] @dataclasses.dataclass(kw_only=True) class AsyncOptions(_ActiveContextGuard): """Options used to configure async behavior. This dataclass defines the configuration parameters for asynchronous checkpoint saving operations within the Orbax framework. Example: Initialize async configuration with a custom timeout and callback:: from orbax.checkpoint.v1.options import AsyncOptions def my_callback(): print("Async save successfully finalized.") options = AsyncOptions( timeout_secs=300, post_finalization_callback=my_callback, create_directories_asynchronously=False ) Attributes: timeout_secs: The timeout in seconds for the async save operation. post_finalization_callback: A function that is called after the async save operation is complete. create_directories_asynchronously: If True, creates directories asynchronously in the background. """ timeout_secs: int = 1200 # 20 minutes. post_finalization_callback: Callable[[], None] | None = None create_directories_asynchronously: bool = True def v0(self) -> v0_options_lib.AsyncOptions: return v0_options_lib.AsyncOptions( timeout_secs=self.timeout_secs, post_finalization_callback=self.post_finalization_callback, create_directories_asynchronously=self.create_directories_asynchronously, )
[docs] @dataclasses.dataclass(kw_only=True) class MultiprocessingOptions(_ActiveContextGuard): """Options used to configure multiprocessing behavior. NOTE: These options are generally dangerous to mess with unless you know what you're doing. This dataclass defines the configuration parameters for multiprocessing checkpoint saving operations within the Orbax framework. Example: Configure a multi-host setup where process 1 is designated as the primary host, only a subset of processes are active, and a custom barrier key is used to prevent collisions with other concurrent checkpointers:: from orbax.checkpoint.v1.options import MultiprocessingOptions options = MultiprocessingOptions( primary_host=1, active_processes={1, 2, 3}, barrier_sync_key_prefix="model_a_sync_" ) Attributes: primary_host: The host id of the primary host. Default to 0. If it's set to None, then all hosts will be considered as primary. It's useful in the case that all hosts are only working with local storage. active_processes: A set of process indices (corresponding to :py:func:`.process_index`) over which :py:class:`~.v1.training.Checkpointer` is expected to be called. This makes it possible to have a :py:class:`~.v1.training.Checkpointer` instance that runs over a subset of processes, rather than all processes as it is normally expected to do. If specified, `primary_host` must belong to `active_processes`. barrier_sync_key_prefix: A string to be prepended to the barrier sync key used to synchronize processes. This is useful to avoid collisions with other barrier syncs if another :py:class:`~.v1.training.Checkpointer` is being used concurrently. """ primary_host: int | None = 0 active_processes: set[int] | None = None barrier_sync_key_prefix: str | None = None def v0(self) -> v0_options_lib.MultiprocessingOptions: return v0_options_lib.MultiprocessingOptions( primary_host=self.primary_host, active_processes=self.active_processes, barrier_sync_key_prefix=self.barrier_sync_key_prefix, )
# pyformat: disable
[docs] @dataclasses.dataclass(kw_only=True) class FileOptions(_ActiveContextGuard): """Options used to configure checkpoint directories and files. This dataclass defines the configuration parameters for creating and managing checkpoint directories and files on disk. Example: Configure checkpoint files to use strict directory permissions (e.g., read/write/execute for owner, read/execute for group):: from orbax.checkpoint.v1.options import FileOptions options = FileOptions( path_permission_mode=0o750 ) Attributes: path_permission_mode: Path permission mode for step directories, user metadata files. e.g. 0o750. Please check https://github.com/google/etils/blob/main/etils/epath/backend.py if your path is supported. temporary_path_class: A class that is used to create and finalize temporary paths, and to ensure atomicity. path_class: The implementation of :py:class:`~.v1.path.Path` to use. Defaults to `etils.epath.Path`, but may be overridden to some other subclass of :py:class:`~.v1.path.Path`. """ path_permission_mode: int | None = None temporary_path_class: type[atomicity_types.TemporaryPath] | None = None path_class: type[path_types.Path] = epath.Path
[docs] def v0(self) -> v0_options_lib.FileOptions: """Converts this :py:class:`~.v1.options.FileOptions` to a v0 :py:class:`~orbax.checkpoint.options.FileOptions`.""" return v0_options_lib.FileOptions( path_permission_mode=self.path_permission_mode, )
# pyformat: enable
[docs] @dataclasses.dataclass(kw_only=True) class PyTreeOptions(_ActiveContextGuard): """Options used to configure PyTree saving and loading. This dataclass defines the configuration parameters for creating and managing PyTree saving and loading on disk. Attributes: saving: Options for saving PyTrees. loading: Options for loading PyTrees. """
[docs] @dataclasses.dataclass(kw_only=True) class Saving(_ActiveContextGuard): """Options for saving PyTrees. pytree_metadata_options: Options for managing PyTree metadata. """ pytree_metadata_options: tree_metadata.PyTreeMetadataOptions = ( dataclasses.field(default_factory=tree_metadata.PyTreeMetadataOptions) )
[docs] @dataclasses.dataclass(kw_only=True) class Loading(_ActiveContextGuard): """Options for loading PyTrees. partial_load: If True, only restore the parameters that are specified in the abstract PyTree. """ partial_load: bool = False
saving: Saving = dataclasses.field(default_factory=Saving) loading: Loading = dataclasses.field(default_factory=Loading)
[docs] @dataclasses.dataclass(kw_only=True) class ArrayOptions(_ActiveContextGuard): """Options used to configure array saving and loading. This dataclass defines the high-level configuration parameters for array checkpointing operations within the Orbax framework. Because it is defined as a keyword-only dataclass, instances map mutable option dimensions. Example: To configure array options with specific saving formats and loading behaviors we can do so like this:: from orbax.checkpoint import v1 as ocp ctx = ocp.Context() ctx.array.saving.use_zarr3 = True ctx.array.saving.use_compression = False ctx.array.loading.enable_padding_and_truncation = True To save certain leaves in float16, while others in float32, we can use `scoped_storage_options_creator` like so:: import jax import jax.numpy as jnp from orbax.checkpoint import v1 as ocp def create_opts_fn(keypath, value): if 'small' in jax.tree_util.keystr(keypath): return ocp_options.ArrayOptions.Saving.StorageOptions( dtype=jnp.float16 ) return None # Fall back to global `storage_options` ctx = ocp.Context() ctx.array.saving.storage_options.dtype = jnp.float32 ctx.array.saving.scoped_storage_options_creator = create_opts_fn Attributes: saving: Options for saving arrays. loading: Options for loading arrays. """
[docs] @dataclasses.dataclass(kw_only=True) class Saving(_ActiveContextGuard): """Options for saving arrays. Attributes: storage_options: Options used to customize array storage behavior for all leaves at a global level. See below. use_ocdbt: Enables OCDBT format. use_zarr3: If True, use Zarr3 format. use_compression: If True, use ZSTD compression. ocdbt_target_data_file_size: Specifies the target size (in bytes) of each OCDBT data file. It only applies when OCDBT is enabled and Zarr3 must be turned on. If left unspecified, default size is 2GB. A value of 0 indicates no maximum file size limit. For best results, ensure chunk_byte_size is smaller than this value. For more details, refer to https://google.github.io/tensorstore/kvstore/ocdbt/index.html#json-kvstore/ocdbt.target_data_file_size enable_pinned_host_transfer: Whether to use pinned_host memory for the transfer from device to host memory. Passing None will enable pinned_host memory depending on the platform used (currently only enables it for the GPU backend). enable_post_merge_validation: If True, enables validation of the parameters after the finalize step. use_replica_parallel: Whether to parallelize saving across replicas. min_slice_bytes_for_replica_parallel: Minimum number of bytes per replica slice. Only uses replica-parallel when the amount of data written per replica is greater than or equal to this number. max_replicas_for_replica_parallel: Maximum number of replicas over which saving will be parallelized if use_replica_parallel is True. enable_replica_parallel_separate_folder: Whether to save replica data in separate folders. enable_write_sharding_file: whether to write sharding file, defaults to True. array_metadata_store: Store to manage per host ArrayMetadata. To disable ArrayMetadata persistence, set it to None. scoped_storage_options_creator: A function that, when dealing with PyTrees, is applied to every leaf. If it returns an :py:class:`ArrayOptions.Saving.StorageOptions`, its fields take precedence when merging if they are set to non-None or non-default values with respect to `storage_options`. If it returns `None`, `storage_options` is used as a default for all fields. It is called similar to: `jax.tree.map_with_path(scoped_storage_options_creator, pytree_to_save)`. """
[docs] class ScopedStorageOptionsCreator(Protocol):
[docs] def __call__( self, key: tree_types.PyTreeKeyPath, value: Any ) -> ArrayOptions.Saving.StorageOptions | None: ...
[docs] @dataclasses.dataclass(kw_only=True) class StorageOptions(_ActiveContextGuard): """Options used to customize array storage behavior for individual leaves. dtype: If provided, casts the parameter to the given dtype before saving. Note that the parameter must be compatible with the given type (e.g., `jnp.bfloat16` is not compatible with `np.ndarray`). chunk_byte_size: This is an experimental feature that automatically chooses the largest possible chunk shape while keeping the chunk byte size less than or equal to the specified `chunk_byte_size`. Both `write_chunk_shape` and `read_chunk_shape` are automatically set to the chosen shape. This uses a greedy algorithm that prioritizes splitting the largest dimensions first. shard_axes: An optional list of axes that should be prioritized when sharding an array for storage. If empty, the storage sharding implementation will prioritize axes which are already sharded. """ dtype: np.typing.DTypeLike | None = None chunk_byte_size: int | None = None shard_axes: tuple[int, ...] | None = None
storage_options: StorageOptions = dataclasses.field( default_factory=StorageOptions ) use_ocdbt: bool = True use_zarr3: bool = True use_compression: bool = True ocdbt_target_data_file_size: int | None = None enable_pinned_host_transfer: bool | None = None enable_post_merge_validation: bool = True use_replica_parallel: bool | None = None min_slice_bytes_for_replica_parallel: int | None = None max_replicas_for_replica_parallel: int | None = None enable_replica_parallel_separate_folder: bool = False enable_write_sharding_file: bool = True array_metadata_store: array_metadata_store_lib.Store | None = ( array_metadata_store_lib.Store() ) scoped_storage_options_creator: ScopedStorageOptionsCreator | None = None
[docs] @dataclasses.dataclass(kw_only=True) class Loading(_ActiveContextGuard): """Options for loading arrays. enable_padding_and_truncation: If True, restoration allows silent truncating/padding of arrays if the stored array shape does not match the target shape. Otherwise, raises an error. raise_array_data_missing_error: If True, raises an error if array data is missing. Otherwise allows returning zeros from an array range that was not necessarily written to. use_load_and_broadcast: Whether to use load-and-broadcast for multi-replica loading. This is useful when the model has multiple replicas across different sets of devices (commonly across multiple TPU slices, but also applies to data-parallel model replicas within a single slice). Array shardings must be structured so that the mesh has a dimension on which all model weights are replicated. The checkpoint will then be loaded only on the hosts and devices taken from replica `primary_replica_id` along the `replica_axis_index` dimension. It will then be broadcast to all other replicas. """
[docs] @dataclasses.dataclass(kw_only=True) class LoadAndBroadcastOptions(_ActiveContextGuard): """Used to configure load-and-broadcast behavior in multi-replica loading. replica_axis_index: Defines the axis of the global mesh along which replicas are defined. E.g. all devices in global_mesh.devices[replica_axis_index] are part of the same replica. primary_replica_id: The id of the replica that is used to load and broadcast the checkpoint. broadcast_memory_limit_bytes: Specifies the memory size (in bytes) used for broadcasting data. broadcast_memory_scaling_factor: Specifies the fraction of available memory to use for broadcasting data. """ replica_axis_index: int | None = 0 primary_replica_id: int | None = 0 broadcast_memory_limit_bytes: int | None = None broadcast_memory_scaling_factor: float | None = 0.75
enable_padding_and_truncation: bool = False raise_array_data_missing_error: bool = True use_load_and_broadcast: bool = False load_and_broadcast_options: LoadAndBroadcastOptions = dataclasses.field( default_factory=LoadAndBroadcastOptions )
saving: Saving = dataclasses.field(default_factory=Saving) loading: Loading = dataclasses.field(default_factory=Loading)
[docs] @dataclasses.dataclass(kw_only=True) class CheckpointablesOptions(_ActiveContextGuard): """Options used to configure `checkpointables` save/load behavior. Primarily intended for registering custom :py:class:`.CheckpointableHandler` classes via direct registry binding. For example:: registry = ocp.handlers.local_registry() registry.add(FooHandler, checkpointable_name=None) registry.add(BarHandler, checkpointable_name='bar') context = ocp.Context() context.checkpointables.registry = registry with context: ocp.save_checkpointables(directory, dict(foo=Foo(...), bar=Bar(...))) In this example, `FooHandler` is registered generically, which means that any checkpointable that is handleable by `FooHandler` can be saved/loaded (a `Foo` object in this case). In contrast, `BarHandler` is explicitly tied to the name `bar`, which means that only a checkpointable that is both handleable by `BarHandler` and has the name `bar` can handled by this `BarHandler`. Recall that a global registry also exists, containing core handlers like :py:class:`.PyTreeHandler` and :py:class:`.JsonHandler`. Use `ocp.handlers.register_handler` to register a handler globally. Note that registration order matters. If multiple handlers are capable of handling an object, the handler registered last (most recently) takes precedence. For example, if saving a dict containing only strings, both :py:class:`.JsonHandler` and :py:class:`.PyTreeHandler` are capable of handling this object, but :py:class:`.PyTreeHandler` will be selected because it is registered after :py:class:`.JsonHandler` in the global registry. Similarly, any custom handlers added by the user will be registered after the global handlers, ensuring they take precedence. Attributes: registry: A :py:class:`.CheckpointableHandlerRegistry` that is used to resolve :py:class:`.CheckpointableHandler` classes for each provided `checkpointable` during saving and loading. """ registry: registration.CheckpointableHandlerRegistry = dataclasses.field( default_factory=lambda: registration.ReadOnlyCheckpointableHandlerRegistry( registration.local_registry(include_global_registry=True) ) )
[docs] @dataclasses.dataclass(kw_only=True) class PathwaysOptions(_ActiveContextGuard): """Options used to configure Pathways saving and loading. Attributes: checkpointing_impl: The implementation mode to use for Pathways checkpointing. """ checkpointing_impl: pathways_types.CheckpointingImpl | None = None
[docs] @dataclasses.dataclass(kw_only=True) class DeletionOptions(_ActiveContextGuard): """Options used to configure checkpoint deletion behavior. Attributes: gcs_deletion_options: Deletion options specific to GCS. """
[docs] @dataclasses.dataclass(kw_only=True) class GcsDeletionOptions(_ActiveContextGuard): """Deletion options specific to GCS. Attributes: todelete_full_path: A path relative to the bucket root for "soft-deleting" checkpoints on Google Cloud Storage (GCS). Instead of being permanently removed, checkpoints are moved to this new location within the same bucket. This is useful if direct deletion on GCS is time-consuming, as it allows an external component to manage the actual removal. This option gathers all "deleted" items in a centralized path at the bucket level for future cleanup. For instance, if a checkpoint is in gs://my-bucket/experiments/run1/, providing the value 'trash' will move a deleted step to gs://my-bucket/trash/<step_id>. Useful when direct deletion is time consuming. It gathers all deleted items in a centralized path for future cleanup. """ todelete_full_path: str | None = None
gcs_deletion_options: GcsDeletionOptions = dataclasses.field( default_factory=GcsDeletionOptions )
[docs] @dataclasses.dataclass(kw_only=True) class MemoryOptions(_ActiveContextGuard): """Options for configuring memory limits for save / load. Can help to reduce the possibility of OOM's when large checkpoints are saved or loaded. Attributes: write_concurrent_bytes: Max concurrent bytes that are allowed for writing (per host). This applies only to *additional* memory used beyond the baseline space required to hold the model weights in memory. E.g. if model weights require 100 GB of memory in RAM, a setting below 100 GB will not prevent 100 GB from being allocated, but will prevent additional memory from being used during the write to disk. `None` indicates no limit. read_concurrent_bytes: Max concurrent bytes that are allowed for reading. `None` indicates no limit. transfer_concurrent_bytes: Max concurrent bytes allowed for device-to-host transfers (per host). When the limit is reached, arrays must be finished writing to the checkpoint before a new array can start being transferred. E.g. if the limit is 100 GB, and model weights would require 120 GB of RAM, then the remaining 20 GB will not be able to start D2H transfer until the first 100 GB have finished being written to the checkpoint. Note that asynchronous saves may not be truly asynchronous with this option enabled, as we have to block on some array writes before beginning others. Also see `is_prioritized_key_fn`. `None` indicates no limit. is_prioritized_key_fn: A function that accepts a PyTree keypath (obtained using jax.tree.map_with_path) that should be scheduled for D2H transfer before other keys. The transfer is scheduled before returning to the caller, so the values will never be corrupted by a concurrent update. Keys that are not prioritized will not be scheduled for transfer until all prioritized keys have been fully written to the checkpoint. This means that these values may be altered if the values are updated concurrently. Callers should take care to call `wait_until_finished` before updating array values (e.g. `apply_gradients`) if some keys are not prioritized. Note that any "prioritized" keys are assumed to be lightweight, and `transfer_concurrent_bytes` will be ignored for them. """ write_concurrent_bytes: int | None = None read_concurrent_bytes: int | None = None transfer_concurrent_bytes: int | None = None is_prioritized_key_fn: serialization_types.IsPrioritizedKeyFn | None = None
[docs] @dataclasses.dataclass(kw_only=True) class SafetensorsOptions(_ActiveContextGuard): """Options for configuring Safetensors loading. Attributes: ignore_load_sharding: If True, skips sharding of the tensors across hosts/devices during load. Whole tensors will be present on each host, allowing for efficient conversion. """ ignore_load_sharding: bool = False
[docs] class CheckpointLayout(enum.Enum): """The layout of the checkpoint. By default, Orbax saves and loads checkpoints with its own layout. However, support for other layouts is available, as a means of supporting interoperatibility with other checkpointing libraries. Currently supported layouts are: ORBAX: Orbax's own layout. SAFETENSORS: https://huggingface.co/docs/safetensors/en/index """ ORBAX = 'Orbax' SAFETENSORS = 'SafeTensors'