Source code for orbax.checkpoint.random_key_checkpoint_handler

# Copyright 2024 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""RandomKeyCheckpointHandlers for saving and restoring individual Jax and Numpy random keys."""

import abc
import asyncio
import dataclasses
from typing import Any, List, Mapping, Optional, Tuple, Union

from etils import epath
import jax
from orbax.checkpoint import array_checkpoint_handler
from orbax.checkpoint import async_checkpoint_handler
from orbax.checkpoint import checkpoint_args
from orbax.checkpoint import composite_checkpoint_handler
from orbax.checkpoint import future
from orbax.checkpoint import json_checkpoint_handler
from orbax.checkpoint import pytree_checkpoint_handler
from orbax.checkpoint import type_handlers

NumpyRandomKeyType = Union[tuple, dict]

ArrayRestoreArgs = array_checkpoint_handler.ArrayRestoreArgs
ArraySaveArgs = array_checkpoint_handler.ArraySaveArgs
CheckpointArgs = checkpoint_args.CheckpointArgs
CompositeArgs = composite_checkpoint_handler.CompositeArgs
CompositeCheckpointHandler = (
    composite_checkpoint_handler.CompositeCheckpointHandler
)
JsonRestoreArgs = json_checkpoint_handler.JsonRestoreArgs
JsonSaveArgs = json_checkpoint_handler.JsonSaveArgs
PyTreeRestoreArgs = pytree_checkpoint_handler.PyTreeRestoreArgs
PyTreeSaveArgs = pytree_checkpoint_handler.PyTreeSaveArgs
register_with_handler = checkpoint_args.register_with_handler


class BaseRandomKeyCheckpointHandler(
    async_checkpoint_handler.AsyncCheckpointHandler
):
  """Base handle saving and restoring individual Jax random key in both typed and untyped format."""

  def __init__(self, key_name: str):
    """Initializes the handler.

    Args:
      key_name: Provides a name for the directory under which Tensorstore files
        will be saved. Defaults to 'random_key'.
    """
    self._key_name = key_name
    self._key_metadata = f'{self._key_name}_metadata'
    self._handler = CompositeCheckpointHandler(
        self._key_name, self._key_metadata
    )

  @abc.abstractmethod
  def checkpoint_save_args(
      self, args: CheckpointArgs
  ) -> Tuple[CheckpointArgs, JsonSaveArgs]:
    """Return the `args` for saving item and metadata.

    Args:
      args: An ocp.checkpoint_args.CheckpointArgs.

    Returns:
      Return a tuple of CheckpointArgs, and JsonSaveArgs.
      The CheckpointArgs is to save the actual random key and the JsonSaveArgs
      is to save any extra metadata
    """
    pass

  @abc.abstractmethod
  def checkpoint_restore_args(
      self,
      args: CheckpointArgs,
  ) -> CheckpointArgs:
    """Return the `args` for retoring the item.

    Args:
      args: An ocp.checkpoint_args.CheckpointArgs.

    Returns:
      Return a CheckpointArgs to restore the item
    """
    pass

  @abc.abstractmethod
  def post_restore(self, item: Any, metadata: Mapping[Any, Any]) -> Any:
    """Final operations needed to be done on the returning result."""
    pass

  async def async_save(
      self,
      directory: epath.Path,
      args: CheckpointArgs,
  ) -> Optional[List[future.Future]]:
    """Saves a random key asynchronously.

    Args:
      directory: Folder in which to save.
      args: An ocp.checkpoint_args.CheckpointArgs.

    Returns:
      A list of commit futures which can be run to complete the save.
    """
    item_arg, metadata_arg = self.checkpoint_save_args(args)
    return await self._handler.async_save(
        directory,
        CompositeArgs(**{
            self._key_name: item_arg,
            self._key_metadata: metadata_arg,
        }),
    )

  def save(self, directory: epath.Path, *args, **kwargs):
    """Saves a random key synchronously."""

    async def async_save():
      commit_futures = await self.async_save(directory, *args, **kwargs)  # pytype: disable=bad-return-type
      # Futures are already running, so sequential waiting is equivalent to
      # concurrent waiting.
      for f in commit_futures:
        f.result()  # Block on result.

    asyncio.run(async_save())

  def restore(
      self,
      directory: epath.Path,
      args: CheckpointArgs,
  ) -> Any:
    """Restores a random key.

    Args:
      directory: folder from which to read.
      args: An ocp.checkpoint_args.CheckpointArgs.

    Returns:
      The restored object.
    """
    item_arg = self.checkpoint_restore_args(args)
    result = self._handler.restore(
        directory,
        args=CompositeArgs(**{
            self._key_name: item_arg,
            self._key_metadata: JsonRestoreArgs(),
        }),
    )

    return self.post_restore(
        result[self._key_name], result[self._key_metadata]
    )

  def finalize(self, directory: epath.Path):
    self._handler.finalize(directory)

  def close(self):
    self._handler.close()


[docs]class JaxRandomKeyCheckpointHandler(BaseRandomKeyCheckpointHandler): """Handles saving and restoring individual Jax random key in both typed and untyped format.""" def __init__(self, key_name: Optional[str] = None): """Initializes the handler. Args: key_name: Provides a name for the directory under which Tensorstore files will be saved. Defaults to 'jax_random_key'. """ super().__init__(key_name or 'jax_random_key') def checkpoint_save_args( self, args: 'JaxRandomKeySaveArgs' ) -> Tuple[CheckpointArgs, JsonSaveArgs]: item = args.item save_args = args.save_args if isinstance(item, jax.Array): if is_typed := jax.dtypes.issubdtype(item.dtype, jax.dtypes.prng_key): item = jax.random.key_data(item) metadata = {'typed': is_typed} else: raise TypeError(f'Unsupported type: {type(item)}.') return ( ArraySaveArgs(jax.random.key_data(item), save_args), JsonSaveArgs(metadata), ) def checkpoint_restore_args( self, args: 'JaxRandomKeyRestoreArgs' ) -> CheckpointArgs: return ArrayRestoreArgs(restore_args=args.restore_args) def post_restore(self, item: Any, metadata: Mapping[Any, Any]) -> Any: if metadata['typed']: return jax.random.wrap_key_data(item) else: return item
@register_with_handler(JaxRandomKeyCheckpointHandler, for_save=True) @dataclasses.dataclass class JaxRandomKeySaveArgs(CheckpointArgs): """Parameters for saving a JAX random key. Attributes: item (required): a JAX random key. save_args: a `ocp.SaveArgs` object specifying save options. """ item: jax.Array save_args: Optional[type_handlers.SaveArgs] = None @register_with_handler(JaxRandomKeyCheckpointHandler, for_restore=True) @dataclasses.dataclass class JaxRandomKeyRestoreArgs(CheckpointArgs): """Jax random key restore args. Attributes: restore_args: a `ocp.RestoreArgs` object specifying restore options for JaxArray. """ restore_args: Optional[type_handlers.RestoreArgs] = None
[docs]class NumpyRandomKeyCheckpointHandler(BaseRandomKeyCheckpointHandler): """Saves Nnumpy random key in legacy or non-lagacy format.""" def __init__(self, key_name: Optional[str] = None): """Initializes NumpyRandomKeyCheckpointHandler. Args: key_name: Provides a name for the directory under which Tensorstore files will be saved. Defaults to 'np_random_key'. """ super().__init__(key_name or 'np_random_key') def checkpoint_save_args( self, args: 'NumpyRandomKeySaveArgs' ) -> Tuple[CheckpointArgs, JsonSaveArgs]: item = args.item if isinstance(item, tuple): metadata = {'legacy': True} elif isinstance(item, dict): metadata = {'legacy': False} else: raise TypeError(f'Unsupported type: {type(item)}.') return (PyTreeSaveArgs(item), JsonSaveArgs(metadata)) def checkpoint_restore_args( self, args: 'NumpyRandomKeyRestoreArgs' ) -> CheckpointArgs: return PyTreeRestoreArgs() def post_restore(self, item: Any, metadata: Mapping[Any, Any]) -> Any: if metadata['legacy']: return tuple(item) else: return item
@register_with_handler(NumpyRandomKeyCheckpointHandler, for_save=True) @dataclasses.dataclass class NumpyRandomKeySaveArgs(CheckpointArgs): """Parameters for saving a Numpy random key. Attributes: item (required): a Numpy random key in legacy or nonlegacy format """ item: NumpyRandomKeyType @register_with_handler(NumpyRandomKeyCheckpointHandler, for_restore=True) @dataclasses.dataclass class NumpyRandomKeyRestoreArgs(CheckpointArgs): """Numpy random key restore args.""" pass