Tree Utilities#
Public symbols for tree module.
- orbax.checkpoint.tree.get_param_names(item, *, include_empty_nodes=True)[source][source]#
Gets parameter names for PyTree elements.
- Return type:
Any
- orbax.checkpoint.tree.serialize_tree(tree, keep_empty_nodes=False)[source][source]#
Transforms a PyTree to a serializable format.
IMPORTANT: The returned tree replaces tuple container nodes with list nodes.
IMPORTANT: The returned tree replaces NamedTuple container nodes with dict nodes.
- Parameters:
tree (
Any) – The tree to serialize, if tree is empty and keep_empty_nodes is False, an error is raised as there is no valid representation.keep_empty_nodes (
bool) – If true, does not filter out empty nodes.
- Return type:
Any- Returns:
The serialized PyTree.
- orbax.checkpoint.tree.deserialize_tree(serialized, target, keep_empty_nodes=False)[source][source]#
Deserializes a PyTree to the same structure as target.
- Return type:
Any
- orbax.checkpoint.tree.to_flat_dict(tree, sep=None, keep_empty_nodes=False, is_leaf=None)[source][source]#
Converts a tree into a flattened dictionary.
The nested keys are flattened to a tuple.
Example:
tree = {'foo': 1, 'bar': {'a': 2, 'b': {}}} to_flat_dict(tree) { ('foo',): 1, ('bar', 'a'): 2, }
- Parameters:
tree (
Any) – A PyTree to be flattened.sep (
Optional[str,None]) – If provided, keys will be returned as sep-separated strings. Otherwise, keys are returned as tuples.keep_empty_nodes (
bool) – If True, empty nodes are not filtered out.is_leaf (
Optional[Callable[[Any],bool],None]) – If provided, a function that returns True if a value is a leaf. Overrides keep_empty_nodes if that is also provided.
- Return type:
Any- Returns:
A flattened dictionary and the tree structure.
- orbax.checkpoint.tree.from_flat_dict(flat_dict, target=None, sep=None, *, inplace=False)[source][source]#
Reconstructs the original tree object from a flattened dictionary.
- Parameters:
flat_dict (
Any) – A dictionary conforming to the return value of to_flat_dict.target (
Optional[Any,None]) – A reference PyTree. The returned value will conform to this structure. If not provided, an unflattened dict will be returned with the inferred structure of the original tree, without necessarily matching it exactly. Note, if not provided, the keys in flat_dict need to match sep.sep (
Optional[str,None]) – separator used for nested keys in flat_dict.inplace (
bool) – If True, removes items from flat_dict as they are added to the result.
- Return type:
Any- Returns:
A dict matching the structure of tree with the values of flat_dict.
- orbax.checkpoint.tree.to_shape_dtype_struct(arr, dtype=None, scalar_dtype=None, support_format=False)[source][source]#
Get ShapeDtypeStruct from array-like object.
- Parameters:
arr (
UnionType[AbstractArrayLike,AbstractArrayLikeGlobalShape,ndarray]) – Array-like object. This can include jax.Array, jax.ShapeDtypeStruct, ArrayRestoreArgs, value_metadata.ArrayMetadata - anything that has shape/global_shape, dtype, and sharding properties. It may also be a numpy array or a scalar value.dtype (
UnionType[dtype,None]) – Optional dtype that overrides the dtype of arr in the result.scalar_dtype (
UnionType[int,float,None]) – Optional dtype to use for scalars. Useful for converting to Python scalar types.support_format (
bool) – Whether to support layout in the result.
- Return type:
UnionType[ShapeDtypeStruct,int,float]- Returns:
jax.ShapeDtypeStruct or scalar value.