ocp.v1.tree module#

Public symbols for tree module.

Standard supported leaf types are described by the table below. See https://orbax.readthedocs.io/en/latest/guides/checkpoint/v1/checkpointing_pytrees.html#standard-leaf-types for more information.

Leaf Type | AbstractLeaf Type | Properties |

:——- | :——– | :——– | |`jax.Array`|`ocp.arrays.AbstractShardedArray` (`jax.ShapeDtypeStruct`) |`shape`, `dtype`, `sharding`| |`np.ndarray`|`ocp.arrays.AbstractArray` (`np.ndarray`) |`shape`, `dtype`| |`int`|`int`| | |`float`|`float`| | |`bytes`|`bytes`| | |`str`|`str`| |

Types#

orbax.checkpoint.experimental.v1.tree.PyTree = <class 'orbax.checkpoint.experimental.v1._src.tree.types.PyTree'>[source][source]#

A PyTree with any leaf type (PyTreeOf[Any]).

orbax.checkpoint.experimental.v1.tree.PyTreeOf = <class 'orbax.checkpoint.experimental.v1._src.tree.types.PyTreeOf'>[source][source]#

A PyTree with leaf types of type T.

Functionally this type is treated as Any since JAX PyTrees cannot be identified by static type checkers.

See https://jax.readthedocs.io/en/latest/pytrees.html for information on PyTrees.

At a very high level, a PyTree is a container-like object such as a dict, list, or flax.struct.dataclass. The elements of these containers can be traversed as a nested tree using jax.tree.* functions.

In a checkpointing context, tree leaves are typically arrays or scalars. Even though arrays are logically lists, they are treated by JAX as leaf nodes.

Note that all leaf nodes are definitionally PyTrees.

orbax.checkpoint.experimental.v1.tree.Leaf = jax.Array | numpy.ndarray | int | float | numpy.number | bytes | bool | str#

Represent a PEP 604 union type

E.g. for int | str

orbax.checkpoint.experimental.v1.tree.AbstractLeaf = orbax.checkpoint.experimental.v1._src.arrays.types.AbstractArray | orbax.checkpoint.experimental.v1._src.arrays.types.AbstractShardedArray | int | float | numpy.number | bytes | bool | str#

Represent a PEP 604 union type

E.g. for int | str

orbax.checkpoint.experimental.v1.tree.PyTreeKey#

alias of SequenceKey | DictKey | GetAttrKey | FlattenedIndexKey

orbax.checkpoint.experimental.v1.tree.PyTreeKeyPath#

Built-in immutable sequence.

If no argument is given, the constructor returns an empty tuple. If iterable is specified the tuple is initialized from iterable’s items.

If the argument is a tuple, the return value is the same object.

alias of tuple[SequenceKey | DictKey | GetAttrKey | FlattenedIndexKey, …]

orbax.checkpoint.experimental.v1.tree.JsonType = list['JsonValue'] | dict[str, 'JsonValue']#

Represent a PEP 604 union type

E.g. for int | str

Structure Utils#

orbax.checkpoint.experimental.v1.tree.merge(*trees, overwrite=False, is_leaf=None)[source]#

Merges trees into a single tree using a comprehensive recursive strategy.

This implementation handles standard Python containers (dicts, lists, tuples), named tuples, and custom JAX PyTree nodes, mirroring the robustness of utilities like tree_trim.

  • Mappings (dict, etc.) are merged by key.

  • Sequences (list, tuple) are merged element-wise if they have the same length; otherwise, a ValueError is raised.

  • Dataclasses (dataclass, etc.) are merged by field name, where non-None values overwrite None values.

  • If overwrite is False, a ValueError is raised for mismatched types.

Example

Merge two PyTrees without overlapping leaf paths:

tree1 = {"a": 1, "b": {"c": 2}}
tree2 = {"d": 3, "b": {"e": 4}}

merged_tree = merge(tree1, tree2)
# Result: {"a": 1, "b": {"c": 2, "e": 4}, "d": 3}

Merge PyTrees with overlapping leaf paths using overwrite=True:

tree3 = {"a": 100}
overwritten_tree = merge(tree1, tree3, overwrite=True)
# Result: {"a": 100, "b": {"c": 2}}
Parameters:
  • *trees – The trees to merge.

  • overwrite (bool) – If True, later values from trees will overwrite earlier values where leaf paths conflict. If False, a ValueError is raised for conflicting leaves.

  • is_leaf (Optional[Callable[[Any], bool], None]) – Optional function to determine if a node is a leaf. Defaults to jax.tree_util.all_leaves.

Return type:

Any

Returns:

A new PyTree representing the merged content of trees.

Raises:

ValueError – If overwrite is False and there are common leaf key paths between trees. Or if overwrite is False and the types of nodes do not match.