model.core.python.tree_util

Contents

model.core.python.tree_util#

Generic tree (similar to JAX’s PyTree) (i.e. nested structures).

orbax.experimental.model.core.python.tree_util.assert_tree(assert_leaf, tree)[source][source]#

Checks that an Any object is a valid Tree.

If assert_leaf checks the type of the leaf, the caller can safely infer the type parameter of the Tree after assert_tree passes. For example, if assert_leaf is lambda x: assert isinstance(x, str), after assert_tree passes, one can safely claim that obj is a Tree[str].

Parameters:
  • assert_leaf (Callable[[Any], None]) – A function that checks that a leaf is valid. When traversing tree as a tree, any non-list/tuple/dict/None node will be passed to assert_leaf.

  • tree (Any) – The tree where all leaves are to be checked by assert_leaf.

Return type:

None