ocp.v1.options module#
Global configuration options.
AsyncOptions#
- class orbax.checkpoint.experimental.v1.options.AsyncOptions(*, timeout_secs=1200, post_finalization_callback=None, create_directories_asynchronously=True)[source]#
Options used to configure async behavior.
This dataclass defines the configuration parameters for asynchronous checkpoint saving operations within the Orbax framework.
Example
Initialize async configuration with a custom timeout and callback:
from orbax.checkpoint.v1.options import AsyncOptions def my_callback(): print("Async save successfully finalized.") options = AsyncOptions( timeout_secs=300, post_finalization_callback=my_callback, create_directories_asynchronously=False )
- timeout_secs#
The timeout in seconds for the async save operation.
- Type:
int
- post_finalization_callback#
A function that is called after the async save operation is complete.
- Type:
Callable[[], None] | None
- create_directories_asynchronously#
If True, creates directories asynchronously in the background.
- Type:
bool
- __delattr__(name)#
Implement delattr(self, name).
- __eq__(other)#
Return self==value.
- __hash__()#
Return hash(self).
- __init__(*, timeout_secs=1200, post_finalization_callback=None, create_directories_asynchronously=True)#
- __setattr__(name, value)#
Implement setattr(self, name, value).
MultiprocessingOptions#
- class orbax.checkpoint.experimental.v1.options.MultiprocessingOptions(*, primary_host=0, active_processes=None, barrier_sync_key_prefix=None)[source]#
Options used to configure multiprocessing behavior.
NOTE: These options are generally dangerous to mess with unless you know what you’re doing.
This dataclass defines the configuration parameters for multiprocessing checkpoint saving operations within the Orbax framework.
Example
Configure a multi-host setup where process 1 is designated as the primary host, only a subset of processes are active, and a custom barrier key is used to prevent collisions with other concurrent checkpointers:
from orbax.checkpoint.v1.options import MultiprocessingOptions options = MultiprocessingOptions( primary_host=1, active_processes={1, 2, 3}, barrier_sync_key_prefix="model_a_sync_"
- primary_host#
The host id of the primary host. Default to 0. If it’s set to None, then all hosts will be considered as primary. It’s useful in the case that all hosts are only working with local storage.
- Type:
int | None
- active_processes#
A set of process indices (corresponding to
process_index()) over whichCheckpointeris expected to be called. This makes it possible to have aCheckpointerinstance that runs over a subset of processes, rather than all processes as it is normally expected to do. If specified, primary_host must belong to active_processes.- Type:
set[int] | None
- barrier_sync_key_prefix#
A string to be prepended to the barrier sync key used to synchronize processes. This is useful to avoid collisions with other barrier syncs if another
Checkpointeris being used concurrently.- Type:
str | None
- __delattr__(name)#
Implement delattr(self, name).
- __eq__(other)#
Return self==value.
- __hash__()#
Return hash(self).
- __init__(*, primary_host=0, active_processes=None, barrier_sync_key_prefix=None)#
- __setattr__(name, value)#
Implement setattr(self, name, value).
FileOptions#
- class orbax.checkpoint.experimental.v1.options.FileOptions(*, path_permission_mode=None, temporary_path_class=None, path_class=<class 'etils.epath.abstract_path.Path'>)[source]#
Options used to configure checkpoint directories and files.
This dataclass defines the configuration parameters for creating and managing checkpoint directories and files on disk.
Example
Configure checkpoint files to use strict directory permissions (e.g., read/write/execute for owner, read/execute for group):
from orbax.checkpoint.v1.options import FileOptions options = FileOptions( path_permission_mode=0o750 )
- path_permission_mode#
Path permission mode for step directories, user metadata files. e.g. 0o750. Please check google/etils if your path is supported.
- Type:
int | None
- temporary_path_class#
A class that is used to create and finalize temporary paths, and to ensure atomicity.
- Type:
type[orbax.checkpoint._src.path.atomicity_types.TemporaryPath] | None
- path_class#
The implementation of
Pathto use. Defaults to etils.epath.Path, but may be overridden to some other subclass ofPath.- Type:
- v0()[source]#
Converts this
FileOptionsto a v0FileOptions.- Return type:
- __delattr__(name)#
Implement delattr(self, name).
- __eq__(other)#
Return self==value.
- __hash__()#
Return hash(self).
- __init__(*, path_permission_mode=None, temporary_path_class=None, path_class=<class 'etils.epath.abstract_path.Path'>)#
- __setattr__(name, value)#
Implement setattr(self, name, value).
PyTreeOptions#
- class orbax.checkpoint.experimental.v1.options.PyTreeOptions(*, saving=<factory>, loading=<factory>)[source]#
Options used to configure PyTree saving and loading.
This dataclass defines the configuration parameters for creating and managing PyTree saving and loading on disk.
- class Saving(*, pytree_metadata_options=<factory>)[source]#
Options for saving PyTrees.
pytree_metadata_options: Options for managing PyTree metadata.
- __delattr__(name)#
Implement delattr(self, name).
- __eq__(other)#
Return self==value.
- __hash__()#
Return hash(self).
- __init__(*, pytree_metadata_options=<factory>)#
- __setattr__(name, value)#
Implement setattr(self, name, value).
- class Loading(*, partial_load=False)[source]#
Options for loading PyTrees.
- partial_load: If True, only restore the parameters that are specified
in the abstract PyTree.
- __delattr__(name)#
Implement delattr(self, name).
- __eq__(other)#
Return self==value.
- __hash__()#
Return hash(self).
- __init__(*, partial_load=False)#
- __setattr__(name, value)#
Implement setattr(self, name, value).
- __delattr__(name)#
Implement delattr(self, name).
- __eq__(other)#
Return self==value.
- __hash__()#
Return hash(self).
- __init__(*, saving=<factory>, loading=<factory>)#
- __setattr__(name, value)#
Implement setattr(self, name, value).
ArrayOptions#
- class orbax.checkpoint.experimental.v1.options.ArrayOptions(*, saving=<factory>, loading=<factory>)[source]#
Options used to configure array saving and loading.
This dataclass defines the high-level configuration parameters for array checkpointing operations within the Orbax framework. Because it is defined as a frozen, keyword-only dataclass, instances are strictly immutable once created, and all parameters must be explicitly specified by their keyword names during initialization.
Example
To configure array options with specific saving formats and loading behaviors we can do so like this:
from orbax.checkpoint.v1.options import ArrayOptions options = ArrayOptions( saving=ArrayOptions.Saving( use_zarr3=True, use_compression=False, ), loading=ArrayOptions.Loading( enable_padding_and_truncation=True )
To save certain leaves in float16, while others in float32, we can use scoped_storage_options_creator like so:
import jax import jax.numpy as jnp from orbax.checkpoint.v1 import options as ocp_options def create_opts_fn(keypath, value): if 'small' in jax.tree_util.keystr(keypath): return ocp_options.ArrayOptions.Saving.StorageOptions( dtype=jnp.float16 ) return None # Fall back to global `storage_options` array_options = ocp_options.ArrayOptions( saving=ocp_options.ArrayOptions.Saving( storage_options=ocp_options.ArrayOptions.Saving.StorageOptions( dtype=jnp.float32 ), scoped_storage_options_creator=create_opts_fn ) )
- class Saving(*, storage_options=<factory>, use_ocdbt=True, use_zarr3=True, use_compression=True, ocdbt_target_data_file_size=None, enable_pinned_host_transfer=None, enable_post_merge_validation=True, use_replica_parallel=None, min_slice_bytes_for_replica_parallel=None, max_replicas_for_replica_parallel=None, enable_replica_parallel_separate_folder=False, enable_write_sharding_file=True, array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object>, scoped_storage_options_creator=None)[source]#
Options for saving arrays.
- storage_options#
Options used to customize array storage behavior for all leaves at a global level. See below.
- Type:
- use_ocdbt#
Enables OCDBT format.
- Type:
bool
- use_zarr3#
If True, use Zarr3 format.
- Type:
bool
- use_compression#
If True, use ZSTD compression.
- Type:
bool
- ocdbt_target_data_file_size#
Specifies the target size (in bytes) of each OCDBT data file. It only applies when OCDBT is enabled and Zarr3 must be turned on. If left unspecified, default size is 2GB. A value of 0 indicates no maximum file size limit. For best results, ensure chunk_byte_size is smaller than this value. For more details, refer to https://google.github.io/tensorstore/kvstore/ocdbt/index.html#json-kvstore/ocdbt.target_data_file_size
- Type:
int | None
- enable_pinned_host_transfer#
Whether to use pinned_host memory for the transfer from device to host memory. Passing None will enable pinned_host memory depending on the platform used (currently only enables it for the GPU backend).
- Type:
bool | None
- enable_post_merge_validation#
If True, enables validation of the parameters after the finalize step.
- Type:
bool
- use_replica_parallel#
Whether to parallelize saving across replicas.
- Type:
bool | None
- min_slice_bytes_for_replica_parallel#
Minimum number of bytes per replica slice. Only uses replica-parallel when the amount of data written per replica is greater than or equal to this number.
- Type:
int | None
- max_replicas_for_replica_parallel#
Maximum number of replicas over which saving will be parallelized if use_replica_parallel is True.
- Type:
int | None
- enable_replica_parallel_separate_folder#
Whether to save replica data in separate folders.
- Type:
bool
- enable_write_sharding_file#
whether to write sharding file, defaults to True.
- Type:
bool
- array_metadata_store#
Store to manage per host ArrayMetadata. To disable ArrayMetadata persistence, set it to None.
- Type:
array_metadata_store_lib.Store | None
- scoped_storage_options_creator#
A function that, when dealing with PyTrees, is applied to every leaf. If it returns an
ArrayOptions.Saving.StorageOptions, its fields take precedence when merging if they are set to non-None or non-default values with respect to storage_options. If it returns None, storage_options is used as a default for all fields. It is called similar to: jax.tree.map_with_path(scoped_storage_options_creator, pytree_to_save).- Type:
ScopedStorageOptionsCreator | None
- class ScopedStorageOptionsCreator(*args, **kwargs)[source]#
- __call__(key, value)[source]#
Call self as a function.
- Return type:
UnionType[StorageOptions,None]
- __init__(*args, **kwargs)#
- classmethod __subclasshook__(other)#
Abstract classes can override this to customize issubclass().
This is invoked early on by abc.ABCMeta.__subclasscheck__(). It should return True, False or NotImplemented. If it returns NotImplemented, the normal algorithm is used. Otherwise, it overrides the normal algorithm (and the outcome is cached).
- class StorageOptions(*, dtype=None, chunk_byte_size=None, shard_axes=None)[source]#
Options used to customize array storage behavior for individual leaves.
- dtype:
If provided, casts the parameter to the given dtype before saving. Note that the parameter must be compatible with the given type (e.g., jnp.bfloat16 is not compatible with np.ndarray).
- chunk_byte_size:
This is an experimental feature that automatically chooses the largest possible chunk shape while keeping the chunk byte size less than or equal to the specified chunk_byte_size. Both write_chunk_shape and read_chunk_shape are automatically set to the chosen shape. This uses a greedy algorithm that prioritizes splitting the largest dimensions first.
- shard_axes:
An optional list of axes that should be prioritized when sharding an array for storage. If empty, the storage sharding implementation will prioritize axes which are already sharded.
- __delattr__(name)#
Implement delattr(self, name).
- __eq__(other)#
Return self==value.
- __hash__()#
Return hash(self).
- __init__(*, dtype=None, chunk_byte_size=None, shard_axes=None)#
- __setattr__(name, value)#
Implement setattr(self, name, value).
- __delattr__(name)#
Implement delattr(self, name).
- __eq__(other)#
Return self==value.
- __hash__()#
Return hash(self).
- __init__(*, storage_options=<factory>, use_ocdbt=True, use_zarr3=True, use_compression=True, ocdbt_target_data_file_size=None, enable_pinned_host_transfer=None, enable_post_merge_validation=True, use_replica_parallel=None, min_slice_bytes_for_replica_parallel=None, max_replicas_for_replica_parallel=None, enable_replica_parallel_separate_folder=False, enable_write_sharding_file=True, array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object>, scoped_storage_options_creator=None)#
- __setattr__(name, value)#
Implement setattr(self, name, value).
- class Loading(*, enable_padding_and_truncation=False, raise_array_data_missing_error=True, use_load_and_broadcast=False, load_and_broadcast_options=<factory>)[source]#
Options for loading arrays.
- enable_padding_and_truncation:
If True, restoration allows silent truncating/padding of arrays if the stored array shape does not match the target shape. Otherwise, raises an error.
- raise_array_data_missing_error:
If True, raises an error if array data is missing. Otherwise allows returning zeros from an array range that was not necessarily written to.
- use_load_and_broadcast: Whether to use load-and-broadcast for multi-replica
loading. This is useful when the model has multiple replicas across different sets of devices (commonly across multiple TPU slices, but also applies to data-parallel model replicas within a single slice). Array shardings must be structured so that the mesh has a dimension on which all model weights are replicated. The checkpoint will then be loaded only on the hosts and devices taken from replica primary_replica_id along the replica_axis_index dimension. It will then be broadcast to all other replicas.
- class LoadAndBroadcastOptions(*, replica_axis_index=0, primary_replica_id=0, broadcast_memory_limit_bytes=None, broadcast_memory_scaling_factor=0.75)[source]#
Used to configure load-and-broadcast behavior in multi-replica loading.
- replica_axis_index: Defines the axis of the global mesh along which
replicas are defined. E.g. all devices in global_mesh.devices[replica_axis_index] are part of the same replica.
- primary_replica_id: The id of the replica that is used to load and
broadcast the checkpoint.
- broadcast_memory_limit_bytes: Specifies the memory size (in bytes) used
for broadcasting data.
- broadcast_memory_scaling_factor: Specifies the fraction of available
memory to use for broadcasting data.
- __delattr__(name)#
Implement delattr(self, name).
- __eq__(other)#
Return self==value.
- __hash__()#
Return hash(self).
- __init__(*, replica_axis_index=0, primary_replica_id=0, broadcast_memory_limit_bytes=None, broadcast_memory_scaling_factor=0.75)#
- __setattr__(name, value)#
Implement setattr(self, name, value).
- __delattr__(name)#
Implement delattr(self, name).
- __eq__(other)#
Return self==value.
- __hash__()#
Return hash(self).
- __init__(*, enable_padding_and_truncation=False, raise_array_data_missing_error=True, use_load_and_broadcast=False, load_and_broadcast_options=<factory>)#
- __setattr__(name, value)#
Implement setattr(self, name, value).
- __delattr__(name)#
Implement delattr(self, name).
- __eq__(other)#
Return self==value.
- __hash__()#
Return hash(self).
- __init__(*, saving=<factory>, loading=<factory>)#
- __setattr__(name, value)#
Implement setattr(self, name, value).
CheckpointablesOptions#
- class orbax.checkpoint.experimental.v1.options.CheckpointablesOptions(*, registry=<factory>)[source]#
Options used to configure checkpointables save/load behavior.
Primarily intended for registering custom
CheckpointableHandlerclasses. You can specify a registry directly, or use create_with_handlers. For example:checkpointables_options = ( ocp.options.CheckpointablesOptions.create_with_handlers( FooHandler(), bar=BarHandler(), ) ) with ocp.Context(checkpointables_options=checkpointables_options)): ocp.save_checkpointables(directory, dict(foo=Foo(...), bar=Bar(...)))
In this example, FooHandler is registered generically, which means that any checkpointable that is handleable by FooHandler can be saved/loaded (a Foo object in this case). In contrast, BarHandler is explicitly tied to the name bar, which means that only a checkpointable that is both handleable by BarHandler and has the name bar can handled by this BarHandler.
Recall that a global registry also exists, containing core handlers like
PyTreeHandlerandJsonHandler. Use ocp.handlers.register_handler to register a handler globally.Note that registration order matters. For example, if saving a dict containing only strings, both
JsonHandlerandPyTreeHandlerare capable of handling this object, butJsonHandlerwill be selected first because it is registered first.- registry#
A
CheckpointableHandlerRegistrythat is used to resolveCheckpointableHandlerclasses for each provided checkpointable during saving and loading.
- __delattr__(name)#
Implement delattr(self, name).
- __eq__(other)#
Return self==value.
- __hash__()#
Return hash(self).
- __init__(*, registry=<factory>)#
- __setattr__(name, value)#
Implement setattr(self, name, value).
PathwaysOptions#
- class orbax.checkpoint.experimental.v1.options.PathwaysOptions(*, checkpointing_impl=None)[source]#
Options used to configure Pathways saving and loading.
- checkpointing_impl#
The implementation to use for Pathways checkpointing.
- Type:
orbax.checkpoint._src.serialization.pathways_types.CheckpointingImpl | None
- __delattr__(name)#
Implement delattr(self, name).
- __eq__(other)#
Return self==value.
- __hash__()#
Return hash(self).
- __init__(*, checkpointing_impl=None)#
- __setattr__(name, value)#
Implement setattr(self, name, value).
DeletionOptions#
- class orbax.checkpoint.experimental.v1.options.DeletionOptions(*, gcs_deletion_options=<factory>)[source]#
Options used to configure checkpoint deletion behavior.
- gcs_deletion_options#
Deletion options specific to GCS.
- Type:
- class GcsDeletionOptions(*, todelete_full_path=None)[source]#
Deletion options specific to GCS.
- todelete_full_path#
A path relative to the bucket root for “soft-deleting” checkpoints on Google Cloud Storage (GCS). Instead of being permanently removed, checkpoints are moved to this new location within the same bucket. This is useful if direct deletion on GCS is time-consuming, as it allows an external component to manage the actual removal.
This option gathers all “deleted” items in a centralized path at the bucket level for future cleanup.
For instance, if a checkpoint is in gs://my-bucket/experiments/run1/, providing the value ‘trash’ will move a deleted step to gs://my-bucket/trash/<step_id>. Useful when direct deletion is time consuming. It gathers all deleted items in a centralized path for future cleanup.
- Type:
str | None
- __delattr__(name)#
Implement delattr(self, name).
- __eq__(other)#
Return self==value.
- __hash__()#
Return hash(self).
- __init__(*, todelete_full_path=None)#
- __setattr__(name, value)#
Implement setattr(self, name, value).
- __delattr__(name)#
Implement delattr(self, name).
- __eq__(other)#
Return self==value.
- __hash__()#
Return hash(self).
- __init__(*, gcs_deletion_options=<factory>)#
- __setattr__(name, value)#
Implement setattr(self, name, value).
MemoryOptions#
- class orbax.checkpoint.experimental.v1.options.MemoryOptions(*, write_concurrent_bytes=None, read_concurrent_bytes=None, transfer_concurrent_bytes=None, is_prioritized_key_fn=None)[source]#
Options for configuring memory limits for save / load.
Can help to reduce the possibility of OOM’s when large checkpoints are saved or loaded.
- write_concurrent_bytes#
Max concurrent bytes that are allowed for writing (per host). This applies only to additional memory used beyond the baseline space required to hold the model weights in memory. E.g. if model weights require 100 GB of memory in RAM, a setting below 100 GB will not prevent 100 GB from being allocated, but will prevent additional memory from being used during the write to disk. None indicates no limit.
- Type:
int | None
- read_concurrent_bytes#
Max concurrent bytes that are allowed for reading. None indicates no limit.
- Type:
int | None
- transfer_concurrent_bytes#
Max concurrent bytes allowed for device-to-host transfers (per host). When the limit is reached, arrays must be finished writing to the checkpoint before a new array can start being transferred. E.g. if the limit is 100 GB, and model weights would require 120 GB of RAM, then the remaining 20 GB will not be able to start D2H transfer until the first 100 GB have finished being written to the checkpoint. Note that asynchronous saves may not be truly asynchronous with this option enabled, as we have to block on some array writes before beginning others. Also see is_prioritized_key_fn. None indicates no limit.
- Type:
int | None
- is_prioritized_key_fn#
A function that accepts a PyTree keypath (obtained using jax.tree.map_with_path) that should be scheduled for D2H transfer before other keys. The transfer is scheduled before returning to the caller, so the values will never be corrupted by a concurrent update. Keys that are not prioritized will not be scheduled for transfer until all prioritized keys have been fully written to the checkpoint. This means that these values may be altered if the values are updated concurrently. Callers should take care to call wait_until_finished before updating array values (e.g. apply_gradients) if some keys are not prioritized. Note that any “prioritized” keys are assumed to be lightweight, and transfer_concurrent_bytes will be ignored for them.
- Type:
orbax.checkpoint._src.serialization.types.IsPrioritizedKeyFn | None
- __delattr__(name)#
Implement delattr(self, name).
- __eq__(other)#
Return self==value.
- __hash__()#
Return hash(self).
- __init__(*, write_concurrent_bytes=None, read_concurrent_bytes=None, transfer_concurrent_bytes=None, is_prioritized_key_fn=None)#
- __setattr__(name, value)#
Implement setattr(self, name, value).