Metadata Utilities#

Defines exported symbols for package orbax.checkpoint.metadata.

Tree Metadata#

class orbax.checkpoint.metadata.ArrayMetadata(name, directory, shape, sharding, dtype, storage=None)[source][source]#

Metadata describing an array.

shape:

Tuple of integers describing the array shape.

sharding:

ShardingMetadata to indicate how the array is sharded. ShardingMetadata is an orbax representation of jax.sharding.Sharding which stores the same properties but not require accessing real devices.

dtype:

Dtype of array elements.

storage:

Optional metadata describing how the array is stored in a checkpoint.

__eq__(other)[source][source]#

Return self==value.

Return type:

bool

__hash__ = None#
__init__(name, directory, shape, sharding, dtype, storage=None)#
class orbax.checkpoint.metadata.ScalarMetadata(name, directory, shape=(), sharding=None, dtype=None, storage=None)[source][source]#

Metadata describing a scalar value.

dtype:

Scalar dtype.

__eq__(other)[source][source]#

Return self==value.

Return type:

bool

__hash__ = None#
__init__(name, directory, shape=(), sharding=None, dtype=None, storage=None)#
class orbax.checkpoint.metadata.StringMetadata(name, directory)[source][source]#

Metadata describing a string value.

__eq__(other)[source][source]#

Return self==value.

Return type:

bool

__hash__ = None#
__init__(name, directory)#
class orbax.checkpoint.metadata.StorageMetadata(chunk_shape, write_shape=None)[source][source]#

Metadata describing how arrays are stored in a checkpoint.

__delattr__(name)#

Implement delattr(self, name).

__eq__(other)#

Return self==value.

__hash__()#

Return hash(self).

__init__(chunk_shape, write_shape=None)#
__setattr__(name, value)#

Implement setattr(self, name, value).

Sharding Metadata#

class orbax.checkpoint.metadata.ShardingMetadata[source][source]#

ShardingMetadata representing Sharding property.

This ShardingMetadata only represents the following jax.sharding.Sharding:

jax.sharding.NamedSharding jax.sharding.SingleDeviceSharding

abstractmethod classmethod from_jax_sharding(jax_sharding)[source][source]#

Converts jax.sharding.Sharding to ShardingMetadata.

Return type:

ShardingMetadata

abstractmethod to_jax_sharding()[source][source]#

Converts ShardingMetadata to jax.sharding.Sharding.

Return type:

Sharding

abstractmethod classmethod from_deserialized_dict(deserialized_dict)[source][source]#

Converts serialized_string in the form of dict[str, str] to ShardingMetadata.

Return type:

ShardingMetadata

abstractmethod to_serialized_string()[source][source]#

Converts ShardingMetadata to serialized_string.

Return type:

str

__eq__(other)#

Return self==value.

__hash__ = None#
__init__()#
class orbax.checkpoint.metadata.NamedShardingMetadata(shape, axis_names, partition_spec, axis_types=None, device_mesh=None)[source][source]#

NamedShardingMetadata representing jax.sharding.NamedSharding.

classmethod from_jax_sharding(jax_sharding)[source][source]#

Converts jax.sharding.Sharding to ShardingMetadata.

Return type:

NamedShardingMetadata

to_jax_sharding()[source][source]#

Converts ShardingMetadata to jax.sharding.Sharding.

Return type:

NamedSharding

classmethod from_deserialized_dict(deserialized_dict)[source][source]#

Converts serialized_string in the form of dict[str, str] to ShardingMetadata.

Return type:

NamedShardingMetadata

to_serialized_string()[source][source]#

Converts ShardingMetadata to serialized_string.

Return type:

str

__eq__(other)[source][source]#

Return self==value.

__hash__ = None#
__init__(shape, axis_names, partition_spec, axis_types=None, device_mesh=None)#
class orbax.checkpoint.metadata.SingleDeviceShardingMetadata(device_str)[source][source]#

SingleDeviceShardingMetadata representing jax.sharding.SingleDeviceSharding.

classmethod from_jax_sharding(jax_sharding)[source][source]#

Converts jax.sharding.Sharding to ShardingMetadata.

Return type:

SingleDeviceShardingMetadata

to_jax_sharding()[source][source]#

Converts ShardingMetadata to jax.sharding.Sharding.

Return type:

SingleDeviceSharding

classmethod from_deserialized_dict(deserialized_dict)[source][source]#

Converts serialized_string in the form of dict[str, str] to ShardingMetadata.

Return type:

SingleDeviceShardingMetadata

to_serialized_string()[source][source]#

Converts ShardingMetadata to serialized_string.

Return type:

str

__eq__(other)[source][source]#

Return self==value.

__hash__ = None#
__init__(device_str)#