TypeHandlers

Contents

TypeHandlers#

Public symbols for type_handlers module.

Arguments for PyTreeCheckpointHandler#

class orbax.checkpoint.type_handlers.SaveArgs(aggregate=False, dtype=None, chunk_byte_size=None, shard_axes=())[source][source]#

Extra arguments that can be provided for saving.

aggregate:

Deprecated, please use custom TypeHandler (https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.type_handlers.html#typehandler) or contact Orbax team to migrate before August 1st, 2024. If true, saves the given parameter in an aggregated tree format rather than individually. See AggregateHandler.

dtype:

If provided, casts the parameter to the given dtype before saving. Note that the parameter must be compatible with the given type (e.g. jnp.bfloat16 is not compatible with np.ndarray).

chunk_byte_size:

This is an experimental feature that automatically chooses the largest chunk shape possible, while keeping the chunk byte size less than or equal to the specified chunk_byte_size. Both the write_chunk_shape and read_chunk_shape are automatically set to the chosen shape. This uses a greedy algorithm that prioritizes splitting the largest dimensions first.

shard_axes: An optional list of axes that should be prioritized when

sharding array for storage. If empty, storage sharding implementation will prioritize axes which are already sharded.

__eq__(other)#

Return self==value.

__hash__ = None#
__init__(aggregate=False, dtype=None, chunk_byte_size=None, shard_axes=())#
class orbax.checkpoint.type_handlers.RestoreArgs(restore_type=None, dtype=None)[source][source]#

Extra arguments that can be provided for restoration.

restore_type:

Specifies the object type of the restored parameter. The type must have a corresponding TypeHandler for restoration. Ignored if the parameter is restored from an aggregated checkpoint file.

dtype:

If provided, casts the parameter to the given dtype after restoring. Note that the parameter must be compatible with the given type (e.g. jnp.bfloat16 is not compatible with np.ndarray).

__eq__(other)#

Return self==value.

__hash__ = None#
__init__(restore_type=None, dtype=None)#
class orbax.checkpoint.type_handlers.ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=None, mesh=None, mesh_axes=None, sharding=None, global_shape=None, shape=None, strict=True)[source][source]#

Arguments used when restoring with ArrayHandler.

restore_type:

See parent class.

mesh:

The device mesh that the array should be restored as. Cannot be None.

mesh_axes:

The mesh_axes that the array should be restored as. Cannot be None.

sharding:

jax.sharding.Sharding, ShardingMetadata, or Layout object which takes precedence over mesh and mesh_axes if provided. Otherwise, mesh and mesh_axes will be used to construct a NamedSharding object OR ShardingMetadata which is an orbax representation of jax.sharding.Sharding that stores the same properties but does not require accessing real devices.

global_shape: The global shape that the array should be restored into. If not

provided, the shape will be restored as written. Presently, arbitrary shape transformations are not supported (for example, reshaping to different dimensions). Padding and truncating are supported. When the global_shape is greater than that of the saved array, 0’s will be appended. If the global_shape is shorter than that of the saved array, excess elements will be dropped from the end of the array.

shape: Interchangeable with global_shape. strict: True by default. If True, enforces that the target global shape and

the origin global shape (as recorded by the saved array) are the same. If False, the returned array will be silently truncated or padded to fit the target global shape as necessary.

restore_type#

alias of Array

__eq__(other)#

Return self==value.

__hash__ = None#
__init__(restore_type=<class 'jax.Array'>, dtype=None, mesh=None, mesh_axes=None, sharding=None, global_shape=None, shape=None, strict=True)#

TypeHandler#

class orbax.checkpoint.type_handlers.TypeHandler[source][source]#

Interface for reading and writing a PyTree leaf.

abstractmethod typestr()[source][source]#

A string representation of the type.

Cannot conflict with other types.

Return type:

str

Returns:

The type as a string.

abstractmethod async metadata(infos)[source][source]#

Constructs object metadata from a stored parameter location.

Parameters:

infos (Sequence[ParamInfo]) – sequence of ParamInfo

Return type:

Sequence[Metadata]

Returns:

Sequence of Metadata for each provided ParamInfo.

abstractmethod async serialize(values, infos, args=None)[source][source]#

Writes the parameter to a storage location.

This method is responsible for copying the parameter from a remote device in a synchronous fashion (if applicable). It should then return a list of futures which can be later awaited to complete the final commit operation to a storage location.

Note: Any operations writing to storage location should be done by using future.CommitFutureAwaitingContractedSignals to wait for the directories to be created.

The function can be used in a multihost setting, but should not implement extra logic to ensure atomicity.

Parameters:
  • values (Sequence[Any]) – a sequence of parameters to save.

  • infos (Sequence[ParamInfo]) – a sequence of ParamInfo containing relevant information for serialization of each value.

  • args (Optional[Sequence[SaveArgs], None]) – a sequence of additional arguments for serialization, provided by the user.

Return type:

Sequence[Future]

Returns:

Sequence of commit futures which can be awaited to complete the save operation.

abstractmethod async deserialize(infos, args=None)[source][source]#

Reads the parameter from a storage location.

Parameters:
  • infos (Sequence[ParamInfo]) – Sequence of ParamInfo for deserialization.

  • args (Optional[Sequence[RestoreArgs], None]) – Sequence of user-provided restoration information.

Return type:

Sequence[Any]

Returns:

The deserialized parameters.

finalize(directory)[source][source]#

Performs any logic to finalize parameter files written by this class.

By default, does nothing.

Parameters:

directory (Path) – A path to the location of the checkpoint. This corresponds to param_info.parent_dir.

memory_size(values)[source][source]#

For a batch of values, returns the size of each value in bytes.

Note that the default implementation uses sys.getsizeof, which is not likely to be accurate for many types.

The value returned is intended to be per-host.

Parameters:

values (Sequence[Any]) – A batch of values.

Return type:

Sequence[Tuple[int, int]]

Returns:

A sequence of elements corresponding to values. Each element is a tuple of [write_size, read_size]. In many cases these values may be the same.

Raises:

NotImplementedError – Raises error by default since we will rely on a backup implementation.

NumpyHandler#

class orbax.checkpoint.type_handlers.NumpyHandler(metadata_key=None, ocdbt_process_id=None)[source][source]#

Provides an implementation of TypeHandler for replicated numpy arrays.

__init__(metadata_key=None, ocdbt_process_id=None)[source][source]#

Constructor.

Parameters:
  • metadata_key (Optional[str, None]) – name to give to Tensorstore metadata files.

  • ocdbt_process_id (UnionType[str, None]) – name of the process id to be used by single controller systems to write in OCDBT format. The checkpoints are written in a subdir with this name to avoid collisions with the subdir names used by other host processes managed by this controller.

typestr()[source][source]#

A string representation of the type.

Cannot conflict with other types.

Return type:

str

Returns:

The type as a string.

async metadata(infos)[source][source]#

Constructs object metadata from a stored parameter location.

Parameters:

infos (Sequence[ParamInfo]) – sequence of ParamInfo

Return type:

Sequence[ArrayMetadata]

Returns:

Sequence of Metadata for each provided ParamInfo.

async serialize(values, infos, args=None)[source][source]#

Uses Tensorstore to serialize a numpy array.

Return type:

Sequence[Future]

async deserialize(infos, args=None)[source][source]#

Deserializes the array using Tensorstore.

Return type:

Sequence[ndarray]

memory_size(values)[source][source]#

For a batch of values, returns the size of each value in bytes.

Note that the default implementation uses sys.getsizeof, which is not likely to be accurate for many types.

The value returned is intended to be per-host.

Parameters:

values (Sequence[ndarray]) – A batch of values.

Return type:

Sequence[Tuple[int, int]]

Returns:

A sequence of elements corresponding to values. Each element is a tuple of [write_size, read_size]. In many cases these values may be the same.

Raises:

NotImplementedError – Raises error by default since we will rely on a backup implementation.

ScalarHandler#

class orbax.checkpoint.type_handlers.ScalarHandler(metadata_key=None, ocdbt_process_id=None)[source][source]#

A wrapper around NumpyHandler to deal with scalar types (int, float, etc.).

typestr()[source][source]#

A string representation of the type.

Cannot conflict with other types.

Return type:

str

Returns:

The type as a string.

async metadata(infos)[source][source]#

Constructs object metadata from a stored parameter location.

Parameters:

infos (Sequence[ParamInfo]) – sequence of ParamInfo

Return type:

Sequence[ScalarMetadata]

Returns:

Sequence of Metadata for each provided ParamInfo.

async serialize(values, infos, args=None)[source][source]#

See superclass documentation.

Return type:

Sequence[Future]

async deserialize(infos, args=None)[source][source]#

See superclass documentation.

Return type:

Sequence[Union[int, float, number]]

memory_size(values)[source][source]#

For a batch of values, returns the size of each value in bytes.

Note that the default implementation uses sys.getsizeof, which is not likely to be accurate for many types.

The value returned is intended to be per-host.

Parameters:

values (Sequence[Union[int, float, number]]) – A batch of values.

Return type:

Sequence[Tuple[int, int]]

Returns:

A sequence of elements corresponding to values. Each element is a tuple of [write_size, read_size]. In many cases these values may be the same.

Raises:

NotImplementedError – Raises error by default since we will rely on a backup implementation.

ArrayHandler#

class orbax.checkpoint.type_handlers.ArrayHandler(metadata_key=None, primary_host=0, replica_id=0, use_replica_parallel=None, min_slice_bytes_for_replica_parallel=None, max_replicas_for_replica_parallel=None, enable_write_sharding_file=True, array_metadata_store=None, enable_replica_parallel_separate_folder=False, dispatcher=None)[source][source]#

An implementation of TypeHandler for jax.Array.

__init__(metadata_key=None, primary_host=0, replica_id=0, use_replica_parallel=None, min_slice_bytes_for_replica_parallel=None, max_replicas_for_replica_parallel=None, enable_write_sharding_file=True, array_metadata_store=None, enable_replica_parallel_separate_folder=False, dispatcher=None)[source][source]#

Constructor.

Parameters:
  • metadata_key (UnionType[str, None]) – name to give to Tensorstore metadata files.

  • primary_host (UnionType[int, None]) – the host id of the primary host. Default to 0. If it’s set to None, then all hosts will be considered as primary. It’s useful in the case that all hosts are only working with local storage.

  • replica_id (UnionType[int, None]) – the replica id to be used for saving. Default to 0. If it’s set to None, each shards will pick first replica_id to be used. It’s useful in the case that all hosts are only working with local storage.

  • use_replica_parallel (UnionType[bool, None]) – Whether to parallelize saving across replicas.

  • min_slice_bytes_for_replica_parallel (UnionType[int, None]) – Minimum number of bytes per replica slice. Only uses replica-parallel when the amount of data written per replica is greater than or equal to this number.

  • max_replicas_for_replica_parallel (UnionType[int, None]) – Maximum number of replicas over which saving will be parallelized if use_replica_parallel is True.

  • enable_write_sharding_file (bool) – whether to write sharding file, defaults to True.

  • array_metadata_store (UnionType[Store, None]) – Store to manage per host ArrayMetadata. To disable ArrayMetadata persistence, set it to None.

  • enable_replica_parallel_separate_folder (bool) – If True, save replica and sharded arrays in separate folders when use_replica_parallel is active.

  • dispatcher (UnionType[Dispatcher, None]) – The dispatcher to use for executing operations on the workers.

typestr()[source][source]#

A string representation of the type.

Cannot conflict with other types.

Return type:

str

Returns:

The type as a string.

async metadata(infos)[source][source]#

Constructs object metadata from a stored parameter location.

Parameters:

infos (Sequence[ParamInfo]) – sequence of ParamInfo

Return type:

Sequence[ArrayMetadata]

Returns:

Sequence of Metadata for each provided ParamInfo.

async serialize(values, infos, args=None)[source][source]#

See superclass documentation.

Return type:

Sequence[Future]

async deserialize(infos, args=None)[source][source]#

See superclass documentation.

Parameters:
  • infos (Sequence[ParamInfo]) – ParamInfo.

  • args (Optional[Sequence[RestoreArgs], None]) – must be of type ArrayRestoreArgs.

Return type:

Sequence[Array]

Returns:

The deserialized parameter.

Raises:
  • ValueError if args is not provided.

  • ValueError if args.sharding is not provided or args.mesh and

  • args.mesh_axes

memory_size(values)[source][source]#

For a batch of values, returns the size of each value in bytes.

Note that the default implementation uses sys.getsizeof, which is not likely to be accurate for many types.

The value returned is intended to be per-host.

Parameters:

values (Sequence[Array]) – A batch of values.

Return type:

Sequence[Tuple[int, int]]

Returns:

A sequence of elements corresponding to values. Each element is a tuple of [write_size, read_size]. In many cases these values may be the same.

Raises:

NotImplementedError – Raises error by default since we will rely on a backup implementation.

SingleReplicaArrayHandler#

class orbax.checkpoint.type_handlers.SingleReplicaArrayHandler(replica_axis_index=0, primary_replica_id=0, broadcast_memory_limit_bytes=None, broadcast_memory_scaling_factor=0.75, **kwargs)[source][source]#

An implementation TypeHandler for jax.

ArrayHandler that optimizes checkpoint read performance during multi-pod or multihost training. Normally each host reads relevant data from the checkpoint, even if other hosts are reading the exact same data. This can be very inefficient with large number of pods/hosts and large checkpoint size. With SingleReplicaArrayhandler the data is read only on hosts that are in primary replica. Then these hosts broadcast the data to other hosts. It is assumed that all hosts have ALL their devices either inside the primary replica or outside. Consider, for example, the following sharding on v4-128 which has 16 hosts and 64 devices:

shape = (32, 2)
mesh = jax.sharding.Mesh(jax.devices().reshape(shape), ('x', 'y'))
pspec = jax.sharding.PartitionSpec(None, 'y')
sharding=jax.sharding.NamedSharding(mesh, pspec)

This sharding will not work since the primary replica has only two devices, and hence there is a host which has 2 devices in the primary replica, and 2 devices outside of primary replica. However, changing shape, for example, to (4, 16) will result in a valid sharding.

This TypeHandler can be registered by running:

ocp.type_handlers.register_type_handler(
    jax.Array,
    type_handlers.SingleReplicaArrayHandler(),
    override=True)

Example usage can be found in MaxText (TO BE MERGED). google/maxtext

__init__(replica_axis_index=0, primary_replica_id=0, broadcast_memory_limit_bytes=None, broadcast_memory_scaling_factor=0.75, **kwargs)[source][source]#

Constructor.

Parameters:
  • replica_axis_index (UnionType[int, None]) – Defines the axis of the global mesh along which replicas are defined. E.g. all devices in global_mesh.devices[replica_axis_index] are part of the same replica.

  • primary_replica_id (UnionType[int, None]) – The id of the replica that is used to load and broadcast the checkpoint.

  • broadcast_memory_limit_bytes (UnionType[int, None]) – Specifies the memory size (in bytes) used for broadcasting data.

  • broadcast_memory_scaling_factor (UnionType[float, None]) – Specifies the fraction of available memory to use for broadcasting data.

  • **kwargs – Look at ArrayHandler for documentation of other arguments.

async deserialize(infos, args=None)[source][source]#

Deserializing in case of single replica broadcasting.

Parameters:
  • infos (Sequence[ParamInfo]) – ParamInfo.

  • args (Optional[Sequence[SingleReplicaArrayRestoreArgs], None]) – must be of type SingleReplicaArrayRestoreArgs.

Return type:

Sequence[Array]

Returns:

Deserialized parameters.

Raises:
  • ValueError if args is not provided.

  • ValueError if args.sharding is not provided or args.mesh and

  • args.mesh_axes

  • not provided.

memory_size(values)[source][source]#

For a batch of values, returns the size of each value in bytes.

Note that the default implementation uses sys.getsizeof, which is not likely to be accurate for many types.

The value returned is intended to be per-host.

Parameters:

values (Sequence[Array]) – A batch of values.

Return type:

Sequence[Tuple[int, int]]

Returns:

A sequence of elements corresponding to values. Each element is a tuple of [write_size, read_size]. In many cases these values may be the same.

Raises:

NotImplementedError – Raises error by default since we will rely on a backup implementation.

StringHandler#

class orbax.checkpoint.type_handlers.StringHandler(filename=None)[source][source]#

TypeHandler for strings.

__init__(filename=None)[source][source]#
typestr()[source][source]#

A string representation of the type.

Cannot conflict with other types.

Return type:

str

Returns:

The type as a string.

async metadata(infos)[source][source]#

Constructs object metadata from a stored parameter location.

Parameters:

infos (Sequence[ParamInfo]) – sequence of ParamInfo

Return type:

Sequence[StringMetadata]

Returns:

Sequence of Metadata for each provided ParamInfo.

async serialize(values, infos, args=None)[source][source]#

See superclass documentation.

Return type:

Sequence[Future]

async deserialize(infos, args=None)[source][source]#

See superclass documentation.

Return type:

Sequence[Optional[str, None]]

memory_size(values)[source][source]#

For a batch of values, returns the size of each value in bytes.

Note that the default implementation uses sys.getsizeof, which is not likely to be accurate for many types.

The value returned is intended to be per-host.

Parameters:

values (Sequence[str]) – A batch of values.

Return type:

Sequence[Tuple[int, int]]

Returns:

A sequence of elements corresponding to values. Each element is a tuple of [write_size, read_size]. In many cases these values may be the same.

Raises:

NotImplementedError – Raises error by default since we will rely on a backup implementation.

Tensorstore functions#

orbax.checkpoint.type_handlers.is_ocdbt_checkpoint(path)[source][source]#

Determines whether a checkpoint uses OCDBT format.

Return type:

bool

async orbax.checkpoint.type_handlers.merge_ocdbt_per_process_files(directory, ts_context, use_zarr3, enable_validation=True)[source][source]#

Merges OCDBT files written to per-process subdirectories.

With Tensorstore’s OCDBT format, arrays are initially written to per-process subdirectories, depending on which host is doing the writing. This function can be called to merge the per-process files into a global key-value store.

The original per-process subdirectories are not and should not be deleted - the global kvstore continues to reference them.

NOTE: If no suitable subdirs with OCDBT checkpoints are found, this function does not raise any error and no merged checkpoint is created.

Parameters:
  • directory (Path) – checkpoint location.

  • ts_context (Context) – Tensorstore context.

  • use_zarr3 (bool) – If True, use zarr3 driver, otherwise, use zarr driver for params validation.

  • enable_validation (bool) – If True, validate params after merging. May have a performance impact.

orbax.checkpoint.type_handlers.get_json_tspec_write(info, use_ocdbt, global_shape, local_shape, dtype, process_index=None, metadata_key=None, arg=None)[source][source]#

Gets Tensorstore spec for writing.

Return type:

dict[str, Any]

orbax.checkpoint.type_handlers.get_json_tspec_read(info, use_ocdbt, metadata_key=None, raise_array_data_missing_error=True)[source][source]#

Gets Tensorstore spec for reading.

Return type:

dict[str, Any]

orbax.checkpoint.type_handlers.get_ts_context(*, use_ocdbt=True, file_io_concurrency_limit=None, data_copy_concurrency_limit=None)[source][source]#

Creates a TensorStore context object.

For use with Orbax serialization APIs, or when directly opening a TensorStore object.

Parameters:
  • use_ocdbt (bool) – Whether to use OCDBT driver. Adds options specific to OCDBT if True.

  • file_io_concurrency_limit (UnionType[int, None]) – Optionally overrides the thread pool size for file I/O.

  • data_copy_concurrency_limit (UnionType[int, None]) – Optionally overrides the thread pool size for compressing and copying data.

Return type:

Context

Returns:

A TensorStore context object.

orbax.checkpoint.type_handlers.get_cast_tspec_serialize(tspec, value, args)[source][source]#

Creates a Tensorstore spec for casting a param during serialize.

Return type:

dict[str, Any]

orbax.checkpoint.type_handlers.get_cast_tspec_deserialize(tspec, args)[source][source]#

Creates a Tensorstore spec for casting a param during deserialize.

Return type:

dict[str, Any]

TypeHandler registry#

class orbax.checkpoint.type_handlers.TypeHandlerRegistry(*args, **kwargs)[source][source]#

A registry for TypeHandlers.

This internal base class is used for the global registry which serves as a default for any type not found in a local registry. It is also accessed through the module function get/set/has_type_handler.

orbax.checkpoint.type_handlers.create_type_handler_registry(*handlers)[source][source]#

Create a type registry.

Parameters:

*handlers – optional pairs of (<type>, <handler>) to initialize the registry with.

Return type:

TypeHandlerRegistry

Returns:

A TypeHandlerRegistry instance with only the specified handlers.

orbax.checkpoint.type_handlers.register_type_handler(ty, handler, func=None, override=False)[source][source]#

Registers a type for serialization/deserialization with a given handler.

Note that it is possible for a type to match multiple different entries in the registry, each with a different handler. In this case, only the first match is used.

Parameters:
  • ty (Any) – A type to register.

  • handler (TypeHandler) – a TypeHandler capable of reading and writing parameters of type ty.

  • func (Optional[Callable[[Any], bool], None]) – A function that accepts a type and returns True if the type should be handled by the provided TypeHandler. If this parameter is not specified, defaults to lambda t: issubclass(t, ty).

  • override (bool) – if True, will override an existing mapping of type to handler.

Raises:

ValueError if a type is already registered and override is False.

orbax.checkpoint.type_handlers.get_type_handler(ty)[source][source]#

Returns the handler registered for a given type, if available.

Return type:

TypeHandler

orbax.checkpoint.type_handlers.has_type_handler(ty)[source][source]#

Returns if there is a handler registered for a given type.

Return type:

bool

orbax.checkpoint.type_handlers.register_standard_handlers_with_options(**kwargs)[source][source]#

Re-registers a select set of handlers with the given options.

This is intended to override options en masse for the standard numeric TypeHandlers and their corresponding types (scalars, numpy arrays and jax.Arrays).

Parameters:

**kwargs – keyword arguments to pass to each of the standard handlers.