Source code for orbax.checkpoint._src.handlers.random_key_checkpoint_handler

# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""RandomKeyCheckpointHandlers for saving and restoring individual Jax and Numpy random keys."""

from __future__ import annotations

import abc
import dataclasses
from typing import Any, List, Mapping, Optional, Tuple, Union

from etils import epath
import jax
from orbax.checkpoint import checkpoint_args
from orbax.checkpoint import options as options_lib
from orbax.checkpoint._src import asyncio_utils
from orbax.checkpoint._src.futures import future
from orbax.checkpoint._src.handlers import array_checkpoint_handler
from orbax.checkpoint._src.handlers import async_checkpoint_handler
from orbax.checkpoint._src.handlers import composite_checkpoint_handler
from orbax.checkpoint._src.handlers import json_checkpoint_handler
from orbax.checkpoint._src.handlers import pytree_checkpoint_handler
from orbax.checkpoint._src.serialization import type_handlers

NumpyRandomKeyType = Union[tuple, dict]

ArrayRestoreArgs = array_checkpoint_handler.ArrayRestoreArgs
ArraySaveArgs = array_checkpoint_handler.ArraySaveArgs
CheckpointArgs = checkpoint_args.CheckpointArgs
CompositeArgs = composite_checkpoint_handler.CompositeArgs
CompositeCheckpointHandler = (
    composite_checkpoint_handler.CompositeCheckpointHandler
)
JsonRestoreArgs = json_checkpoint_handler.JsonRestoreArgs
JsonSaveArgs = json_checkpoint_handler.JsonSaveArgs
PyTreeRestoreArgs = pytree_checkpoint_handler.PyTreeRestoreArgs
PyTreeSaveArgs = pytree_checkpoint_handler.PyTreeSaveArgs
register_with_handler = checkpoint_args.register_with_handler


class BaseRandomKeyCheckpointHandler(
    async_checkpoint_handler.AsyncCheckpointHandler
):
  """Base handle saving and restoring individual Jax random key in both typed and untyped format."""

  def __init__(self, key_name: str):
    """Initializes the handler.

    Args:
      key_name: Provides a name for the directory under which Tensorstore files
        will be saved. Defaults to 'random_key'.
    """
    self._key_name = key_name
    self._key_metadata = f'{self._key_name}_metadata'
    composite_options = composite_checkpoint_handler.CompositeOptions(
        async_options=options_lib.AsyncOptions(
            create_directories_asynchronously=False
        )
    )
    self._handler = CompositeCheckpointHandler(
        composite_options=composite_options
    )

  @abc.abstractmethod
  def checkpoint_save_args(
      self, args: CheckpointArgs
  ) -> Tuple[CheckpointArgs, JsonSaveArgs]:
    """Return the `args` for saving item and metadata.

    Args:
      args: An ocp.checkpoint_args.CheckpointArgs.

    Returns:
      Return a tuple of CheckpointArgs, and JsonSaveArgs.
      The CheckpointArgs is to save the actual random key and the JsonSaveArgs
      is to save any extra metadata
    """
    pass

  @abc.abstractmethod
  def checkpoint_restore_args(
      self,
      args: CheckpointArgs,
  ) -> CheckpointArgs:
    """Return the `args` for retoring the item.

    Args:
      args: An ocp.checkpoint_args.CheckpointArgs.

    Returns:
      Return a CheckpointArgs to restore the item
    """
    pass

  @abc.abstractmethod
  def post_restore(self, item: Any, metadata: Mapping[Any, Any]) -> Any:
    """Final operations needed to be done on the returning result."""
    pass

  async def async_save(
      self,
      directory: epath.Path,
      args: CheckpointArgs,
  ) -> Optional[List[future.Future]]:
    """Saves a random key asynchronously.

    Args:
      directory: Folder in which to save.
      args: An ocp.checkpoint_args.CheckpointArgs.

    Returns:
      A list of commit futures which can be run to complete the save.
    """
    item_arg, metadata_arg = self.checkpoint_save_args(args)
    return await self._handler.async_save(
        directory,
        CompositeArgs(**{
            self._key_name: item_arg,
            self._key_metadata: metadata_arg,
        }),
    )

  def save(self, directory: epath.Path, *args, **kwargs):
    """Saves a random key synchronously."""

    async def async_save():
      commit_futures = await self.async_save(directory, *args, **kwargs)  # pytype: disable=bad-return-type
      # Futures are already running, so sequential waiting is equivalent to
      # concurrent waiting.
      if commit_futures:  # May be None.
        for f in commit_futures:
          f.result()  # Block on result.

    asyncio_utils.run_sync(async_save())

  def restore(
      self,
      directory: epath.Path,
      args: CheckpointArgs,
  ) -> Any:
    """Restores a random key.

    Args:
      directory: folder from which to read.
      args: An ocp.checkpoint_args.CheckpointArgs.

    Returns:
      The restored object.
    """
    item_arg = self.checkpoint_restore_args(args)
    result = self._handler.restore(
        directory,
        args=CompositeArgs(**{
            self._key_name: item_arg,
            self._key_metadata: JsonRestoreArgs(),
        }),
    )

    return self.post_restore(result[self._key_name], result[self._key_metadata])

  def finalize(self, directory: epath.Path):
    self._handler.finalize(directory)

  def close(self):
    self._handler.close()


[docs] class JaxRandomKeyCheckpointHandler(BaseRandomKeyCheckpointHandler): """Handles saving and restoring individual Jax random key in both typed and untyped format.""" def __init__(self, key_name: Optional[str] = None): """Initializes the handler. Args: key_name: Provides a name for the directory under which Tensorstore files will be saved. Defaults to 'jax_random_key'. """ super().__init__(key_name or 'jax_random_key') def checkpoint_save_args( self, args: JaxRandomKeySaveArgs ) -> Tuple[CheckpointArgs, JsonSaveArgs]: item = args.item save_args = args.save_args if isinstance(item, jax.Array): if is_typed := jax.dtypes.issubdtype(item.dtype, jax.dtypes.prng_key): item = jax.random.key_data(item) metadata = {'typed': is_typed} else: raise TypeError(f'Unsupported type: {type(item)}.') return ( ArraySaveArgs(jax.random.key_data(item), save_args), JsonSaveArgs(metadata), ) def checkpoint_restore_args( self, args: JaxRandomKeyRestoreArgs ) -> CheckpointArgs: return ArrayRestoreArgs(restore_args=args.restore_args) def post_restore(self, item: Any, metadata: Mapping[Any, Any]) -> Any: if metadata['typed']: return jax.random.wrap_key_data(item) else: return item
[docs] @register_with_handler(JaxRandomKeyCheckpointHandler, for_save=True) @dataclasses.dataclass class JaxRandomKeySaveArgs(CheckpointArgs): """Parameters for saving a JAX random key. Attributes: item (required): a JAX random key. save_args: a `ocp.SaveArgs` object specifying save options. """ item: jax.Array save_args: Optional[type_handlers.SaveArgs] = None
[docs] @register_with_handler(JaxRandomKeyCheckpointHandler, for_restore=True) @dataclasses.dataclass class JaxRandomKeyRestoreArgs(CheckpointArgs): """Jax random key restore args. Attributes: restore_args: a `ocp.RestoreArgs` object specifying restore options for JaxArray. """ restore_args: Optional[type_handlers.RestoreArgs] = None
[docs] class NumpyRandomKeyCheckpointHandler(BaseRandomKeyCheckpointHandler): """Saves Nnumpy random key in legacy or non-lagacy format.""" def __init__(self, key_name: Optional[str] = None): """Initializes NumpyRandomKeyCheckpointHandler. Args: key_name: Provides a name for the directory under which Tensorstore files will be saved. Defaults to 'np_random_key'. """ super().__init__(key_name or 'np_random_key') def checkpoint_save_args( self, args: NumpyRandomKeySaveArgs ) -> Tuple[CheckpointArgs, JsonSaveArgs]: item = args.item if isinstance(item, tuple): metadata = {'legacy': True} elif isinstance(item, dict): metadata = {'legacy': False} else: raise TypeError(f'Unsupported type: {type(item)}.') return (PyTreeSaveArgs(item), JsonSaveArgs(metadata)) def checkpoint_restore_args( self, args: NumpyRandomKeyRestoreArgs ) -> CheckpointArgs: return PyTreeRestoreArgs() def post_restore(self, item: Any, metadata: Mapping[Any, Any]) -> Any: if metadata['legacy']: return tuple(item) else: return item
[docs] @register_with_handler(NumpyRandomKeyCheckpointHandler, for_save=True) @dataclasses.dataclass class NumpyRandomKeySaveArgs(CheckpointArgs): """Parameters for saving a Numpy random key. Attributes: item (required): a Numpy random key in legacy or nonlegacy format """ item: NumpyRandomKeyType
[docs] @register_with_handler(NumpyRandomKeyCheckpointHandler, for_restore=True) @dataclasses.dataclass class NumpyRandomKeyRestoreArgs(CheckpointArgs): """Numpy random key restore args.""" pass