# Copyright 2024 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ArrayCheckpointHandler for saving and restoring individual arrays/scalars."""
import asyncio
import dataclasses
from typing import List, Optional, Union
from etils import epath
import jax
import numpy as np
from orbax.checkpoint import aggregate_handlers
from orbax.checkpoint import async_checkpoint_handler
from orbax.checkpoint import checkpoint_args
from orbax.checkpoint import future
from orbax.checkpoint import type_handlers
from orbax.checkpoint import utils
CheckpointArgs = checkpoint_args.CheckpointArgs
register_with_handler = checkpoint_args.register_with_handler
ArrayType = Union[int, float, np.number, np.ndarray, jax.Array]
_ELEMENT_KEY = 'ELEMENT'
[docs]class ArrayCheckpointHandler(async_checkpoint_handler.AsyncCheckpointHandler):
"""Handles saving and restoring individual arrays and scalars."""
[docs] def __init__(self, checkpoint_name: Optional[str] = None):
"""Initializes the handler.
Args:
checkpoint_name: Provides a name for the directory under which Tensorstore
files will be saved. Defaults to 'checkpoint'.
"""
if not checkpoint_name:
checkpoint_name = 'checkpoint'
self._checkpoint_name = checkpoint_name
self._aggregate_handler = aggregate_handlers.MsgpackHandler()
def _is_supported_type(self, item: ArrayType) -> bool:
return isinstance(item, (np.ndarray, jax.Array)) or utils.is_scalar(item)
[docs] async def async_save(
self,
directory: epath.Path,
item: Optional[ArrayType] = None,
save_args: Optional[type_handlers.SaveArgs] = None,
args: Optional['ArraySaveArgs'] = None,
) -> Optional[List[future.Future]]:
"""Saves an object asynchronously.
Args:
directory: Folder in which to save.
item: Deprecated, use `args`.
save_args: Deprecated, use `args`.
args: An ocp.array_checkpoint_handler.ArraySaveArgs (see below).
Returns:
A list of commit futures which can be run to complete the save.
"""
if args is not None:
item = args.item
save_args = args.save_args
if not self._is_supported_type(item):
raise TypeError(f'Unsupported type: {type(item)}.')
if not save_args:
save_args = type_handlers.SaveArgs()
if save_args.aggregate:
return [
await self._aggregate_handler.serialize(
directory / self._checkpoint_name, {_ELEMENT_KEY: item}
)
]
info = type_handlers.ParamInfo(
name=self._checkpoint_name,
path=directory / self._checkpoint_name,
parent_dir=directory,
is_ocdbt_checkpoint=False,
)
type_handler = type_handlers.get_type_handler(type(item))
futures = await type_handler.serialize([item], [info], args=[save_args])
return list(futures)
[docs] def save(self, directory: epath.Path, *args, **kwargs):
"""Saves an array synchronously."""
async def async_save():
commit_futures = await self.async_save(directory, *args, **kwargs) # pytype: disable=bad-return-type
# Futures are already running, so sequential waiting is equivalent to
# concurrent waiting.
for f in commit_futures:
f.result() # Block on result.
asyncio.run(async_save())
[docs] def restore(
self,
directory: epath.Path,
item: Optional[ArrayType] = None,
restore_args: Optional[type_handlers.RestoreArgs] = None,
args: Optional['ArrayRestoreArgs'] = None,
) -> ArrayType:
"""Restores an object.
Args:
directory: folder from which to read.
item: Deprecated, use `args`.
restore_args: Deprecated, use `args`.
args: An ocp.array_checkpoint_handler.ArrayRestoreArgs object (see below).
Returns:
The restored object.
"""
if args is None:
args = ArrayRestoreArgs(item=item, restore_args=restore_args)
restore_args = args.restore_args or type_handlers.RestoreArgs()
checkpoint_path = directory / self._checkpoint_name
if checkpoint_path.exists() and checkpoint_path.is_file():
result = self._aggregate_handler.deserialize(checkpoint_path)
result = result[_ELEMENT_KEY]
if not self._is_supported_type(result):
raise TypeError(f'Unsupported type: {type(result)}.')
if isinstance(restore_args, type_handlers.ArrayRestoreArgs):
result = result.reshape(restore_args.global_shape)
sharding = restore_args.sharding or jax.sharding.NamedSharding(
restore_args.mesh, restore_args.mesh_axes
)
result = jax.make_array_from_callback(
result.shape, sharding, lambda idx: result[idx]
)
else:
info = type_handlers.ParamInfo(
name=self._checkpoint_name,
path=checkpoint_path,
parent_dir=directory,
skip_deserialize=False,
is_ocdbt_checkpoint=type_handlers.is_ocdbt_checkpoint(
directory
),
)
restore_type = restore_args.restore_type
if restore_type is None:
restore_type = type_handlers.default_restore_type(restore_args)
type_handler = type_handlers.get_type_handler(restore_type)
result = asyncio.run(
type_handler.deserialize([info], args=[restore_args])
)[0]
return result
[docs] def finalize(self, directory: epath.Path):
type_handlers.merge_ocdbt_per_process_files(directory)
[docs] def close(self):
"""See superclass documentation."""
self._aggregate_handler.close()
@register_with_handler(ArrayCheckpointHandler, for_save=True)
@dataclasses.dataclass
class ArraySaveArgs(CheckpointArgs):
"""Parameters for saving an array or scalar.
Attributes:
item (required): an array or scalar object.
save_args: a `ocp.SaveArgs` object specifying save options.
"""
item: ArrayType
save_args: Optional[type_handlers.SaveArgs] = None
@register_with_handler(ArrayCheckpointHandler, for_restore=True)
@dataclasses.dataclass
class ArrayRestoreArgs(CheckpointArgs):
"""Array restore args.
Attributes:
item: unused, but provided as an option for legacy-compatibility reasons.
restore_args: a `ocp.RestoreArgs` object specifying restore options.
"""
item: Optional[ArrayType] = None
restore_args: Optional[type_handlers.RestoreArgs] = None