Source code for orbax.checkpoint._src.handlers.array_checkpoint_handler

# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""ArrayCheckpointHandler for saving and restoring individual arrays/scalars."""

from __future__ import annotations

import dataclasses
from typing import List, Optional, Union

from etils import epath
import jax
import numpy as np
from orbax.checkpoint import aggregate_handlers
from orbax.checkpoint import checkpoint_args
from orbax.checkpoint import utils
from orbax.checkpoint._src import asyncio_utils
from orbax.checkpoint._src.futures import future
from orbax.checkpoint._src.handlers import async_checkpoint_handler
from orbax.checkpoint._src.handlers import base_pytree_checkpoint_handler
from orbax.checkpoint._src.serialization import type_handler_registry
from orbax.checkpoint._src.serialization import type_handlers


CheckpointArgs = checkpoint_args.CheckpointArgs
register_with_handler = checkpoint_args.register_with_handler
BasePyTreeCheckpointHandler = (
    base_pytree_checkpoint_handler.BasePyTreeCheckpointHandler
)
BasePyTreeSaveArgs = base_pytree_checkpoint_handler.BasePyTreeSaveArgs
BasePyTreeRestoreArgs = base_pytree_checkpoint_handler.BasePyTreeRestoreArgs

ArrayType = Union[int, float, np.number, np.ndarray, jax.Array]

_ELEMENT_KEY = 'ELEMENT'
# OCDBT has no real benefit for a single array.
_USE_OCDBT_FOR_SAVE = False
PYTREE_METADATA_FILE = base_pytree_checkpoint_handler.PYTREE_METADATA_FILE


[docs] class ArrayCheckpointHandler(async_checkpoint_handler.AsyncCheckpointHandler): """Handles saving and restoring individual arrays and scalars."""
[docs] def __init__(self, checkpoint_name: Optional[str] = None): """Initializes the handler. Args: checkpoint_name: Provides a name for the directory under which Tensorstore files will be saved. Defaults to 'checkpoint'. """ if not checkpoint_name: checkpoint_name = 'checkpoint' self._checkpoint_name = checkpoint_name self._aggregate_handler = aggregate_handlers.MsgpackHandler() self._base_pytree_checkpoint_handler = BasePyTreeCheckpointHandler( use_ocdbt=_USE_OCDBT_FOR_SAVE )
def _is_supported_type(self, item: ArrayType) -> bool: return isinstance(item, (np.ndarray, jax.Array)) or utils.is_scalar(item)
[docs] async def async_save( self, directory: epath.Path, item: Optional[ArrayType] = None, save_args: Optional[type_handlers.SaveArgs] = None, args: Optional[ArraySaveArgs] = None, ) -> Optional[List[future.Future]]: """Saves an object asynchronously. Args: directory: Folder in which to save. item: Deprecated, use `args`. save_args: Deprecated, use `args`. args: An ocp.array_checkpoint_handler.ArraySaveArgs (see below). Returns: A list of commit futures which can be run to complete the save. """ if args is not None: item = args.item save_args = args.save_args if not self._is_supported_type(item): raise TypeError(f'Unsupported type: {type(item)}.') if save_args is None: save_args = type_handlers.SaveArgs() pytree_args = BasePyTreeSaveArgs( item={self._checkpoint_name: item}, save_args={self._checkpoint_name: save_args}, ) return await self._base_pytree_checkpoint_handler.async_save( directory, args=pytree_args )
[docs] def save(self, directory: epath.Path, *args, **kwargs): """Saves an array synchronously.""" async def async_save(): commit_futures = await self.async_save(directory, *args, **kwargs) # pytype: disable=bad-return-type # Futures are already running, so sequential waiting is equivalent to # concurrent waiting. if commit_futures: # May be None. for f in commit_futures: f.result() # Block on result. asyncio_utils.run_sync(async_save())
[docs] def restore( self, directory: epath.Path, item: Optional[ArrayType] = None, restore_args: Optional[type_handlers.RestoreArgs] = None, args: Optional[ArrayRestoreArgs] = None, ) -> ArrayType: """Restores an object. Args: directory: folder from which to read. item: Deprecated, use `args`. restore_args: Deprecated, use `args`. args: An ocp.array_checkpoint_handler.ArrayRestoreArgs object (see below). Returns: The restored object. """ if args is None: args = ArrayRestoreArgs(item=item, restore_args=restore_args) if (directory / PYTREE_METADATA_FILE).exists(): pytree_args = BasePyTreeRestoreArgs( {self._checkpoint_name: args.item} if args.item is not None else None, restore_args={self._checkpoint_name: args.restore_args}, ) return self._base_pytree_checkpoint_handler.restore( directory, args=pytree_args )[self._checkpoint_name] # TODO(nikhilbansall): Remove this logic once support for legacy # checkpoints lacking PYTREE_METADATA_FILE is no longer needed. restore_args = args.restore_args or type_handlers.RestoreArgs() info = type_handlers.ParamInfo( name=self._checkpoint_name, parent_dir=directory, skip_deserialize=False, is_ocdbt_checkpoint=type_handlers.is_ocdbt_checkpoint(directory), ) restore_type = restore_args.restore_type if restore_type is None: restore_type = type_handlers.default_restore_type(restore_args) type_handler = type_handler_registry.get_type_handler(restore_type) result = asyncio_utils.run_sync( type_handler.deserialize([info], args=[restore_args]) )[0] return result
[docs] def finalize(self, directory: epath.Path): self._base_pytree_checkpoint_handler.finalize(directory)
[docs] def close(self): """See superclass documentation.""" self._aggregate_handler.close()
[docs] @register_with_handler(ArrayCheckpointHandler, for_save=True) @dataclasses.dataclass class ArraySaveArgs(CheckpointArgs): """Parameters for saving an array or scalar. Attributes: item (required): an array or scalar object. save_args: a `ocp.SaveArgs` object specifying save options. """ item: ArrayType save_args: Optional[type_handlers.SaveArgs] = None
[docs] @register_with_handler(ArrayCheckpointHandler, for_restore=True) @dataclasses.dataclass class ArrayRestoreArgs(CheckpointArgs): """Array restore args. Attributes: item: unused, but provided as an option for legacy-compatibility reasons. restore_args: a `ocp.RestoreArgs` object specifying restore options. """ item: Optional[ArrayType] = None restore_args: Optional[type_handlers.RestoreArgs] = None