TypeHandlers

Contents

TypeHandlers#

Provides utils for PytreeCheckpointHandler.

Arguments for PyTreeCheckpointHandler#

class orbax.checkpoint.type_handlers.SaveArgs(aggregate=False, dtype=None, write_chunk_shape=None, read_chunk_shape=None, chunk_byte_size=None)[source][source]#

Extra arguments that can be provided for saving.

aggregate:

Deprecated, please use custom TypeHandler (https://orbax.readthedocs.io/en/latest/custom_handlers.html#typehandler) or contact Orbax team to migrate before August 1st, 2024. If true, saves the given parameter in an aggregated tree format rather than individually. See AggregateHandler.

dtype:

If provided, casts the parameter to the given dtype before saving. Note that the parameter must be compatible with the given type (e.g. jnp.bfloat16 is not compatible with np.ndarray).

write_chunk_shape:

This only applies to Zarr version 3. This specifies the shape of a shard used in writing. The default(None) is set to equal to the array shard size, so there are equal number of write chunks and shards. The write_chunk_shape needs to be a divisor of the array shape.

read_chunk_shape:

This only applies to Zarr version 3. This specifies the chunk sizes within a write chunk. Default is set to equal to the write_chunk_shape. The read_chunk_shape is required to be a divisor of the write_chunk_shape.

chunk_byte_size:

This is an experimental feature that automatically chooses the largest chunk shape possible, while keeping the chunk byte size less than or equal to the specified chunk_byte_size. Both the write_chunk_shape and read_chunk_shape are automatically set to the chosen shape. This uses a greedy algorithm that prioritizes splitting the largest dimensions first. In order to enable this feature, both write_chunk_shape and read_chunk_shape must be set to None.

__eq__(other)#

Return self==value.

__hash__ = None#
__init__(aggregate=False, dtype=None, write_chunk_shape=None, read_chunk_shape=None, chunk_byte_size=None)#
class orbax.checkpoint.type_handlers.RestoreArgs(restore_type=None, dtype=None)[source][source]#

Extra arguments that can be provided for restoration.

restore_type:

Specifies the object type of the restored parameter. The type must have a corresponding TypeHandler for restoration. Ignored if the parameter is restored from an aggregated checkpoint file.

dtype:

If provided, casts the parameter to the given dtype after restoring. Note that the parameter must be compatible with the given type (e.g. jnp.bfloat16 is not compatible with np.ndarray).

__eq__(other)#

Return self==value.

__hash__ = None#
__init__(restore_type=None, dtype=None)#
class orbax.checkpoint.type_handlers.ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=None, mesh=None, mesh_axes=None, sharding=None, global_shape=None)[source][source]#

Arguments used when restoring with ArrayHandler.

restore_type:

See parent class.

mesh:

The device mesh that the array should be restored as. Cannot be None.

mesh_axes:

The mesh_axes that the array should be restored as. Cannot be None.

sharding:
jax.sharding.Sharding object which takes precedence over mesh and

mesh_axes if provided. Otherwise, mesh and mesh_axes will be used to construct a NamedSharding object OR ShardingMetadata which is an orbax representation of jax.sharding.Sharding that stores the same properties but does not require accessing real devices.

global_shape:

The global shape that the array should be restored into. If not provided, the shape will be restored as written. Presently, arbitrary shape transformations are not supported (for example, reshaping to different dimensions). Padding and truncating are supported. When the global_shape is greater than that of the saved array, 0’s will be appended. If the global_shape is shorter than that of the saved array, excess elements will be dropped from the end of the array.

restore_type#

alias of Array

__eq__(other)#

Return self==value.

__hash__ = None#
__init__(restore_type=<class 'jax.Array'>, dtype=None, mesh=None, mesh_axes=None, sharding=None, global_shape=None)#

TypeHandler#

class orbax.checkpoint.type_handlers.TypeHandler[source][source]#

Interface for reading and writing a PyTree leaf.

abstract typestr()[source][source]#

A string representation of the type.

Cannot conflict with other types.

Return type:

str

Returns:

The type as a string.

abstract async metadata(infos)[source][source]#

Constructs object metadata from a stored parameter location.

Parameters:

infos (Sequence[ParamInfo]) – sequence of ParamInfo

Return type:

Sequence[Metadata]

Returns:

Sequence of Metadata for each provided ParamInfo.

abstract async serialize(values, infos, args=None)[source][source]#

Writes the parameter to a storage location.

This method is responsible for copying the parameter from a remote device in a synchronous fashion (if applicable). It should then return a list of futures which can be later awaited to complete the final commit operation to a storage location.

The function can be used in a multihost setting, but should not implement extra logic to ensure atomicity.

Parameters:
  • values (Sequence[Any]) – a sequence of parameters to save.

  • infos (Sequence[ParamInfo]) – a sequence of ParamInfo containing relevant information for serialization of each value.

  • args (Optional[Sequence[SaveArgs]]) – a sequnece of additional arguments for serialization, provided by the user.

Return type:

Sequence[Future]

Returns:

Sequence of commit futures which can be awaited to complete the save operation.

abstract async deserialize(infos, args=None)[source][source]#

Reads the parameter from a storage location.

Parameters:
  • infos (Sequence[ParamInfo]) – Sequnece of ParamInfo for deserialization.

  • args (Optional[Sequence[RestoreArgs]]) – Sequence of user-provided restoration information.

Return type:

Sequence[Any]

Returns:

The deserialized parameters.

finalize(directory)[source][source]#

Performs any logic to finalize parameter files written by this class.

By default, does nothing.

Parameters:

directory (Path) – A path to the location of the checkpoint. This corresponds to param_info.parent_dir.

NumpyHandler#

class orbax.checkpoint.type_handlers.NumpyHandler(metadata_key=None)[source][source]#

Provides an implementation of TypeHandler for replicated numpy arrays.

__init__(metadata_key=None)[source][source]#

Constructor.

Parameters:

metadata_key (Optional[str]) – name to give to Tensorstore metadata files.

typestr()[source][source]#

A string representation of the type.

Cannot conflict with other types.

Return type:

str

Returns:

The type as a string.

async metadata(infos)[source][source]#

Constructs object metadata from a stored parameter location.

Parameters:

infos (Sequence[ParamInfo]) – sequence of ParamInfo

Return type:

Sequence[ArrayMetadata]

Returns:

Sequence of Metadata for each provided ParamInfo.

async serialize(values, infos, args=None)[source][source]#

Uses Tensorstore to serialize a numpy array.

Return type:

Sequence[Future]

async deserialize(infos, args=None)[source][source]#

Deserializes the array using Tensorstore.

Return type:

Sequence[ndarray]

ScalarHandler#

class orbax.checkpoint.type_handlers.ScalarHandler(metadata_key=None)[source][source]#

A wrapper around NumpyHandler to deal with scalar types (int, float, etc.).

typestr()[source][source]#

A string representation of the type.

Cannot conflict with other types.

Return type:

str

Returns:

The type as a string.

async metadata(infos)[source][source]#

Constructs object metadata from a stored parameter location.

Parameters:

infos (Sequence[ParamInfo]) – sequence of ParamInfo

Return type:

Sequence[ScalarMetadata]

Returns:

Sequence of Metadata for each provided ParamInfo.

async serialize(values, infos, args=None)[source][source]#

See superclass documentation.

Return type:

Sequence[Future]

async deserialize(infos, args=None)[source][source]#

See superclass documentation.

Return type:

Sequence[Union[int, float, number]]

ArrayHandler#

class orbax.checkpoint.type_handlers.ArrayHandler(metadata_key=None, primary_host=0, replica_id=0)[source][source]#

An implementation of TypeHandler for jax.Array.

__init__(metadata_key=None, primary_host=0, replica_id=0)[source][source]#

Constructor.

Parameters:
  • metadata_key (Optional[str]) – name to give to Tensorstore metadata files.

  • primary_host (Optional[int]) – the host id of the primary host. Default to 0. If it’s set to None, then all hosts will be considered as primary. It’s useful in the case that all hosts are only working with local storage.

  • replica_id (Optional[int]) – the replica id to be used for saving. Default to 0. If it’s set to None, each shards will pick first replica_id to be used. It’s useful in the case that all hosts are only working with local storage.

typestr()[source][source]#

A string representation of the type.

Cannot conflict with other types.

Return type:

str

Returns:

The type as a string.

async metadata(infos)[source][source]#

Constructs object metadata from a stored parameter location.

Parameters:

infos (Sequence[ParamInfo]) – sequence of ParamInfo

Return type:

Sequence[ArrayMetadata]

Returns:

Sequence of Metadata for each provided ParamInfo.

async serialize(values, infos, args=None)[source][source]#

See superclass documentation.

Return type:

Sequence[Future]

async deserialize(infos, args=None)[source][source]#

See superclass documentation.

Parameters:
  • infos (Sequence[ParamInfo]) – ParamInfo.

  • args (Optional[Sequence[RestoreArgs]]) – must be of type ArrayRestoreArgs.

Return type:

Sequence[Array]

Returns:

The deserialized parameter.

Raises:
  • ValueError if args is not provided.

  • ValueError if args.sharding is not provided or args.mesh and

  • args.mesh_axes

SingleSliceArrayHandler#

class orbax.checkpoint.type_handlers.ArrayHandler(metadata_key=None, primary_host=0, replica_id=0)[source][source]#

An implementation of TypeHandler for jax.Array.

__init__(metadata_key=None, primary_host=0, replica_id=0)[source][source]#

Constructor.

Parameters:
  • metadata_key (Optional[str]) – name to give to Tensorstore metadata files.

  • primary_host (Optional[int]) – the host id of the primary host. Default to 0. If it’s set to None, then all hosts will be considered as primary. It’s useful in the case that all hosts are only working with local storage.

  • replica_id (Optional[int]) – the replica id to be used for saving. Default to 0. If it’s set to None, each shards will pick first replica_id to be used. It’s useful in the case that all hosts are only working with local storage.

typestr()[source][source]#

A string representation of the type.

Cannot conflict with other types.

Return type:

str

Returns:

The type as a string.

async metadata(infos)[source][source]#

Constructs object metadata from a stored parameter location.

Parameters:

infos (Sequence[ParamInfo]) – sequence of ParamInfo

Return type:

Sequence[ArrayMetadata]

Returns:

Sequence of Metadata for each provided ParamInfo.

async serialize(values, infos, args=None)[source][source]#

See superclass documentation.

Return type:

Sequence[Future]

async deserialize(infos, args=None)[source][source]#

See superclass documentation.

Parameters:
  • infos (Sequence[ParamInfo]) – ParamInfo.

  • args (Optional[Sequence[RestoreArgs]]) – must be of type ArrayRestoreArgs.

Return type:

Sequence[Array]

Returns:

The deserialized parameter.

Raises:
  • ValueError if args is not provided.

  • ValueError if args.sharding is not provided or args.mesh and

  • args.mesh_axes

StringHandler#

class orbax.checkpoint.type_handlers.StringHandler(filename=None)[source][source]#

TypeHandler for strings.

__init__(filename=None)[source][source]#
typestr()[source][source]#

A string representation of the type.

Cannot conflict with other types.

Return type:

str

Returns:

The type as a string.

async metadata(infos)[source][source]#

Constructs object metadata from a stored parameter location.

Parameters:

infos (Sequence[ParamInfo]) – sequence of ParamInfo

Return type:

Sequence[StringMetadata]

Returns:

Sequence of Metadata for each provided ParamInfo.

async serialize(values, infos, args=None)[source][source]#

See superclass documentation.

Return type:

Sequence[Future]

async deserialize(infos, args=None)[source][source]#

See superclass documentation.

Return type:

Sequence[Optional[str]]

Tensorstore functions#

orbax.checkpoint.type_handlers.is_ocdbt_checkpoint(path)[source][source]#

Determines whether a checkpoint uses OCDBT format.

Return type:

bool

orbax.checkpoint.type_handlers.merge_ocdbt_per_process_files(directory)[source][source]#

Merges OCDBT files written to per-process subdirectories.

With Tensorstore’s OCDBT format, arrays are initially written to per-process subdirectories, depending on which host is doing the writing. This function can be called to merge the per-process files into a global key-value store.

The original per-process subdirectories are not and should not be deleted - the global kvstore continues to reference them.

Parameters:

directory (Path) – checkpoint location.

orbax.checkpoint.type_handlers.get_json_tspec_write(info, use_ocdbt, global_shape, local_shape, dtype, process_index=None, metadata_key=None, arg=None)[source][source]#

Gets Tensorstore spec for writing.

orbax.checkpoint.type_handlers.get_json_tspec_read(info, use_ocdbt, metadata_key=None)[source][source]#

Gets Tensorstore spec for reading.

orbax.checkpoint.type_handlers.get_ts_context(use_ocdbt)[source][source]#

Returns a shared global TensorStore Context instance to use.

Return type:

Context

orbax.checkpoint.type_handlers.get_cast_tspec_serialize(tspec, value, args)[source][source]#

Creates a Tensorstore spec for casting a param during serialize.

orbax.checkpoint.type_handlers.get_cast_tspec_deserialize(tspec, args)[source][source]#

Creates a Tensorstore spec for casting a param during deserialize.

TypeHandler registry#

class orbax.checkpoint.type_handlers.TypeHandlerRegistry(*args, **kwargs)[source][source]#

A registry for TypeHandlers.

This internal base class is used for the global registry which serves as a default for any type not found in a local registry. It is also accessed through the module function get/set/has_type_handler.

orbax.checkpoint.type_handlers.create_type_handler_registry(*handlers)[source][source]#

Create a type registry.

Parameters:

*handlers – optional pairs of (<type>, <handler>) to initialize the registry with.

Return type:

TypeHandlerRegistry

Returns:

A TypeHandlerRegistry instance with only the specified handlers.

orbax.checkpoint.type_handlers.register_type_handler(ty, handler, func=None, override=False)[source][source]#

Registers a type for serialization/deserialization with a given handler.

Note that it is possible for a type to match multiple different entries in the registry, each with a different handler. In this case, only the first match is used.

Parameters:
  • ty (Any) – A type to register.

  • handler (TypeHandler) – a TypeHandler capable of reading and writing parameters of type ty.

  • func (Optional[Callable[[Any], bool]]) – A function that accepts a type and returns True if the type should be handled by the provided TypeHandler. If this parameter is not specified, defaults to lambda t: issubclass(t, ty).

  • override (bool) – if True, will override an existing mapping of type to handler.

Raises:

ValueError if a type is already registered and override is False.

orbax.checkpoint.type_handlers.get_type_handler(ty)[source][source]#

Returns the handler registered for a given type, if available.

Return type:

TypeHandler

orbax.checkpoint.type_handlers.has_type_handler(ty)[source][source]#

Returns if there is a handler registered for a given type.

Return type:

bool

orbax.checkpoint.type_handlers.register_standard_handlers_with_options(**kwargs)[source][source]#

Re-registers a select set of handlers with the given options.

This is intended to override options en masse for the standard numeric TypeHandlers and their corresponding types (scalars, numpy arrays and jax.Arrays).

Parameters:

**kwargs – keyword arguments to pass to each of the standard handlers.