Metadata Utilities

Contents

Metadata Utilities#

Defines exported symbols for package orbax.checkpoint.metadata.

Tree Metadata#

class orbax.checkpoint.metadata.Metadata(name, directory)[source][source]#

Metadata describing PyTree values.

name:

A string representing the original name of the parameter.

directory:

The directory where the parameter can be found, after taking name into account.

__eq__(other)[source][source]#

Return self==value.

Return type:

bool

__hash__ = None#
__init__(name, directory)#
class orbax.checkpoint.metadata.ArrayMetadata(name, directory, shape, sharding, dtype)[source][source]#

Metadata describing an array.

shape:

Tuple of integers describing the array shape.

sharding:

ShardingMetadata to indicate how the array is sharded. ShardingMetadata is an orbax representation of jax.sharding.Sharding which stores the same properties but not require accessing real devices.

dtype:

Dtype of array elements.

__eq__(other)[source][source]#

Return self==value.

Return type:

bool

__hash__ = None#
__init__(name, directory, shape, sharding, dtype)#
class orbax.checkpoint.metadata.ScalarMetadata(name, directory, shape=(), sharding=None, dtype=None)[source][source]#

Metadata describing a scalar value.

dtype:

Scalar dtype.

__eq__(other)[source][source]#

Return self==value.

Return type:

bool

__hash__ = None#
__init__(name, directory, shape=(), sharding=None, dtype=None)#
class orbax.checkpoint.metadata.StringMetadata(name, directory)[source][source]#

Metadata describing a string value.

__eq__(other)[source][source]#

Return self==value.

Return type:

bool

__hash__ = None#
__init__(name, directory)#

Sharding Metadata#

class orbax.checkpoint.metadata.ShardingMetadata[source][source]#

ShardingMetadata representing Sharding property.

This ShardingMetadata only represents the following jax.sharding.Sharding:

jax.sharding.NamedSharding jax.sharding.SingleDeviceSharding jax.sharding.GSPMDSharding jax.sharding.PositionalSharding

abstract classmethod from_jax_sharding(jax_sharding)[source][source]#

Converts jax.sharding.Sharding to ShardingMetadata.

Return type:

ShardingMetadata

abstract to_jax_sharding()[source][source]#

Converts ShardingMetadata to jax.sharding.Sharding.

Return type:

Sharding

abstract classmethod from_deserialized_dict(deserialized_dict)[source][source]#

Converts serialized_string in the form of dict[str, str] to ShardingMetadata.

Return type:

ShardingMetadata

abstract to_serialized_string()[source][source]#

Converts ShardingMetadata to serialized_string.

Return type:

str

__eq__(other)#

Return self==value.

__hash__ = None#
__init__()#
class orbax.checkpoint.metadata.NamedShardingMetadata(shape, axis_names, partition_spec)[source][source]#

NamedShardingMetadata representing jax.sharding.NamedSharding.

classmethod from_jax_sharding(jax_sharding)[source][source]#

Converts jax.sharding.Sharding to ShardingMetadata.

Return type:

NamedShardingMetadata

to_jax_sharding()[source][source]#

Converts ShardingMetadata to jax.sharding.Sharding.

Return type:

NamedSharding

classmethod from_deserialized_dict(deserialized_dict)[source][source]#

Converts serialized_string in the form of dict[str, str] to ShardingMetadata.

Return type:

NamedShardingMetadata

to_serialized_string()[source][source]#

Converts ShardingMetadata to serialized_string.

Return type:

str

__eq__(other)[source][source]#

Return self==value.

__hash__ = None#
__init__(shape, axis_names, partition_spec)#
class orbax.checkpoint.metadata.SingleDeviceShardingMetadata(device_str)[source][source]#

SingleDeviceShardingMetadata representing jax.sharding.SingleDeviceSharding.

classmethod from_jax_sharding(jax_sharding)[source][source]#

Converts jax.sharding.Sharding to ShardingMetadata.

Return type:

SingleDeviceShardingMetadata

to_jax_sharding()[source][source]#

Converts ShardingMetadata to jax.sharding.Sharding.

Return type:

SingleDeviceSharding

classmethod from_deserialized_dict(deserialized_dict)[source][source]#

Converts serialized_string in the form of dict[str, str] to ShardingMetadata.

Return type:

SingleDeviceShardingMetadata

to_serialized_string()[source][source]#

Converts ShardingMetadata to serialized_string.

Return type:

str

__eq__(other)[source][source]#

Return self==value.

__hash__ = None#
__init__(device_str)#
class orbax.checkpoint.metadata.GSPMDShardingMetadata[source][source]#
__eq__(other)#

Return self==value.

__hash__ = None#
__init__()#
class orbax.checkpoint.metadata.PositionalShardingMetadata[source][source]#
__eq__(other)#

Return self==value.

__hash__ = None#
__init__()#
class orbax.checkpoint.metadata.ShardingTypes(value)[source][source]#

An enumeration.

orbax.checkpoint.metadata.from_jax_sharding(jax_sharding)[source][source]#

Converts jax.sharding.Sharding to ShardingMetadata.

Return type:

Optional[ShardingMetadata]

orbax.checkpoint.metadata.from_serialized_string(serialized_str)[source][source]#

Converts serialized_string to ShardingMetadata.

Return type:

ShardingMetadata

orbax.checkpoint.metadata.get_sharding_or_none(serialized_string)[source][source]#

Internal Metadata#

class orbax.checkpoint.metadata.TreeMetadata(tree_metadata_entries, use_zarr3)[source][source]#

Metadata representation of a PyTree.

classmethod build(tree, *, type_handler_registry, save_args=None, use_zarr3=False)[source][source]#

Builds the tree metadata.

Return type:

TreeMetadata

to_json()[source][source]#

Returns a JSON representation of the metadata.

Uses JSON format::
{
_TREE_METADATA_KEY: {
“(top_level_key, lower_level_key)”: {
_KEY_METADATA_KEY: (

{_KEY_NAME: “top_level_key”, _KEY_TYPE: <_KeyType (int)>}, {_KEY_NAME: “lower_level_key”, _KEY_TYPE: <_KeyType (int)>},

) _VALUE_METADATA_KEY: {

_VALUE_TYPE: “jax.Array”, _SKIP_DESERIALIZE: True/False,

}

}

}

Return type:

Dict[str, Any]

classmethod from_json(json_dict)[source][source]#

Convert the TreeMetadata from a JSON representation.

Return type:

TreeMetadata

as_nested_tree(*, keep_empty_nodes)[source][source]#

Converts to a nested tree, with values of ValueMetadataEntry.

Return type:

Dict[str, Any]

__eq__(other)#

Return self==value.

__hash__ = None#
__init__(tree_metadata_entries, use_zarr3)#