model.core.python.tree_util#
Generic tree (similar to JAX’s PyTree) (i.e. nested structures).
- orbax.experimental.model.core.python.tree_util.assert_tree(assert_leaf, tree)[source][source]#
Checks that an Any object is a valid Tree.
If assert_leaf checks the type of the leaf, the caller can safely infer the type parameter of the Tree after assert_tree passes. For example, if assert_leaf is lambda x: assert isinstance(x, str), after assert_tree passes, one can safely claim that obj is a Tree[str].
- Parameters:
assert_leaf (
Callable[[Any],None]) – A function that checks that a leaf is valid. When traversing tree as a tree, any non-list/tuple/dict/None node will be passed to assert_leaf.tree (
Any) – The tree where all leaves are to be checked by assert_leaf.
- Return type:
None