# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides handlers for Jax Arrays."""
from __future__ import annotations
import asyncio
import dataclasses
import functools
import os
import time
from typing import Any, Callable, Dict, Sequence, Set, Tuple, TypeAlias, Union, cast
import warnings
from absl import logging
import humanize
import jax
import jax.numpy as jnp
import numpy as np
from orbax.checkpoint._src import asyncio_utils
from orbax.checkpoint._src.futures import future
from orbax.checkpoint._src.metadata import array_metadata as array_metadata_lib
from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib
from orbax.checkpoint._src.metadata import sharding as sharding_metadata
from orbax.checkpoint._src.metadata import value as value_metadata
from orbax.checkpoint._src.multihost import dispatchers
from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint._src.multihost import multislice
from orbax.checkpoint._src.path import async_path
from orbax.checkpoint._src.path import utils as path_utils
from orbax.checkpoint._src.serialization import jax_array_restore_args
from orbax.checkpoint._src.serialization import limits
from orbax.checkpoint._src.serialization import ocdbt_utils
from orbax.checkpoint._src.serialization import replica_slices
from orbax.checkpoint._src.serialization import serialization
from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils
from orbax.checkpoint._src.serialization import types
from orbax.checkpoint._src.serialization import worker_memory_utils
from orbax.checkpoint._src.tree import utils as tree_utils
import tensorstore as ts
Pytree: TypeAlias = Any
ArrayRestoreArgs = jax_array_restore_args.ArrayRestoreArgs
SingleReplicaArrayRestoreArgs = (
jax_array_restore_args.SingleReplicaArrayRestoreArgs
)
_SHARDING_FILE_NAME = '_sharding'
def check_array_values(
values: Sequence[Union[jax.Array, np.ndarray]],
infos: Sequence[types.ParamInfo],
raise_error: bool = True,
):
"""Checks array values for zero size."""
for v, info in zip(values, infos):
if v.size == 0:
if raise_error:
raise ValueError(
f'Cannot save arrays with zero size: ParamInfo: [name={info.name},'
f'value_typestr={info.value_typestr}]'
)
else:
logging.warning(
'Saving array with zero size: ParamInfo: [name=%s,'
' value_typestr=%s]',
info.name,
info.value_typestr,
)
JAX_ARRAY_TYPE_STR = 'jax.Array'
def represents_jax_array(param_info: types.ParamInfo) -> bool:
"""Returns True if the param_info represents a jax.Array."""
assert (
param_info.value_typestr is not None
), f'ParamInfo.value_typestr cannot be None: {param_info}'
return param_info.value_typestr == JAX_ARRAY_TYPE_STR
def any_jax_array_param_info(param_infos: Pytree) -> types.ParamInfo | None:
"""Returns any jax.Array param_info in the PyTree, or None."""
return jax.tree_util.tree_reduce(
lambda found_jax_array, param_info: (
found_jax_array
or (param_info if represents_jax_array(param_info) else None)
),
tree=param_infos,
initializer=None,
)
def _has_prng_key_dtype(arg: Any) -> bool:
"""Returns True if the dtype of arg is a PRNG key dtype."""
return arg.dtype is not None and jax.dtypes.issubdtype(
arg.dtype, jax.dtypes.prng_key
)
def _get_underlying_shape(
shape: tuple[int, ...] | None, dtype: Any
) -> tuple[int, ...] | None:
"""Returns the data shape for underlying data of PRNG keys."""
if shape is None:
return None
return jax.eval_shape(
jax.random.key_data, jax.ShapeDtypeStruct(shape=shape, dtype=dtype)
).shape
@functools.lru_cache(maxsize=4096)
def _is_replicated_sharding(sharding: jax.sharding.Sharding) -> bool:
"""Returns True if the sharding is replicated.
This is to provide a quick check to decide whether to the sharding would
produce replicated data. For namedsharding, if any axis is not specified in
the PartitionSpec, it is considered as replicated. This function doesn't take
in the array shape into account as the shape isn't know at the point of
deserialization.
We can cache results because we typically expect `save` to be called
repeatedly on the same model (with changing array values).
Args:
sharding: The sharding to check.
Returns:
True if the sharding is replicated.
"""
if isinstance(sharding, jax.sharding.NamedSharding):
pspec = sharding.spec
pspec_len = len(pspec)
mesh_len = len(sharding.mesh.axis_names)
if mesh_len > pspec_len or not pspec or any((i is None for i in pspec)):
# replica
return True
else:
return False
elif isinstance(sharding, jax.sharding.SingleDeviceSharding):
return True
else:
logging.warning(
'Unsupported sharding type, assuming not replicated: %s', sharding
)
return False
async def _async_serialize_shardings(
shardings: Sequence[jax.sharding.Sharding | None],
infos: Sequence[types.ParamInfo],
*,
primary_host: int | None,
):
"""Serializes sharding metadata."""
sharding_metadata_txn = ts.Transaction()
for sharding, info in zip(shardings, infos):
if sharding is None:
continue
await info.await_path_creation()
if info.parent_dir is None:
raise ValueError('parent_dir cannot be None')
tspec_sharding = ts_utils.get_sharding_tensorstore_spec(
info.parent_dir.as_posix(), info.name
)
if multihost.is_primary_host(primary_host):
# OCDBT is not used for sharding metadata.
sharding_ts_context = info.ts_context
t = await ts.open(
tspec_sharding,
open=True,
context=sharding_ts_context,
)
serialized_sharding = None
sharding_metadata_value = sharding_metadata.from_jax_sharding(sharding)
if sharding_metadata_value is not None:
serialized_sharding = sharding_metadata_value.to_serialized_string()
if serialized_sharding is not None:
await t.with_transaction(sharding_metadata_txn).write( # pytype: disable=attribute-error
serialized_sharding
)
await sharding_metadata_txn.commit_async()
def _get_replica_slices(
arrays: Sequence[jax.Array],
replica_id: int,
use_replica_parallel: bool,
min_slice_bytes_for_replica_parallel: int | None = None,
max_replicas_for_replica_parallel: int | None = None,
) -> Sequence[replica_slices.ReplicaSlices]:
"""Returns ReplicaSlices for arrays."""
rslices_per_array = [
replica_slices.get_replica_slices(
arr,
replica_id,
use_replica_parallel,
min_slice_bytes_for_replica_parallel,
max_replicas_for_replica_parallel,
)
for arr in arrays
]
# D2H copy is performed automatically as part of dispatcher call, but
# we must set properties correctly to pass later consistency checks.
return [
dataclasses.replace(
rslices,
is_on_host=True,
replica_slices=[
dataclasses.replace(
rslice,
unsliced_data=np.asarray(rslice.data()),
slice_args=None,
)
for rslice in rslices.replica_slices
],
)
for rslices in rslices_per_array
]
def _worker_serialize_arrays(
arrays: Sequence[jax.Array],
infos: Sequence[types.ParamInfo],
args: Sequence[types.SaveArgs],
replica_id: int,
use_replica_parallel: bool,
min_slice_bytes_for_replica_parallel: int | None,
max_replicas_for_replica_parallel: int | None,
primary_host: int | None,
metadata_key: str | None,
array_metadata_store: array_metadata_store_lib.Store | None,
enable_replica_parallel_separate_folder: bool,
ext_metadata: Dict[str, Any],
):
"""Worker function to serialize arrays."""
rslices_per_array = _get_replica_slices(
arrays,
replica_id,
use_replica_parallel,
min_slice_bytes_for_replica_parallel,
max_replicas_for_replica_parallel,
)
asyncio_utils.run_sync(
_async_serialize_replica_slices(
rslices_per_array,
infos,
args,
primary_host=primary_host,
metadata_key=metadata_key,
array_metadata_store=array_metadata_store,
enable_replica_parallel_separate_folder=enable_replica_parallel_separate_folder,
use_replica_parallel=use_replica_parallel,
ext_metadata=ext_metadata,
)
)
def _get_deprioritized_batches_to_serialize(
deprioritized_params: Sequence[
tuple[jax.Array, types.ParamInfo, types.SaveArgs]
],
*,
device_host_max_bytes: int,
replica_id: int | None,
dispatcher: dispatchers.Dispatcher | None,
):
"""Yields batches of info, args, and arrays that fit within the memory budget."""
logging.info(
'Option `device_host_max_bytes` was set to %s. Using memory-limited'
' saving. Note that this feature may impact saving speed.',
humanize.naturalsize(device_host_max_bytes, binary=True),
)
if deprioritized_params:
arrays_saved_count = 0
for batch in worker_memory_utils.next_memory_budgeted_batch(
deprioritized_params,
device_host_max_bytes,
replica_id=replica_id,
dispatcher=dispatcher,
):
assert arrays_saved_count < len(deprioritized_params)
logging.info(
'Scheduling serialization of %d deprioritized arrays. Already'
' completed %d / %d arrays. Included keys: %s',
len(batch),
arrays_saved_count,
len(deprioritized_params),
[tree_utils.str_keypath(info.keypath) for _, info, _ in batch],
)
yield zip(*batch)
logging.info(
'Serialization of %d deprioritized jax.Array completed.',
len(batch),
)
arrays_saved_count += len(batch)
assert arrays_saved_count == len(deprioritized_params)
def _on_batch_callback(
infos: Sequence[types.ParamInfo],
callback_fn: Callable[..., None],
) -> None:
"""Launches callback for each info."""
for info in infos:
callback_fn(info.keypath)
def _serialize_arrays_batches_without_dispatcher(
prioritized: Sequence[tuple[jax.Array, types.ParamInfo, types.SaveArgs]],
deprioritized: Sequence[tuple[jax.Array, types.ParamInfo, types.SaveArgs]],
device_host_max_bytes: int | None,
replica_id: int | None,
use_replica_parallel: bool,
min_slice_bytes_for_replica_parallel: int | None,
max_replicas_for_replica_parallel: int | None,
primary_host: int | None,
metadata_key: str | None,
array_metadata_store: array_metadata_store_lib.Store | None,
enable_replica_parallel_separate_folder: bool,
ext_metadata: Dict[str, Any],
enable_pinned_host_transfer: bool,
callback: types.SerializationStatusCallback,
) -> future.Future:
"""Serializes arrays batches without dispatcher."""
# Complete D2H transfer in parallel for each array for prioritized values.
replica_slices_transfer_arrays_to_host = functools.partial(
replica_slices.transfer_arrays_to_host,
replica_id=replica_id,
use_replica_parallel=use_replica_parallel,
enable_pinned_host_transfer=enable_pinned_host_transfer,
min_slice_bytes_for_replica_parallel=min_slice_bytes_for_replica_parallel,
max_replicas_for_replica_parallel=max_replicas_for_replica_parallel,
)
async_serialize_replica_slices_batch = functools.partial(
_async_serialize_replica_slices,
primary_host=primary_host,
metadata_key=metadata_key,
array_metadata_store=array_metadata_store,
enable_replica_parallel_separate_folder=enable_replica_parallel_separate_folder,
use_replica_parallel=use_replica_parallel,
ext_metadata=ext_metadata,
)
prioritized_values_on_host = []
prioritized_infos = []
prioritized_args = []
if prioritized:
logging.info(
'Scheduling D2H of %d prioritized jax.Array.',
len(prioritized),
)
prioritized_arrays, prioritized_infos, prioritized_args = zip(*prioritized)
prioritized_values_on_host = replica_slices_transfer_arrays_to_host(
prioritized_arrays
)
_on_batch_callback(prioritized_infos, callback.on_transfer_end)
else:
logging.warning(
'No prioritized params found for saving. D2H for all values will be'
' scheduled asynchronously.'
)
async def _serialize_without_dispatcher():
if prioritized_values_on_host:
await async_serialize_replica_slices_batch(
prioritized_values_on_host,
prioritized_infos,
prioritized_args,
)
_on_batch_callback(prioritized_infos, callback.on_write_end)
if deprioritized:
assert device_host_max_bytes is not None
for (
b_arrays,
b_infos,
b_args,
) in _get_deprioritized_batches_to_serialize(
deprioritized,
device_host_max_bytes=device_host_max_bytes,
# TODO(b/436858989): We overestimate memory usage for now if replica
# parallel is enabled, as each host has a non-trivial calculation for
# bytes transferred to host.
replica_id=None if use_replica_parallel else replica_id,
dispatcher=None,
):
b_arrays_on_host = replica_slices_transfer_arrays_to_host(b_arrays)
_on_batch_callback(b_infos, callback.on_transfer_end)
await async_serialize_replica_slices_batch(
b_arrays_on_host,
b_infos,
b_args,
)
_on_batch_callback(b_infos, callback.on_write_end)
return future.CommitFutureAwaitingContractedSignals(
_serialize_without_dispatcher(),
name='array_type_handler',
)
def _serialize_arrays(
arrays: Sequence[jax.Array],
infos: Sequence[types.ParamInfo],
args: Sequence[types.SaveArgs],
dispatcher: dispatchers.Dispatcher | None,
replica_id: int | None,
use_replica_parallel: bool,
min_slice_bytes_for_replica_parallel: int | None,
max_replicas_for_replica_parallel: int | None,
primary_host: int | None,
metadata_key: str | None,
array_metadata_store: array_metadata_store_lib.Store | None,
enable_replica_parallel_separate_folder: bool,
ext_metadata: Dict[str, Any],
callback: types.SerializationStatusCallback,
) -> future.Future:
"""D2H transfer and serialize arrays using dispatcher if provided."""
device_host_max_bytes = None
if byte_limiter := infos[0].device_host_byte_limiter:
if isinstance(byte_limiter, limits.LimitInFlightBytes):
device_host_max_bytes = byte_limiter.max_bytes
prioritized: list[tuple[jax.Array, types.ParamInfo, types.SaveArgs]] = []
prioritized_async: list[tuple[jax.Array, types.ParamInfo, types.SaveArgs]] = (
[]
)
deprioritized: list[tuple[jax.Array, types.ParamInfo, types.SaveArgs]] = []
if device_host_max_bytes is None:
for info, arg, value in zip(infos, args, arrays):
prioritized.append((value, info, arg))
else:
for info, arg, value in zip(infos, args, arrays):
prioritization = callback.key_priority(info.keypath)
if prioritization == types.TransferPriority.SYNCHRONOUS:
prioritized.append((value, info, arg))
elif prioritization == types.TransferPriority.ASYNCHRONOUS_PRIORITIZED:
prioritized_async.append((value, info, arg))
elif prioritization == types.TransferPriority.ASYNCHRONOUS_DEPRIORITIZED:
deprioritized.append((value, info, arg))
elif prioritization == types.TransferPriority.UNKNOWN:
raise ValueError(
f'Prioritization is unknown for key {info.keypath}.'
)
deprioritized = prioritized_async + deprioritized
if dispatcher is None:
return _serialize_arrays_batches_without_dispatcher(
prioritized,
deprioritized,
device_host_max_bytes,
replica_id,
use_replica_parallel,
min_slice_bytes_for_replica_parallel,
max_replicas_for_replica_parallel,
primary_host,
metadata_key,
array_metadata_store,
enable_replica_parallel_separate_folder,
ext_metadata,
infos[0].enable_pinned_host_transfer,
callback,
)
else:
def _serialize_batch(
batch_infos: Sequence[types.ParamInfo],
batch_args: Sequence[types.SaveArgs],
batch_arrays: Sequence[jax.Array],
):
ret = dispatcher.dispatch(
_worker_serialize_arrays,
input_arrays=batch_arrays,
func_kwargs={
'infos': batch_infos,
'args': batch_args,
'replica_id': replica_id,
'use_replica_parallel': use_replica_parallel,
'min_slice_bytes_for_replica_parallel': (
min_slice_bytes_for_replica_parallel
),
'max_replicas_for_replica_parallel': (
max_replicas_for_replica_parallel
),
'primary_host': primary_host,
'metadata_key': metadata_key,
'array_metadata_store': array_metadata_store,
'enable_replica_parallel_separate_folder': (
enable_replica_parallel_separate_folder
),
'ext_metadata': ext_metadata,
},
)
_on_batch_callback(batch_infos, callback.on_transfer_end)
jax.block_until_ready(ret)
_on_batch_callback(batch_infos, callback.on_write_end)
# Enqueue D2H operation for prioritized values.
if prioritized:
logging.info(
'Scheduling D2H of %d prioritized jax.Array.',
len(prioritized),
)
prioritized_arrays, prioritized_infos, prioritized_args = zip(
*prioritized
)
prioritized_arrays = dispatcher.device_to_host(prioritized_arrays)
prioritized = [
(v, i, a)
for v, i, a in zip(
prioritized_arrays, prioritized_infos, prioritized_args
)
]
else:
logging.warning(
'No prioritized params found for saving. D2H for all values will be'
' scheduled asynchronously.'
)
all_infos = infos
async def _serialize():
for info in all_infos:
await info.await_path_creation()
if prioritized:
arrays, infos, args = zip(*prioritized)
_serialize_batch(infos, args, arrays)
if deprioritized:
assert device_host_max_bytes is not None
for (
b_arrays,
b_infos,
b_args,
) in _get_deprioritized_batches_to_serialize(
deprioritized,
device_host_max_bytes=device_host_max_bytes,
replica_id=replica_id,
dispatcher=dispatcher,
):
_serialize_batch(b_infos, b_args, b_arrays)
return future.CommitFutureAwaitingContractedSignals(
_serialize(),
name='array_type_handler',
)
async def _async_serialize_replica_slices(
values: Sequence[replica_slices.ReplicaSlices],
infos: Sequence[types.ParamInfo],
args: Sequence[types.SaveArgs],
*,
primary_host: int | None,
metadata_key: str | None,
array_metadata_store: array_metadata_store_lib.Store | None,
enable_replica_parallel_separate_folder: bool,
use_replica_parallel: bool,
ext_metadata: Dict[str, Any],
) -> None:
"""This function contains the logic from ArrayHandler._background_serialize."""
write_coros = []
array_metadatas = []
use_transaction = (
infos[0].is_ocdbt_checkpoint
and (
infos[0].byte_limiter is None
or isinstance(infos[0].byte_limiter, limits.UnlimitedInFlightBytes)
)
)
ocdbt_transaction = ts.Transaction(atomic=True) if use_transaction else None
for value, info, arg in zip(values, infos, args):
replica_separate_folder = False
if use_replica_parallel and enable_replica_parallel_separate_folder:
if info.is_ocdbt_checkpoint:
replica_separate_folder = _is_replicated_sharding(value.sharding)
else:
logging.log_first_n(
logging.WARNING,
'Replica_separate_folder is disabled as OCDBT is not enabled.',
1,
)
await info.await_path_creation()
array_write_spec = ts_utils.build_array_write_spec(
info=info,
arg=arg,
global_shape=value.global_shape,
local_shape=value.local_shape,
dtype=value.dtype,
use_ocdbt=info.is_ocdbt_checkpoint,
process_index=ocdbt_utils.get_process_index_for_subdir(
info.is_ocdbt_checkpoint
),
replica_separate_folder=replica_separate_folder,
metadata_key=metadata_key,
ext_metadata=ext_metadata.get(info.name),
)
tspec = array_write_spec.json
ts_context = info.ts_context
if logging.vlog_is_on(1):
logging.vlog(1, 'info: %s', info)
logging.vlog(1, 'arg: %s', arg)
logging.vlog(
1,
'value.global_shape: %s, value.sharding: %s',
value.global_shape,
value.sharding,
)
logging.vlog(1, 'tspec: %s', tspec)
write_coros.append(
serialization.async_serialize_from_host(
value,
tspec,
primary_host=primary_host,
context=ts_context,
transaction=ocdbt_transaction,
byte_limiter=info.byte_limiter,
)
)
array_metadatas.append(array_write_spec.metadata)
if array_metadata_store is not None:
write_coros.append(
array_metadata_store.write(
checkpoint_dir=infos[0].parent_dir,
array_metadatas=array_metadatas,
process_index=multihost.process_index(),
)
)
await asyncio.gather(*write_coros)
if ocdbt_transaction is not None:
await ocdbt_transaction.commit_async()
def _wrap_random_key_data(
array_metadatas: Any,
infos: Sequence[types.ParamInfo],
deserialized_arrays: list[jax.Array],
) -> Sequence[jax.Array]:
"""Parse array_metadatas and wrap deserialized_arrays as random keys."""
logging.vlog(1, 'array_metadatas = %s', array_metadatas)
if not isinstance(array_metadatas, Dict):
raise ValueError(
'Expecting array_metadatas to be a "Dict" but got'
f' {type(array_metadatas)}.'
)
# use the first available array_metadata
array_metadatas_cache = {
array_metadata.param_name: array_metadata
for array_metadata in next(iter(array_metadatas.values()))
}
for i, (info, v) in enumerate(zip(infos, deserialized_arrays)):
if meta := array_metadatas_cache.get(info.name):
assert isinstance(
meta, array_metadata_lib.SerializedArrayMetadata
), f'Expecting SerializedArrayMetadata but got {type(meta)}.'
if meta.ext_metadata is None or not isinstance(meta.ext_metadata, dict):
continue
if impl := meta.ext_metadata.get(array_metadata_lib.RANDOM_KEY_IMPL): # pytype: disable=attribute-error
deserialized_arrays[i] = jax.random.wrap_key_data(v, impl=impl)
logging.vlog(
1,
'%s: recreated as a random key: %s',
info.name,
deserialized_arrays[i],
)
return deserialized_arrays
def _validate_ocdbt_settings(infos: Sequence[types.ParamInfo]) -> bool:
"""Checks that all parameters have matching OCDBT flags set."""
assert infos
use_ocdbt = infos[0].is_ocdbt_checkpoint
for info in infos:
if info.is_ocdbt_checkpoint != use_ocdbt:
raise ValueError(
f"OCDBT settings for parameter {info.name} don't match those for the"
f' rest of parameters: got ({info.is_ocdbt_checkpoint=}, expected'
f' {use_ocdbt=}.'
)
if use_ocdbt is None:
raise ValueError('Setting of `use_ocdbt` may not be None.')
return use_ocdbt
async def _validate_non_ocdbt_files(
infos: Sequence[types.ParamInfo], metadata_key: str
):
await asyncio.gather(*[
ts_utils.assert_parameter_files_exist( # pylint: disable=protected-access
info.parent_dir / info.name, metadata_key, info.use_zarr3
)
for info in infos
])
async def _deserialize_shardings(
infos: Sequence[types.ParamInfo],
args: Sequence[types.RestoreArgs],
sharding_file_exists: bool,
) -> Sequence[Any]:
"""Deserializes shardings from file or infers from args."""
shardings = []
for info, arg in zip(infos, args):
sharding = None
if (
isinstance(arg, ArrayRestoreArgs)
and arg.mesh is not None
and arg.mesh_axes is not None
):
sharding = jax.sharding.NamedSharding(arg.mesh, arg.mesh_axes)
elif isinstance(arg, ArrayRestoreArgs) and arg.sharding is not None:
if isinstance(arg.sharding, sharding_metadata.ShardingMetadata):
sharding = arg.sharding.to_jax_sharding()
else:
sharding = arg.sharding
elif sharding_file_exists:
warnings.warn(
'Sharding info not provided when restoring. Populating sharding'
' info from sharding file. Please note restoration time will be'
' slightly increased due to reading from file. Note also that this'
' option is unsafe when restoring on a different topology than the'
' checkpoint was saved with.'
)
assert info.parent_dir is not None
if info.name is not None:
tspec_sharding = ts_utils.get_sharding_tensorstore_spec(
info.parent_dir.as_posix(), info.name
)
t = await ts.open(
tspec_sharding,
# OCDBT is not used for sharding metadata.
context=info.ts_context,
open=True,
read=True,
)
serialized_string = await t.read() # pytype: disable=attribute-error
if serialized_string:
sharding = sharding_metadata.get_sharding_or_none(serialized_string)
else:
raise ValueError('Unable to deserialize sharding.')
else:
raise ValueError(
'Sharding of jax.Array cannot be None. Provide `mesh`'
' and `mesh_axes` OR `sharding`'
)
shardings.append(sharding)
return shardings
async def _deserialize_arrays(
infos: Sequence[types.ParamInfo],
args: Sequence[types.RestoreArgs],
shardings: Sequence[jax.sharding.Sharding],
metadata_key: str | None,
array_metadata_store: array_metadata_store_lib.Store | None,
) -> Sequence[jax.Array]:
"""Deserializes arrays and applies array_metadata if available."""
total_start_time = time.time()
async def _async_deserialize(
infos: Sequence[types.ParamInfo],
args: Sequence[types.RestoreArgs],
shardings: Sequence[jax.sharding.Sharding],
*,
metadata_key: str | None,
) -> tuple[list[jax.Array], int]:
"""This function contains the core TensorStore read logic from ArrayHandler.deserialize."""
use_ocdbt = _validate_ocdbt_settings(infos)
if not use_ocdbt:
await _validate_non_ocdbt_files(infos, metadata_key)
deserialize_ops = []
for info, arg, sharding in zip(infos, args, shardings):
array_read_spec = ts_utils.build_array_read_spec(
info,
use_ocdbt=use_ocdbt,
metadata_key=metadata_key,
raise_array_data_missing_error=info.raise_array_data_missing_error,
target_dtype=arg.dtype,
)
tspec = array_read_spec.json
base_shape = arg.global_shape if hasattr(arg, 'global_shape') else None
if _has_prng_key_dtype(arg):
# set dtype=None to deserialize for random keys
dtype = None
arg_global_shape = _get_underlying_shape(base_shape, arg.dtype)
if isinstance(sharding, jax.sharding.NamedSharding):
# Extend PartitionSpec with None dims for key trailing shape
# instead of forcing replicated, so local-mode reads work.
key_trailing_ndim = (
len(arg_global_shape) - len(base_shape)
if arg_global_shape and base_shape
else 0
)
physical_spec = jax.sharding.PartitionSpec(
*sharding.spec, *([None] * key_trailing_ndim)
)
sharding_for_read = jax.sharding.NamedSharding(
sharding.mesh, physical_spec
)
else:
sharding_for_read = sharding
else:
dtype = arg.dtype
arg_global_shape = base_shape
sharding_for_read = sharding
if logging.vlog_is_on(1):
logging.vlog(1, 'tspec = %s', tspec)
logging.vlog(1, 'info = %s', info)
logging.vlog(1, 'arg = %s', arg)
logging.vlog(1, 'dtype = %s', dtype)
logging.vlog(1, 'sharding = %s', sharding_for_read)
deserialize_ops.append(
serialization.async_deserialize(
sharding_for_read,
tspec,
global_shape=arg_global_shape,
dtype=dtype,
byte_limiter=info.byte_limiter,
context=info.ts_context,
strict=arg.strict if hasattr(arg, 'strict') else True,
)
)
results = await asyncio.gather(*deserialize_ops)
deserialized_arrays = []
total_io_bytes = 0
for arr, io_bytes in results:
deserialized_arrays.append(arr)
total_io_bytes += io_bytes
return deserialized_arrays, total_io_bytes
if array_metadata_store is not None:
(ret, total_io_bytes), array_metadatas = await asyncio.gather(
_async_deserialize(
infos,
args,
shardings,
metadata_key=metadata_key,
),
array_metadata_store.read(
checkpoint_dir=infos[0].parent_dir,
),
)
if array_metadatas:
ret = _wrap_random_key_data(array_metadatas, infos, ret)
else:
ret, total_io_bytes = await _async_deserialize(
infos,
args,
shardings,
metadata_key=metadata_key,
)
total_duration = time.time() - total_start_time
io_throughput = total_io_bytes / total_duration if total_duration > 0 else 0
storage_type = path_utils.get_storage_type(infos[0].parent_dir)
logging.info(
'[process=%d] %s throughput: %s/s (total gbytes: %s) (time elapsed: %s s)'
' (per-host)',
multihost.process_index(),
'/jax/orbax/read/worker/io/requested',
humanize.naturalsize(io_throughput, binary=True, format='%.3f'),
humanize.naturalsize(total_io_bytes, binary=True),
total_duration,
)
# Record total duration of the read operation. Note that for McJAX, it
# includes IO time and H2D transfer time. For Pathways Remote Python,
# it includes only IO time.
jax.monitoring.record_event_duration_secs(
'/jax/orbax/read/worker/total_duration_secs',
total_duration,
storage_type=storage_type,
)
# record total bytes requested to be read from IO
jax.monitoring.record_scalar(
'/jax/orbax/read/worker/io/requested/gbytes',
total_io_bytes / (1024**3),
storage_type=storage_type,
)
jax.monitoring.record_scalar(
'/jax/orbax/read/worker/io/requested/throughput/gbytes_per_sec',
io_throughput / (1024**3),
storage_type=storage_type,
)
return ret
def _sync_deserialize_arrays(
infos: Sequence[types.ParamInfo],
args: Sequence[types.RestoreArgs],
shardings: Sequence[jax.sharding.Sharding],
metadata_key: str | None,
array_metadata_store: array_metadata_store_lib.Store | None,
) -> Sequence[jax.Array]:
"""Deserializes arrays and applies array_metadata if available."""
return asyncio_utils.run_sync(
_deserialize_arrays(
infos,
args,
shardings,
metadata_key,
array_metadata_store,
)
)
async def _get_abstract_arrays(
args: Sequence[types.RestoreArgs],
shardings: Sequence[jax.sharding.Sharding],
array_metadata_store: array_metadata_store_lib.Store | None = None,
infos: Sequence[types.ParamInfo] | None = None,
) -> Sequence[jax.ShapeDtypeStruct]:
"""Returns result specs for dispatchers.
Computes ShapeDtypeStruct specs that describe the expected output of the
dispatched worker function. For PRNG key parameters (detected via
array_metadata_store), the specs use PRNG key dtypes and logical shapes.
The colocated_python framework handles the PRNG key <-> physical
conversion at the IFRT boundary.
Args:
args: ArrayRestoreArgs for each parameter.
shardings: Shardings for each parameter.
array_metadata_store: Store to read PRNG key impl metadata from.
infos: ParamInfo for each parameter.
Returns:
Sequence of ShapeDtypeStruct result specs for the dispatcher.
"""
metadatas_cache: dict[str, Any] = {}
if array_metadata_store is not None and infos:
array_metadatas = await array_metadata_store.read(
checkpoint_dir=infos[0].parent_dir,
)
if array_metadatas:
if isinstance(array_metadatas, dict):
target_list = next(iter(array_metadatas.values()))
else:
target_list = array_metadatas
metadatas_cache = {meta.param_name: meta for meta in target_list}
abstract_arrays: list[jax.ShapeDtypeStruct] = []
for i, (arg, sharding) in enumerate(zip(args, shardings)):
assert isinstance(arg, ArrayRestoreArgs)
assert arg.global_shape is not None
assert arg.dtype is not None
if sharding is None:
raise ValueError('Sharding of jax.Array cannot be None.')
shape = arg.global_shape
dtype = arg.dtype
if infos and (meta := metadatas_cache.get(infos[i].name)) is not None:
if meta.ext_metadata and isinstance(meta.ext_metadata, dict):
if (
impl := meta.ext_metadata.get(array_metadata_lib.RANDOM_KEY_IMPL)
) is not None:
prng_key_dtype = jax.random.key(0, impl=impl).dtype
if jax.dtypes.issubdtype(dtype, jax.dtypes.prng_key):
# arg.dtype is already a PRNG key dtype, so arg.global_shape
# is already the logical shape. Use it as-is.
dtype = prng_key_dtype
else:
# arg.dtype is physical (e.g. uint32), so arg.global_shape
# is the physical shape. Convert to logical shape.
dtype = prng_key_dtype
key_trailing_shape = jax.eval_shape(
jax.random.key_data,
jax.ShapeDtypeStruct(shape=(), dtype=prng_key_dtype),
).shape
key_trailing_ndim = len(key_trailing_shape)
shape = shape[:-key_trailing_ndim] if key_trailing_ndim else shape
# Fix rank mismatch between logical shape and physical sharding.
if isinstance(sharding, jax.sharding.NamedSharding):
original_spec = sharding.spec
# Drop trailing dims to match the rank of the logical shape.
logical_spec = jax.sharding.PartitionSpec(
*original_spec[: len(shape)]
)
sharding = jax.sharding.NamedSharding(sharding.mesh, logical_spec)
logging.vlog(
1,
'_get_abstract_arrays: PRNG key parameter %s: impl=%s,'
' logical shape=%s, dtype=%s, sharding=%s',
infos[i].name,
impl,
shape,
dtype,
sharding,
)
abstract_arrays.append(
jax.ShapeDtypeStruct(shape=shape, dtype=dtype, sharding=sharding)
)
return abstract_arrays
def _get_default_use_replica_parallel():
platform = os.environ.get('JAX_PLATFORMS', '').lower()
if 'gpu' in platform or 'cuda' in platform:
return False
return True
[docs]
class ArrayHandler(types.TypeHandler):
"""An implementation of TypeHandler for jax.Array."""
[docs]
def __init__(
self,
metadata_key: str | None = None,
primary_host: int | None = 0,
replica_id: int | None = 0,
use_replica_parallel: bool | None = None,
min_slice_bytes_for_replica_parallel: int | None = None,
max_replicas_for_replica_parallel: int | None = None,
enable_write_sharding_file: bool = True,
array_metadata_store: array_metadata_store_lib.Store | None = None,
enable_replica_parallel_separate_folder: bool = False,
dispatcher: dispatchers.Dispatcher | None = None,
callback: types.SerializationStatusCallback | None = None,
):
"""Constructor.
Args:
metadata_key: name to give to Tensorstore metadata files.
primary_host: the host id of the primary host. Default to 0. If it's set
to None, then all hosts will be considered as primary. It's useful in
the case that all hosts are only working with local storage.
replica_id: the replica id to be used for saving. Default to 0. If it's
set to None, each shards will pick first replica_id to be used. It's
useful in the case that all hosts are only working with local storage.
use_replica_parallel: Whether to parallelize saving across replicas.
min_slice_bytes_for_replica_parallel: Minimum number of bytes per replica
slice. Only uses replica-parallel when the amount of data written per
replica is greater than or equal to this number.
max_replicas_for_replica_parallel: Maximum number of replicas over which
saving will be parallelized if use_replica_parallel is True.
enable_write_sharding_file: whether to write sharding file, defaults to
True.
array_metadata_store: Store to manage per host ArrayMetadata. To disable
ArrayMetadata persistence, set it to None.
enable_replica_parallel_separate_folder: If True, save replica and sharded
arrays in separate folders when use_replica_parallel is active.
dispatcher: The dispatcher to use for executing operations on the workers.
callback: The callback to use for executing operations in handlers.
"""
self._metadata_key = metadata_key
self._primary_host = primary_host
self._replica_id = replica_id
self._enable_write_sharding_file = enable_write_sharding_file
self._use_replica_parallel = (
_get_default_use_replica_parallel()
if use_replica_parallel is None
else use_replica_parallel
)
self._min_slice_bytes_for_replica_parallel = (
min_slice_bytes_for_replica_parallel
)
self._max_replicas_for_replica_parallel = max_replicas_for_replica_parallel
self._array_metadata_store = array_metadata_store
self._enable_replica_parallel_separate_folder = (
enable_replica_parallel_separate_folder
)
self._ext_metadata = dict()
self._dispatcher = dispatcher
self._callback = (
callback
if callback is not None
else types.NoopSerializationStatusCallback()
)
logging.vlog(
1,
'Created `%s` with primary_host=%s, replica_id=%s,'
' use_replica_parallel=%s, array_metadata_store=%s, dispatcher=%s',
self.__class__.__qualname__,
self._primary_host,
self._replica_id,
self._use_replica_parallel,
self._array_metadata_store,
self._dispatcher,
)
jax.monitoring.record_event(
'/jax/orbax/array_handler/init',
type=self.__class__.__qualname__,
dispatcher=self._dispatcher.__class__.__qualname__
if self._dispatcher
else 'none',
use_replica_parallel=self._use_replica_parallel,
enable_replica_parallel_separate_folder=self._enable_replica_parallel_separate_folder,
)
if self._primary_host is None and jax.__version_info__ <= (0, 4, 25): # pylint:disable=unreachable
raise ValueError(
'Setting `primary_host` to None requires JAX version > 0.4.25.'
)
def has_dispatcher(self) -> bool:
return self._dispatcher is not None
[docs]
def typestr(self) -> str:
return JAX_ARRAY_TYPE_STR
[docs]
async def serialize(
self,
values: Sequence[jax.Array],
infos: Sequence[types.ParamInfo],
args: Sequence[types.SaveArgs] | None = None,
) -> Sequence[future.Future]:
"""See superclass documentation."""
args = args or [types.SaveArgs()] * len(values)
types.check_input_arguments(values, infos, args)
# TODO(b/461467565): Raise error when saving zero sized arrays on pathways
# as well.
check_array_values(values, infos, raise_error=not self.has_dispatcher())
self._ext_metadata = dict()
arrays = []
for v, info in zip(values, infos):
if (
isinstance(v, jax.Array)
and jax.process_count() > 1
and v.is_fully_addressable
):
debug_param_info = (
f'ParamInfo=[name={info.name},value_typestr={info.value_typestr}]'
)
debug_array = (
f'jax.Array=[value={v},shape={v.shape},dtype={v.dtype},'
f'sharding={v.sharding},device={v.device}]'
)
raise ValueError(
f'Cannot serialize host local jax.Array ({debug_param_info},'
f' {debug_array}) in multi-host setting. Arrays like this are'
' typically obtained using pmap. Consider using'
' fully_replicated_host_local_array_to_global_array in'
' orbax/checkpoint/utils.py to convert your arrays into'
f' serializable objects. Array.sharding: {v.sharding}'
)
logging.vlog(
1,
'serialize: param %s, dtype=%s, shape=%s, sharding=%s',
info.name,
v.dtype,
v.shape,
getattr(v, 'sharding', None),
)
if jax.dtypes.issubdtype(v.dtype, jax.dtypes.prng_key):
# a JAX random key
logging.vlog(
1,
'serialize: PRNG key param %s, logical shape=%s, sharding=%s',
info.name,
v.shape,
v.sharding,
)
key_data = jax.random.key_data(v)
logging.vlog(
1,
'serialize: PRNG key param %s, physical shape=%s, sharding=%s',
info.name,
key_data.shape,
key_data.sharding,
)
arrays.append(key_data)
self._ext_metadata[info.name] = {
array_metadata_lib.RANDOM_KEY_IMPL: str(jax.random.key_impl(v))
}
else:
# regular array
arrays.append(v)
assert all([info.enable_pinned_host_transfer for info in infos]) or all(
[not info.enable_pinned_host_transfer for info in infos]
)
future_list = []
if self._enable_write_sharding_file:
future_list.append(
future.CommitFutureAwaitingContractedSignals(
_async_serialize_shardings(
shardings=[arr.sharding for arr in arrays],
infos=infos,
primary_host=self._primary_host,
),
name='serialize_shardings',
)
)
future_list.append(
_serialize_arrays(
arrays=arrays,
infos=infos,
args=args,
dispatcher=self._dispatcher,
primary_host=self._primary_host,
replica_id=self._replica_id,
use_replica_parallel=self._use_replica_parallel,
min_slice_bytes_for_replica_parallel=self._min_slice_bytes_for_replica_parallel,
max_replicas_for_replica_parallel=self._max_replicas_for_replica_parallel,
enable_replica_parallel_separate_folder=self._enable_replica_parallel_separate_folder,
metadata_key=self._metadata_key,
ext_metadata=self._ext_metadata,
array_metadata_store=self._array_metadata_store,
callback=self._callback,
)
)
return future_list
async def _maybe_read_metadata_and_update_restore_args(
self,
infos: Sequence[types.ParamInfo],
args: Sequence[types.RestoreArgs],
) -> Sequence[ArrayRestoreArgs]:
"""Reads metadata and updates restore args."""
if any(
not isinstance(arg, ArrayRestoreArgs)
or arg.global_shape is None
or arg.dtype is None
for arg in args
):
result: list[ArrayRestoreArgs] = []
logging.warning(
'`global_shape` and `dtype` are required for efficient restoration on'
' Pathways. Automatically restoring metadata from disk to obtain'
' these properties, which involves lightweight reading of metadata'
' files, but please provide these properties for optimal restoration.'
)
metadatas = await self.metadata(infos)
for arg, meta in zip(args, metadatas):
if not isinstance(arg, ArrayRestoreArgs):
arg = ArrayRestoreArgs()
if meta is not None:
physical_shape = getattr(meta, 'write_shape', None) or getattr(
meta, 'shape', None
)
dtype = getattr(meta, 'dtype', None)
if arg.global_shape is None:
arg = dataclasses.replace(
arg, global_shape=physical_shape, shape=physical_shape
)
if arg.dtype is None:
arg = dataclasses.replace(arg, dtype=dtype)
result.append(arg)
return result
else:
return [cast(ArrayRestoreArgs, arg) for arg in args]
[docs]
async def deserialize(
self,
infos: Sequence[types.ParamInfo],
args: Sequence[types.RestoreArgs] | None = None,
) -> Sequence[jax.Array]:
"""See superclass documentation.
Args:
infos: ParamInfo.
args: must be of type `ArrayRestoreArgs`.
Returns:
The deserialized parameter.
Raises:
ValueError if `args` is not provided.
ValueError if `args.sharding` is not provided or `args.mesh` and
`args.mesh_axes` are not provided.
"""
if args is None:
raise ValueError('Must provide ArrayRestoreArgs to restore as jax.Array.')
types.check_input_arguments(infos, args)
await asyncio.gather(*[info.await_path_creation() for info in infos])
if infos[0].parent_dir is None:
raise ValueError('parent_dir cannot be None')
sharding_file_path = infos[0].parent_dir / _SHARDING_FILE_NAME
sharding_file_exists = await async_path.exists(sharding_file_path)
shardings = await _deserialize_shardings(infos, args, sharding_file_exists)
if self._dispatcher is None:
ret = await _deserialize_arrays(
infos,
args,
shardings,
self._metadata_key,
self._array_metadata_store,
)
else:
args = await self._maybe_read_metadata_and_update_restore_args(
infos, args
)
result_specs = await _get_abstract_arrays(
args, shardings, self._array_metadata_store, infos
)
ret = self._dispatcher.dispatch(
_sync_deserialize_arrays,
result_specs=result_specs,
func_kwargs={
'infos': infos,
'args': args,
'shardings': shardings,
'metadata_key': self._metadata_key,
'array_metadata_store': self._array_metadata_store,
},
)
jax.block_until_ready(ret)
if logging.vlog_is_on(1):
for a in ret:
logging.vlog(
1,
'restored jax.Array.shape = %s, jax.array.dtype = %s,'
' jax.array.format = %s',
getattr(a, 'shape', None),
getattr(a, 'dtype', None),
getattr(a, 'format', None),
)
ts_utils.print_ts_debug_data(self._metadata_key, infos)
return ret # pytype: disable=bad-return-type
[docs]
def memory_size(
self, values: Sequence[jax.Array]
) -> Sequence[Tuple[int, int]]:
write_sizes = []
read_sizes = []
shard_size = lambda shard: shard.data.size * shard.data.dtype.itemsize
for v in values:
write_sizes.append(
replica_slices.get_replica_slices(
v,
replica_id=self._replica_id,
use_replica_parallel=self._use_replica_parallel,
min_slice_bytes_for_replica_parallel=self._min_slice_bytes_for_replica_parallel,
max_replicas_for_replica_parallel=self._max_replicas_for_replica_parallel,
).nbytes
)
read_sizes.append(
sum(shard_size(shard) for shard in v.addressable_shards)
)
return list(zip(write_sizes, read_sizes))
def _is_host_for_primary_replica(primary_replica_ids: set[int]) -> bool:
return multihost.process_index() in primary_replica_ids
class InvalidShardingError(ValueError):
"""Error raised when sharding is not valid."""
def _validate_sharding_and_get_primary_replica_processes(
replica_axis_index: int,
primary_replica_id: int,
sharding: jax.sharding.Sharding,
) -> Set[int]:
"""Validates sharding for restoration."""
if not isinstance(sharding, jax.sharding.NamedSharding):
raise InvalidShardingError(
'The provided sharding is not a NamedSharding. Please use'
' NamedSharding instead.'
)
primary_replica_device_ids, primary_replica_pids = (
multislice.get_primary_replica_ids_and_pids(
replica_axis_idx=replica_axis_index,
mesh=sharding.mesh,
primary_replica_id=primary_replica_id,
)
)
if len(primary_replica_device_ids) == len(jax.devices()):
raise InvalidShardingError(
'All devices are in the primary replica. There are no non-primary'
' replicas to broadcast to.'
)
expected_primary_replica_device_ids = {
d.id
for d in jax.devices()
if multihost.process_index_from_device(d) in primary_replica_pids
}
if not primary_replica_device_ids.issubset(
expected_primary_replica_device_ids
):
raise InvalidShardingError(
'The provided sharding is not valid. The primary replica has the'
f' following devices: {primary_replica_device_ids}, which is not a'
' subset of the expected devices:'
f' {expected_primary_replica_device_ids}. for the primary processes:'
f' {primary_replica_pids}.'
)
return primary_replica_pids
async def _single_replica_deserialize_and_broadcast(
infos: Sequence[types.ParamInfo],
args: Sequence[SingleReplicaArrayRestoreArgs],
shardings: Sequence[jax.sharding.Sharding],
single_replica_shardings: Sequence[jax.sharding.Sharding],
replica_axis_index: int,
primary_replica_id: int,
metadata_key: str | None,
broadcast_memory_limit_bytes: int | None,
broadcast_memory_scaling_factor: float | None,
) -> Sequence[jax.Array]:
"""Deserializes and broadcasts a single replica."""
primary_replica_pids = _validate_sharding_and_get_primary_replica_processes(
replica_axis_index=replica_axis_index,
primary_replica_id=primary_replica_id,
sharding=shardings[0],
)
if _is_host_for_primary_replica(primary_replica_pids):
start_deserialization = time.time()
deserialized = await _deserialize_arrays(
infos,
args,
single_replica_shardings,
metadata_key,
None,
)
deserialization_elapsed_s = time.time() - start_deserialization
jax.monitoring.record_event_duration_secs(
'/jax/checkpoint/read/primary_replica_deserialization_duration_secs',
deserialization_elapsed_s,
)
logging.info(
'Finished primary replica deserialization in %.2f seconds',
deserialization_elapsed_s,
)
else:
@functools.partial(
jax.jit, static_argnums=0, out_shardings=tuple(single_replica_shardings)
)
def create_zeros(shape_dtype_tup):
return jax.tree.map(
lambda sd: jnp.zeros(sd.shape, dtype=sd.dtype), shape_dtype_tup
)
shape_dtype = [
jax.ShapeDtypeStruct(arg.global_shape, arg.dtype) for arg in args
]
local_mesh = cast(
jax.sharding.NamedSharding, single_replica_shardings[0]
).mesh
if hasattr(jax, 'set_mesh'):
with jax.set_mesh(local_mesh):
deserialized = create_zeros(tuple(shape_dtype))
else:
with local_mesh:
deserialized = create_zeros(tuple(shape_dtype))
deserialized = tuple(deserialized)
start_broadcast = time.time()
global_mesh = cast(jax.sharding.NamedSharding, shardings[0]).mesh
shared_state, _ = multislice.broadcast_one_replica_to_all(
deserialized,
global_mesh,
replica_axis_index,
_is_host_for_primary_replica(primary_replica_pids),
memory_limit_bytes=broadcast_memory_limit_bytes,
memory_scaling_factor=broadcast_memory_scaling_factor,
)
broadcast_elapsed_s = time.time() - start_broadcast
jax.monitoring.record_event_duration_secs(
'/jax/checkpoint/read/broadcast_duration_secs', broadcast_elapsed_s
)
logging.info('Finished broadcasting in %.2f seconds', broadcast_elapsed_s)
return shared_state
def _single_replica_deserialize_on_worker(
_,
infos: Sequence[types.ParamInfo],
args: Sequence[SingleReplicaArrayRestoreArgs],
single_replica_shardings: Sequence[jax.sharding.Sharding],
metadata_key: str | None,
):
"""Deserializes a single replica on a worker."""
return asyncio_utils.run_sync(
_deserialize_arrays(
infos,
args,
single_replica_shardings,
metadata_key,
None,
)
)
[docs]
class SingleReplicaArrayHandler(ArrayHandler):
"""An implementation TypeHandler for jax.
ArrayHandler that optimizes checkpoint read performance during multi-pod or
multihost training. Normally each host reads relevant data from the
checkpoint, even if other hosts are reading the exact same data. This can be
very inefficient with large number of pods/hosts and large checkpoint size.
With SingleReplicaArrayhandler the data is read only on hosts that are in
primary replica. Then these hosts broadcast the data to other hosts. It is
assumed that all hosts have ALL their devices either inside the primary
replica or outside.
Consider, for example, the following sharding on v4-128 which has 16 hosts and
64 devices::
shape = (32, 2)
mesh = jax.sharding.Mesh(jax.devices().reshape(shape), ('x', 'y'))
pspec = jax.sharding.PartitionSpec(None, 'y')
sharding=jax.sharding.NamedSharding(mesh, pspec)
This sharding will not work since the primary replica has only two devices,
and hence there is a host which has 2 devices in the primary replica, and 2
devices outside of primary replica. However, changing shape, for example, to
(4, 16) will result in a valid sharding.
This TypeHandler can be registered by running::
ocp.type_handlers.register_type_handler(
jax.Array,
type_handlers.SingleReplicaArrayHandler(),
override=True)
Example usage can be found in MaxText (TO BE MERGED).
https://github.com/google/maxtext/blob/main/MaxText/checkpointing.py
"""
[docs]
def __init__(
self,
replica_axis_index: int | None = 0,
primary_replica_id: int | None = 0,
broadcast_memory_limit_bytes: int | None = None,
broadcast_memory_scaling_factor: float | None = 0.75,
**kwargs,
):
"""Constructor.
Args:
replica_axis_index: Defines the axis of the global mesh along which
replicas are defined. E.g. all devices in
global_mesh.devices[replica_axis_index] are part of the same replica.
primary_replica_id: The id of the replica that is used to load and
broadcast the checkpoint.
broadcast_memory_limit_bytes: Specifies the memory size (in bytes) used
for broadcasting data.
broadcast_memory_scaling_factor: Specifies the fraction of available
memory to use for broadcasting data.
**kwargs: Look at `ArrayHandler` for documentation of other arguments.
"""
super(SingleReplicaArrayHandler, self).__init__(
**kwargs,
)
self.replica_axis_index = replica_axis_index
self.primary_replica_id = primary_replica_id
self.broadcast_memory_limit_bytes = broadcast_memory_limit_bytes
self.broadcast_memory_scaling_factor = broadcast_memory_scaling_factor
def _construct_single_replica_sharding(
self, sharding: jax.sharding.Sharding
) -> jax.sharding.Sharding:
"""Constructs a single replica sharding."""
assert isinstance(sharding, jax.sharding.NamedSharding)
local_replica_devices = multislice.local_replica_devices(
sharding.mesh, replica_axis_index=self.replica_axis_index
)
local_replica_devices = np.expand_dims(
local_replica_devices, axis=self.replica_axis_index
)
replica_mesh = jax.sharding.Mesh(
local_replica_devices,
sharding.mesh.axis_names,
)
return jax.sharding.NamedSharding(replica_mesh, sharding.spec)
[docs]
async def deserialize(
self,
infos: Sequence[types.ParamInfo],
args: Sequence[SingleReplicaArrayRestoreArgs] | None = None, # pytype: disable=signature-mismatch
) -> Sequence[jax.Array]:
"""Deserializing in case of single replica broadcasting.
Args:
infos: ParamInfo.
args: must be of type `SingleReplicaArrayRestoreArgs`.
Returns:
Deserialized parameters.
Raises:
ValueError if `args` is not provided.
ValueError if `args.sharding` is not provided or `args.mesh` and
`args.mesh_axes` or `single_replica_pids` or `single_replica_ids` are
not provided.
"""
if args is None:
raise ValueError(
'Must provide SingleReplicaArrayRestoreArgs to restore as jax.Array.'
)
types.check_input_arguments(infos, args)
for arg in args:
if not isinstance(arg, SingleReplicaArrayRestoreArgs):
raise ValueError(
'Must provide `SingleReplicaArrayRestoreArgs`, but got'
f' {type(arg)}.'
)
if arg.sharding is None:
raise ValueError(
'Must provide `sharding` to restore with'
' `SingleReplicaArrayHandler`.'
)
# arg.single_replica_sharding is not required to be passed.
single_replica_shardings = [
arg.single_replica_sharding
if arg.single_replica_sharding
else self._construct_single_replica_sharding(arg.sharding)
for arg in args
]
shardings = [arg.sharding for arg in args]
if self._dispatcher is None:
ret = await _single_replica_deserialize_and_broadcast(
infos,
args,
shardings,
single_replica_shardings,
self.replica_axis_index,
self.primary_replica_id,
self._metadata_key,
self.broadcast_memory_limit_bytes,
self.broadcast_memory_scaling_factor,
)
else:
primary_replica_devices = multislice.replica_devices(
shardings[0].mesh,
replica_id=self.primary_replica_id,
replica_axis_index=self.replica_axis_index,
).flatten()
dummy_input_array = dispatchers.get_dummy_input_array(
primary_replica_devices
)
# Step 1: Deserialize arrays on a single replica.
ret = self._dispatcher.dispatch(
_single_replica_deserialize_on_worker,
input_arrays=dummy_input_array,
result_specs=await _get_abstract_arrays(
args,
single_replica_shardings,
self._array_metadata_store,
infos,
),
func_kwargs={
'infos': infos,
'args': args,
'single_replica_shardings': single_replica_shardings,
'metadata_key': self._metadata_key,
},
)
# Step 2: Use `jax.device_put` to broadcast/reshard the data to all
# devices according to the final desired sharding. This is the equivalent
# of multislice.broadcast_one_replica_to_all in non-dispatcher based
# implementation.
ret = jax.tree.map(jax.device_put, ret, shardings)
jax.block_until_ready(ret)
if self._array_metadata_store is not None:
array_metadatas = await self._array_metadata_store.read(
checkpoint_dir=infos[0].parent_dir,
)
if array_metadatas:
ret = _wrap_random_key_data(array_metadatas, infos, list(ret))
return ret
# TODO(b/370396118): Calculation overestimates bytes read.
[docs]
def memory_size( # pylint: disable=useless-parent-delegation
self, values: Sequence[jax.Array]
) -> Sequence[Tuple[int, int]]:
return super().memory_size(values)