Metadata Utilities#
Defines exported symbols for package orbax.checkpoint.metadata.
Tree Metadata#
- class orbax.checkpoint.metadata.Metadata(name, directory)[source][source]#
Metadata describing PyTree values.
- name:
A string representing the original name of the parameter.
- directory:
The directory where the parameter can be found, after taking name into account.
- __hash__ = None#
- __init__(name, directory)#
- class orbax.checkpoint.metadata.ArrayMetadata(name, directory, shape, sharding, dtype)[source][source]#
Metadata describing an array.
- shape:
Tuple of integers describing the array shape.
- sharding:
ShardingMetadata to indicate how the array is sharded. ShardingMetadata is an orbax representation of jax.sharding.Sharding which stores the same properties but not require accessing real devices.
- dtype:
Dtype of array elements.
- __hash__ = None#
- __init__(name, directory, shape, sharding, dtype)#
Internal Metadata#
- class orbax.checkpoint.metadata.TreeMetadata(tree_metadata_entries, use_zarr3)[source][source]#
Metadata representation of a PyTree.
- classmethod build(tree, *, type_handler_registry, save_args=None, use_zarr3=False)[source][source]#
Builds the tree metadata.
- Return type:
- to_json()[source][source]#
Returns a JSON representation of the metadata.
- Uses JSON format::
- {
- _TREE_METADATA_KEY: {
- “(top_level_key, lower_level_key)”: {
- _KEY_METADATA_KEY: (
{_KEY_NAME: “top_level_key”, _KEY_TYPE: <_KeyType (int)>}, {_KEY_NAME: “lower_level_key”, _KEY_TYPE: <_KeyType (int)>},
) _VALUE_METADATA_KEY: {
_VALUE_TYPE: “jax.Array”, _SKIP_DESERIALIZE: True/False,
}
}
}
- Return type:
Dict
[str
,Any
]
- classmethod from_json(json_dict)[source][source]#
Convert the TreeMetadata from a JSON representation.
- Return type:
- as_nested_tree(*, keep_empty_nodes)[source][source]#
Converts to a nested tree, with values of ValueMetadataEntry.
- Return type:
Dict
[str
,Any
]
- __eq__(other)#
Return self==value.
- __hash__ = None#
- __init__(tree_metadata_entries, use_zarr3)#