Source code for orbax.checkpoint.experimental.v1._src.saving.saving

# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Defines free-function interface for saving."""

from orbax.checkpoint._src.checkpointers import async_checkpointer
from orbax.checkpoint._src.handlers import composite_checkpoint_handler
from orbax.checkpoint._src.handlers import handler_registration as legacy_handler_registration
from orbax.checkpoint.experimental.v1._src.context import context as context_lib
from orbax.checkpoint.experimental.v1._src.handlers import compatibility as handler_compatibility
from orbax.checkpoint.experimental.v1._src.handlers import registration as handler_registration
from orbax.checkpoint.experimental.v1._src.handlers import types as handler_types
import orbax.checkpoint.experimental.v1._src.handlers.global_registration  # pylint: disable=unused-import
from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout
from orbax.checkpoint.experimental.v1._src.path import types as path_types
from orbax.checkpoint.experimental.v1._src.saving import execution
from orbax.checkpoint.experimental.v1._src.saving import validation
from orbax.checkpoint.experimental.v1._src.synchronization import types as async_types
from orbax.checkpoint.experimental.v1._src.tree import types as tree_types
from typing_extensions import deprecated  # pytype: disable=not-supported-yet

STATE_CHECKPOINTABLE_KEY = checkpoint_layout.STATE_CHECKPOINTABLE_KEY
Checkpointable = handler_types.Checkpointable


[docs] def save( path: path_types.PathLike, state: tree_types.PyTreeOf[tree_types.Leaf], *, checkpointable_name: str = STATE_CHECKPOINTABLE_KEY, overwrite: bool = False, custom_metadata: tree_types.JsonType | None = None, ): """Saves a `PyTree`. The operation blocks until complete. For improved performance, consider using :py:func:`.save_async` instead. This function should be called on all available controller processes. Example usage: Simple save of a dictionary containing JAX arrays:: state = { 'params': { 'w': jnp.ones((8, 8)), 'b': jnp.zeros(8), }, 'step': 100 } # Saves to /tmp/my_checkpoint/ ocp.save('/tmp/my_checkpoint', state) Args: path: The path to save the checkpoint to. state: The `PyTree` to save. This may be any JAX `PyTree` (including custom objects registered as `PyTrees`) consisting of supported leaf types. See `orbax.checkpoint.experimental.v1.tree` for a table of standard supported leaf types. checkpointable_name: The name of the checkpointable to save a pytree under. Defaults to 'pytree'. overwrite: If True, fully overwrites an existing checkpoint in `path`. Otherwise, raises an error if the checkpoint already exists. custom_metadata: User-provided custom metadata. An arbitrary JSON-serializable dictionary the user can use to store additional information. The field is treated as opaque by Orbax. """ execution.save_checkpointables_impl( path, {checkpointable_name: state}, overwrite=overwrite, custom_metadata=custom_metadata, async_origin=False, ).result()
[docs] def save_checkpointables( path: path_types.PathLike, checkpointables: dict[str, Checkpointable], *, overwrite: bool = False, custom_metadata: tree_types.JsonType | None = None, ) -> None: """Saves a dictionary of checkpointables. A “checkpointable” refers to a logical piece of the checkpoint that is distinct in some way from other pieces. Checkpointables are separable; they may or may not be loaded concurrently and some may be omitted from the checkpoint entirely. Checkpointables are often represented by different types, and have different representations on disk. The quintessential example is model params vs. dataset. For example, one might do:: ocp.save_checkpointables( path, { 'params': pytree_of_arrays, 'dataset': pygrain.DatasetIterator(...), } ) It is also possible to do:: train_state = { 'params': params_pytree_of_arrays, 'opt_state': opt_state_pytree_of_arrays, 'step': step, ... } ocp.save_checkpointables(path, train_state) This is not the ideal way of doing things because it is then difficult to run transformations that involve the entire train state (see the `load_and_transform` API). This function should be called on all available controller processes. Args: path: The path to save the checkpoint to. checkpointables: A dictionary of checkpointables. Dictionary keys represent the names of the checkpointables, while the values are the checkpointable objects themselves. overwrite: If True, fully overwrites an existing checkpoint in `path`. Otherwise, raises an error if the checkpoint already exists. custom_metadata: User-provided custom metadata. An arbitrary JSON-serializable dictionary the user can use to store additional information. The field is treated as opaque by Orbax. """ validation.validate_save_checkpointables(checkpointables) execution.save_checkpointables_impl( path, checkpointables, overwrite=overwrite, custom_metadata=custom_metadata, async_origin=False, ).result()
# TODO(b/396190818): Test modification of the context by the user after the # save operation is scheduled.
[docs] def save_async( path: path_types.PathLike, state: tree_types.PyTreeOf[tree_types.Leaf], *, checkpointable_name: str = STATE_CHECKPOINTABLE_KEY, overwrite: bool = False, custom_metadata: tree_types.JsonType | None = None, ) -> async_types.AsyncResponse[None]: """Saves a `PyTree` asynchronously. Unlike :py:func:`.save`, this function returns immediately after the save operation is scheduled (except for certain operations, like device-to-host copying of on-device arrays, which must happen on the main thread). Further writing operations continue in a background thread. An :py:class:`~.AsyncResponse` is returned that can be used to block until the save is complete (using `response.result()`). Make sure to wait for completion before attempting to load the checkpoint or exiting the program. This function should be called on all available controller processes. Example usage: Simple save of a dictionary containing JAX arrays asynchronously:: state = { 'params': { 'w': jnp.ones((8, 8)), 'b': jnp.zeros(8), }, 'step': 100 } # Saves to /tmp/my_checkpoint/ future = ocp.experimental.v1.save_async( '/tmp/my_checkpoint', state ) # Perform other work here... # Wait for completion only when necessary future.result() Args: path: The path to save the checkpoint to. state: The `PyTree` to save. This may be any JAX `PyTree` (including custom objects registered as `PyTrees`) consisting of supported leaf types. See `orbax.checkpoint.v1.tree` for a table of standard supported leaf types. checkpointable_name: The name of the checkpointable to save a pytree under. Defaults to 'pytree'. overwrite: If True, fully overwrites an existing checkpoint in `path`. Otherwise, raises an error if the checkpoint already exists. custom_metadata: User-provided custom metadata. An arbitrary JSON-serializable dictionary the user can use to store additional information. The field is treated as opaque by Orbax. Returns: An `AsyncResponse` that can be used to block until the save is complete. Blocking can be done using `response.result()`, which returns `None`. """ return execution.save_checkpointables_impl( path, {checkpointable_name: state}, overwrite=overwrite, custom_metadata=custom_metadata, async_origin=True, )
[docs] def save_checkpointables_async( path: path_types.PathLike, checkpointables: dict[str, Checkpointable], *, overwrite: bool = False, custom_metadata: tree_types.JsonType | None = None, ) -> async_types.AsyncResponse[None]: """Saves a dictionary of checkpointables asynchronously. See :py:func:`.save_checkpointables` documentation. Unlike :py:func:`.save_checkpointables`, this function returns immediately after the save operation is scheduled (except for certain operations, like device-to-host copying of on-device arrays, which must happen on the main thread). Further writing operations continue in a background thread. An :py:class:`~.AsyncResponse` is returned that can be used to block until the save is complete (using `response.result()`). Make sure to wait for completion before attempting to load the checkpoint or exiting the program. This function should be called on all available controller processes. Example usage: Saving multiple distinct components (e.g. model parameters and dataset iterator) asynchronously:: path = '/tmp/my_checkpoint_step_100' # Setup components params = {'w': jnp.ones((8, 8)), 'b': jnp.zeros(8)} # Setup Grain iterator (Stateful Checkpointable) import grain dataset_iter = iter( grain.MapDataset.range(30) .batch(3) .map(lambda x: x.tolist()) ) # Save multiple components checkpointables = { 'model': params, 'dataset': dataset_iter, } # Start the async save response = ocp.save_checkpointables_async(path, checkpointables) # Perform other operations here... # Wait for the save to finish response.result() Args: path: The path to save the checkpoint to. checkpointables: A dictionary of checkpointables. Dictionary keys represent the names of the checkpointables, while the values are the checkpointable objects themselves. overwrite: If True, fully overwrites an existing checkpoint in `path`. Otherwise, raises an error if the checkpoint already exists. custom_metadata: User-provided custom metadata. An arbitrary JSON-serializable dictionary the user can use to store additional information. The field is treated as opaque by Orbax. Returns: An `AsyncResponse` that can be used to block until the save is complete. Blocking can be done using `response.result()`, which returns `None`. """ validation.validate_save_checkpointables(checkpointables) return execution.save_checkpointables_impl( path, checkpointables, overwrite=overwrite, custom_metadata=custom_metadata, async_origin=True, )
def get_v0_checkpointer_and_args( checkpointables: dict[str, Checkpointable], *, metrics: tree_types.JsonType | None = None, ) -> tuple[ async_checkpointer.AsyncCheckpointer, composite_checkpoint_handler.CompositeArgs, ]: """Constructs V0 Checkpointer and Args for saving. Args: checkpointables: A dictionary of checkpointables. metrics: Optional metrics to add to the checkpointables. Returns: A tuple containing the V0 Checkpointer and Args. """ context = context_lib.get_context() checkpointables = execution.add_internal_checkpointables( checkpointables, context=context, metrics=metrics ) handlers = { name: handler_registration.resolve_handler_for_save( context.checkpointables_options.registry, checkpointable, name=name ) for name, checkpointable in checkpointables.items() } compatibility_handlers = { name: handler_compatibility.get_compatibility_handler(handler) for name, handler in handlers.items() } handler_registry = ( legacy_handler_registration.DefaultCheckpointHandlerRegistry() ) for name, handler in compatibility_handlers.items(): handler_registry.add(name, handler_compatibility.Args, handler) composite_options = composite_checkpoint_handler.CompositeOptions( async_options=context.async_options.v0(), file_options=context.file_options.v0(), multiprocessing_options=context.multiprocessing_options.v0(), temporary_path_class=context.file_options.temporary_path_class, ) ckptr = async_checkpointer.AsyncCheckpointer( composite_checkpoint_handler.CompositeCheckpointHandler( handler_registry=handler_registry, composite_options=composite_options, ), async_options=context.async_options.v0(), multiprocessing_options=context.multiprocessing_options.v0(), file_options=context.file_options.v0(), temporary_path_class=context.file_options.temporary_path_class, ) args = composite_checkpoint_handler.CompositeArgs(**{ name: handler_compatibility.Args(checkpointable) for name, checkpointable in checkpointables.items() }) return ckptr, args @deprecated('Use `save` instead.') def save_pytree(*args, **kwargs): return save(*args, **kwargs) @deprecated('Use `save_async` instead.') def save_pytree_async(*args, **kwargs): return save_async(*args, **kwargs)