# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Global configuration options."""
from __future__ import annotations
import contextvars
import dataclasses
import enum
from typing import Any, Callable, Protocol
from etils import epath
import numpy as np
from orbax.checkpoint import options as v0_options_lib
from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib
from orbax.checkpoint._src.metadata import tree as tree_metadata
from orbax.checkpoint._src.path import atomicity_types
from orbax.checkpoint._src.serialization import pathways_types
from orbax.checkpoint.experimental.v1._src.handlers import registration
from orbax.checkpoint.experimental.v1._src.path import types as path_types
from orbax.checkpoint.experimental.v1._src.serialization import types as serialization_types
from orbax.checkpoint.experimental.v1._src.tree import types as tree_types
FROZEN_IDS: contextvars.ContextVar[frozenset[int]] = contextvars.ContextVar(
'orbax_frozen_option_ids', default=frozenset()
)
class _ActiveContextGuard:
"""Base class to guard against mutating options attached to an active context.
In Orbax, `Context` objects and their child option dataclasses (e.g.,
`ArrayOptions`, `AsyncOptions`) serve two distinct roles: standalone
configuration templates (buildable/mutable in memory) and active execution
policies when bound to a context manager (`with ctx:`).
This base class ensures that once a `Context` enters a `with` block, its
entire tree of option dataclasses is frozen against mutation. By checking
`FROZEN_IDS` (which tracks the memory IDs of all active context components
in the current thread/coroutine), it prevents mid-execution side effects or
race conditions while allowing independent, non-active `Context` templates
to be configured freely.
"""
def __setattr__(self, name: str, value: Any) -> None:
"""Intercepts attribute assignment to enforce context immutability.
Checks if the memory ID of this option instance is currently registered in
the thread-local `FROZEN_IDS` set. A match indicates that this instance
belongs to an active `Context` currently in use by an ongoing checkpointing
operation; mutating configuration values during active execution is strictly
prohibited.
Args:
name: The name of the attribute being assigned.
value: The new value to assign to the attribute.
Raises:
RuntimeError: If this option instance belongs to an active context.
"""
if id(self) in FROZEN_IDS.get():
raise RuntimeError(
'Cannot mutate options of an active context. '
'Configure before entering the `with` block.'
)
object.__setattr__(self, name, value)
[docs]
@dataclasses.dataclass(kw_only=True)
class AsyncOptions(_ActiveContextGuard):
"""Options used to configure async behavior.
This dataclass defines the configuration parameters for asynchronous
checkpoint saving operations within the Orbax framework.
Example:
Initialize async configuration with a custom timeout and callback::
from orbax.checkpoint.v1.options import AsyncOptions
def my_callback():
print("Async save successfully finalized.")
options = AsyncOptions(
timeout_secs=300,
post_finalization_callback=my_callback,
create_directories_asynchronously=False
)
Attributes:
timeout_secs: The timeout in seconds for the async save operation.
post_finalization_callback: A function that is called after the async save
operation is complete.
create_directories_asynchronously: If True, creates directories
asynchronously in the background.
"""
timeout_secs: int = 1200 # 20 minutes.
post_finalization_callback: Callable[[], None] | None = None
create_directories_asynchronously: bool = True
def v0(self) -> v0_options_lib.AsyncOptions:
return v0_options_lib.AsyncOptions(
timeout_secs=self.timeout_secs,
post_finalization_callback=self.post_finalization_callback,
create_directories_asynchronously=self.create_directories_asynchronously,
)
[docs]
@dataclasses.dataclass(kw_only=True)
class MultiprocessingOptions(_ActiveContextGuard):
"""Options used to configure multiprocessing behavior.
NOTE: These options are generally dangerous to mess with unless you know what
you're doing.
This dataclass defines the configuration parameters for multiprocessing
checkpoint saving operations within the Orbax framework.
Example:
Configure a multi-host setup where process 1 is designated as the primary
host, only a subset of processes are active, and a custom barrier key is
used to prevent collisions with other concurrent checkpointers::
from orbax.checkpoint.v1.options import MultiprocessingOptions
options = MultiprocessingOptions(
primary_host=1,
active_processes={1, 2, 3},
barrier_sync_key_prefix="model_a_sync_"
)
Attributes:
primary_host: The host id of the primary host. Default to 0. If it's set
to None, then all hosts will be considered as primary. It's useful in the
case that all hosts are only working with local storage.
active_processes:
A set of process indices (corresponding to :py:func:`.process_index`) over
which :py:class:`~.v1.training.Checkpointer` is expected to be called.
This makes it possible to have a :py:class:`~.v1.training.Checkpointer`
instance that runs over a subset of processes, rather than all processes
as it is normally expected to do. If specified, `primary_host` must belong
to `active_processes`.
barrier_sync_key_prefix: A string to be prepended to the barrier sync key
used to synchronize processes. This is useful to avoid collisions with
other barrier syncs if another :py:class:`~.v1.training.Checkpointer` is
being used concurrently.
"""
primary_host: int | None = 0
active_processes: set[int] | None = None
barrier_sync_key_prefix: str | None = None
def v0(self) -> v0_options_lib.MultiprocessingOptions:
return v0_options_lib.MultiprocessingOptions(
primary_host=self.primary_host,
active_processes=self.active_processes,
barrier_sync_key_prefix=self.barrier_sync_key_prefix,
)
# pyformat: disable
[docs]
@dataclasses.dataclass(kw_only=True)
class FileOptions(_ActiveContextGuard):
"""Options used to configure checkpoint directories and files.
This dataclass defines the configuration parameters for creating and managing
checkpoint directories and files on disk.
Example:
Configure checkpoint files to use strict directory permissions (e.g.,
read/write/execute for owner, read/execute for group)::
from orbax.checkpoint.v1.options import FileOptions
options = FileOptions(
path_permission_mode=0o750
)
Attributes:
path_permission_mode:
Path permission mode for step directories, user metadata files. e.g.
0o750. Please check
https://github.com/google/etils/blob/main/etils/epath/backend.py if your
path is supported.
temporary_path_class:
A class that is used to create and finalize temporary paths, and to ensure
atomicity.
path_class:
The implementation of :py:class:`~.v1.path.Path` to use. Defaults to
`etils.epath.Path`, but may be overridden to some other subclass of
:py:class:`~.v1.path.Path`.
"""
path_permission_mode: int | None = None
temporary_path_class: type[atomicity_types.TemporaryPath] | None = None
path_class: type[path_types.Path] = epath.Path
[docs]
def v0(self) -> v0_options_lib.FileOptions:
"""Converts this :py:class:`~.v1.options.FileOptions` to a v0 :py:class:`~orbax.checkpoint.options.FileOptions`."""
return v0_options_lib.FileOptions(
path_permission_mode=self.path_permission_mode,
)
# pyformat: enable
[docs]
@dataclasses.dataclass(kw_only=True)
class PyTreeOptions(_ActiveContextGuard):
"""Options used to configure PyTree saving and loading.
This dataclass defines the configuration parameters for creating and managing
PyTree saving and loading on disk.
Attributes:
saving: Options for saving PyTrees.
loading: Options for loading PyTrees.
"""
[docs]
@dataclasses.dataclass(kw_only=True)
class Saving(_ActiveContextGuard):
"""Options for saving PyTrees.
pytree_metadata_options: Options for managing PyTree metadata.
"""
pytree_metadata_options: tree_metadata.PyTreeMetadataOptions = (
dataclasses.field(default_factory=tree_metadata.PyTreeMetadataOptions)
)
[docs]
@dataclasses.dataclass(kw_only=True)
class Loading(_ActiveContextGuard):
"""Options for loading PyTrees.
partial_load: If True, only restore the parameters that are specified
in the abstract PyTree.
"""
partial_load: bool = False
saving: Saving = dataclasses.field(default_factory=Saving)
loading: Loading = dataclasses.field(default_factory=Loading)
[docs]
@dataclasses.dataclass(kw_only=True)
class ArrayOptions(_ActiveContextGuard):
"""Options used to configure array saving and loading.
This dataclass defines the high-level configuration parameters for array
checkpointing operations within the Orbax framework. Because it is defined
as a keyword-only dataclass, instances map mutable option dimensions.
Example:
To configure array options with specific saving formats and loading
behaviors we can do so like this::
from orbax.checkpoint import v1 as ocp
ctx = ocp.Context()
ctx.array.saving.use_zarr3 = True
ctx.array.saving.use_compression = False
ctx.array.loading.enable_padding_and_truncation = True
To save certain leaves in float16, while others in float32, we can use
`scoped_storage_options_creator` like so::
import jax
import jax.numpy as jnp
from orbax.checkpoint import v1 as ocp
def create_opts_fn(keypath, value):
if 'small' in jax.tree_util.keystr(keypath):
return ocp_options.ArrayOptions.Saving.StorageOptions(
dtype=jnp.float16
)
return None # Fall back to global `storage_options`
ctx = ocp.Context()
ctx.array.saving.storage_options.dtype = jnp.float32
ctx.array.saving.scoped_storage_options_creator = create_opts_fn
Attributes:
saving: Options for saving arrays.
loading: Options for loading arrays.
"""
[docs]
@dataclasses.dataclass(kw_only=True)
class Saving(_ActiveContextGuard):
"""Options for saving arrays.
Attributes:
storage_options: Options used to customize array storage behavior for all
leaves at a global level. See below.
use_ocdbt: Enables OCDBT format.
use_zarr3: If True, use Zarr3 format.
use_compression: If True, use ZSTD compression.
ocdbt_target_data_file_size: Specifies the target size (in bytes) of each
OCDBT data file. It only applies when OCDBT is enabled and Zarr3 must be
turned on. If left unspecified, default size is 2GB. A value of 0
indicates no maximum file size limit. For best results, ensure
chunk_byte_size is smaller than this value. For more details, refer to
https://google.github.io/tensorstore/kvstore/ocdbt/index.html#json-kvstore/ocdbt.target_data_file_size
enable_pinned_host_transfer: Whether to use pinned_host memory for the
transfer from device to host memory. Passing None will enable
pinned_host memory depending on the platform used (currently only
enables it for the GPU backend).
enable_post_merge_validation: If True, enables validation of the
parameters after the finalize step.
use_replica_parallel: Whether to parallelize saving across replicas.
min_slice_bytes_for_replica_parallel: Minimum number of bytes per replica
slice. Only uses replica-parallel when the amount of data written per
replica is greater than or equal to this number.
max_replicas_for_replica_parallel: Maximum number of replicas over which
saving will be parallelized if use_replica_parallel is True.
enable_replica_parallel_separate_folder: Whether to save replica data in
separate folders.
enable_write_sharding_file: whether to write sharding file, defaults to
True.
array_metadata_store: Store to manage per host ArrayMetadata. To disable
ArrayMetadata persistence, set it to None.
scoped_storage_options_creator: A function that, when dealing with
PyTrees, is applied to every leaf. If it returns an
:py:class:`ArrayOptions.Saving.StorageOptions`, its fields take
precedence when merging if they are set to non-None or non-default
values with respect to `storage_options`. If it returns `None`,
`storage_options` is used as a default for all fields. It is called
similar to: `jax.tree.map_with_path(scoped_storage_options_creator,
pytree_to_save)`.
"""
[docs]
class ScopedStorageOptionsCreator(Protocol):
[docs]
def __call__(
self, key: tree_types.PyTreeKeyPath, value: Any
) -> ArrayOptions.Saving.StorageOptions | None:
...
[docs]
@dataclasses.dataclass(kw_only=True)
class StorageOptions(_ActiveContextGuard):
"""Options used to customize array storage behavior for individual leaves.
dtype:
If provided, casts the parameter to the given dtype before saving.
Note that the parameter must be compatible with the given type (e.g.,
`jnp.bfloat16` is not compatible with `np.ndarray`).
chunk_byte_size:
This is an experimental feature that automatically chooses the largest
possible chunk shape while keeping the chunk byte size less than or
equal to the specified `chunk_byte_size`. Both `write_chunk_shape` and
`read_chunk_shape` are automatically set to the chosen shape. This uses
a greedy algorithm that prioritizes splitting the largest dimensions
first.
shard_axes:
An optional list of axes that should be prioritized when sharding an
array for storage. If empty, the storage sharding implementation will
prioritize axes which are already sharded.
"""
dtype: np.typing.DTypeLike | None = None
chunk_byte_size: int | None = None
shard_axes: tuple[int, ...] | None = None
storage_options: StorageOptions = dataclasses.field(
default_factory=StorageOptions
)
use_ocdbt: bool = True
use_zarr3: bool = True
use_compression: bool = True
ocdbt_target_data_file_size: int | None = None
enable_pinned_host_transfer: bool | None = None
enable_post_merge_validation: bool = True
use_replica_parallel: bool | None = None
min_slice_bytes_for_replica_parallel: int | None = None
max_replicas_for_replica_parallel: int | None = None
enable_replica_parallel_separate_folder: bool = False
enable_write_sharding_file: bool = True
array_metadata_store: array_metadata_store_lib.Store | None = (
array_metadata_store_lib.Store()
)
scoped_storage_options_creator: ScopedStorageOptionsCreator | None = None
[docs]
@dataclasses.dataclass(kw_only=True)
class Loading(_ActiveContextGuard):
"""Options for loading arrays.
enable_padding_and_truncation:
If True, restoration allows silent truncating/padding of arrays if the
stored array shape does not match the target shape. Otherwise, raises an
error.
raise_array_data_missing_error:
If True, raises an error if array data is missing. Otherwise allows
returning zeros from an array range that was not necessarily written to.
use_load_and_broadcast: Whether to use load-and-broadcast for multi-replica
loading. This is useful when the model has multiple replicas across
different sets of devices (commonly across multiple TPU slices, but also
applies to data-parallel model replicas within a single slice). Array
shardings must be structured so that the mesh has a dimension on which
all model weights are replicated. The checkpoint will then be loaded only
on the hosts and devices taken from replica `primary_replica_id` along the
`replica_axis_index` dimension. It will then be broadcast to all other
replicas.
"""
[docs]
@dataclasses.dataclass(kw_only=True)
class LoadAndBroadcastOptions(_ActiveContextGuard):
"""Used to configure load-and-broadcast behavior in multi-replica loading.
replica_axis_index: Defines the axis of the global mesh along which
replicas are defined. E.g. all devices in
global_mesh.devices[replica_axis_index] are part of the same replica.
primary_replica_id: The id of the replica that is used to load and
broadcast the checkpoint.
broadcast_memory_limit_bytes: Specifies the memory size (in bytes) used
for broadcasting data.
broadcast_memory_scaling_factor: Specifies the fraction of available
memory to use for broadcasting data.
"""
replica_axis_index: int | None = 0
primary_replica_id: int | None = 0
broadcast_memory_limit_bytes: int | None = None
broadcast_memory_scaling_factor: float | None = 0.75
enable_padding_and_truncation: bool = False
raise_array_data_missing_error: bool = True
use_load_and_broadcast: bool = False
load_and_broadcast_options: LoadAndBroadcastOptions = dataclasses.field(
default_factory=LoadAndBroadcastOptions
)
saving: Saving = dataclasses.field(default_factory=Saving)
loading: Loading = dataclasses.field(default_factory=Loading)
[docs]
@dataclasses.dataclass(kw_only=True)
class CheckpointablesOptions(_ActiveContextGuard):
"""Options used to configure `checkpointables` save/load behavior.
Primarily intended for registering custom :py:class:`.CheckpointableHandler`
classes via direct registry binding.
For example::
registry = ocp.handlers.local_registry()
registry.add(FooHandler, checkpointable_name=None)
registry.add(BarHandler, checkpointable_name='bar')
context = ocp.Context()
context.checkpointables.registry = registry
with context:
ocp.save_checkpointables(directory, dict(foo=Foo(...), bar=Bar(...)))
In this example, `FooHandler` is registered generically, which means that any
checkpointable that is handleable by `FooHandler` can be saved/loaded (a
`Foo` object in this case). In contrast, `BarHandler` is explicitly tied to
the name `bar`, which means that only a checkpointable that is both handleable
by `BarHandler` and has the name `bar` can handled by this `BarHandler`.
Recall that a global registry also exists, containing core handlers like
:py:class:`.PyTreeHandler` and :py:class:`.JsonHandler`. Use
`ocp.handlers.register_handler` to register a handler globally.
Note that registration order matters. If multiple handlers are capable of
handling an object, the handler registered last (most recently) takes
precedence. For example, if saving a dict containing only strings, both
:py:class:`.JsonHandler` and :py:class:`.PyTreeHandler` are capable of
handling this object, but :py:class:`.PyTreeHandler` will be selected because
it is registered after :py:class:`.JsonHandler` in the global registry.
Similarly, any custom handlers added by the user will be registered after the
global handlers, ensuring they take precedence.
Attributes:
registry: A :py:class:`.CheckpointableHandlerRegistry` that is used to
resolve :py:class:`.CheckpointableHandler` classes for each provided
`checkpointable` during saving and loading.
"""
registry: registration.CheckpointableHandlerRegistry = dataclasses.field(
default_factory=lambda: registration.ReadOnlyCheckpointableHandlerRegistry(
registration.local_registry(include_global_registry=True)
)
)
[docs]
@dataclasses.dataclass(kw_only=True)
class PathwaysOptions(_ActiveContextGuard):
"""Options used to configure Pathways saving and loading.
Attributes:
checkpointing_impl: The implementation mode to use for Pathways
checkpointing.
"""
checkpointing_impl: pathways_types.CheckpointingImpl | None = None
[docs]
@dataclasses.dataclass(kw_only=True)
class DeletionOptions(_ActiveContextGuard):
"""Options used to configure checkpoint deletion behavior.
Attributes:
gcs_deletion_options: Deletion options specific to GCS.
"""
[docs]
@dataclasses.dataclass(kw_only=True)
class GcsDeletionOptions(_ActiveContextGuard):
"""Deletion options specific to GCS.
Attributes:
todelete_full_path: A path relative to the bucket root for "soft-deleting"
checkpoints on Google Cloud Storage (GCS). Instead of being permanently
removed, checkpoints are moved to this new location within the same
bucket. This is useful if direct deletion on GCS is time-consuming, as
it allows an external component to manage the actual removal.
This option gathers all "deleted" items in a centralized path at the
bucket level for future cleanup.
For instance, if a checkpoint is in
gs://my-bucket/experiments/run1/, providing the value 'trash' will move
a deleted step to gs://my-bucket/trash/<step_id>. Useful when direct
deletion is time consuming. It gathers all deleted items in a
centralized path for future cleanup.
"""
todelete_full_path: str | None = None
gcs_deletion_options: GcsDeletionOptions = dataclasses.field(
default_factory=GcsDeletionOptions
)
[docs]
@dataclasses.dataclass(kw_only=True)
class MemoryOptions(_ActiveContextGuard):
"""Options for configuring memory limits for save / load.
Can help to reduce the possibility of OOM's when large checkpoints are
saved or loaded.
Attributes:
write_concurrent_bytes: Max concurrent bytes that are allowed for writing
(per host). This applies only to *additional* memory used beyond the
baseline space required to hold the model weights in memory. E.g. if
model weights require 100 GB of memory in RAM, a setting below 100 GB
will not prevent 100 GB from being allocated, but will prevent
additional memory from being used during the write to disk.
`None` indicates no limit.
read_concurrent_bytes:
Max concurrent bytes that are allowed for reading.
`None` indicates no limit.
transfer_concurrent_bytes: Max concurrent bytes allowed for device-to-host
transfers (per host). When the limit is reached, arrays must be finished
writing to the checkpoint before a new array can start being
transferred. E.g. if the limit is 100 GB, and model weights would
require 120 GB of RAM, then the remaining 20 GB will not be able to
start D2H transfer until the first 100 GB have finished being written to
the checkpoint. Note that asynchronous saves may not be truly
asynchronous with this option enabled, as we have to block on some array
writes before beginning others. Also see `is_prioritized_key_fn`.
`None` indicates no limit.
is_prioritized_key_fn: A function that accepts a PyTree keypath (obtained
using jax.tree.map_with_path) that should be scheduled for D2H transfer
before other keys. The transfer is scheduled before returning to the
caller, so the values will never be corrupted by a concurrent update.
Keys that are not prioritized will not be scheduled for transfer until
all prioritized keys have been fully written to the checkpoint. This
means that these values may be altered if the values are updated
concurrently. Callers should take care to call `wait_until_finished`
before updating array values (e.g. `apply_gradients`) if some keys are
not prioritized. Note that any "prioritized" keys are assumed to be
lightweight, and `transfer_concurrent_bytes` will be ignored for
them.
"""
write_concurrent_bytes: int | None = None
read_concurrent_bytes: int | None = None
transfer_concurrent_bytes: int | None = None
is_prioritized_key_fn: serialization_types.IsPrioritizedKeyFn | None = None
[docs]
@dataclasses.dataclass(kw_only=True)
class SafetensorsOptions(_ActiveContextGuard):
"""Options for configuring Safetensors loading.
Attributes:
ignore_load_sharding: If True, skips sharding of the tensors across
hosts/devices during load. Whole tensors will be present on each host,
allowing for efficient conversion.
"""
ignore_load_sharding: bool = False
[docs]
class CheckpointLayout(enum.Enum):
"""The layout of the checkpoint.
By default, Orbax saves and loads checkpoints with its own layout. However,
support for other layouts is available, as a means of supporting
interoperatibility with other checkpointing libraries.
Currently supported layouts are:
ORBAX: Orbax's own layout.
SAFETENSORS: https://huggingface.co/docs/safetensors/en/index
"""
ORBAX = 'Orbax'
SAFETENSORS = 'SafeTensors'