PyTree Transformations#

Provides utils for transforming PyTrees from one version to another.

class orbax.checkpoint.transform_utils.Transform(original_key=None, use_fallback=False, value_fn=None, multi_value_fn=None)[source][source]#

A representation of a transformation applied to pytree keys/values.

See apply_transformations for usage examples. Transform represents an operation on a single key/value pair. For example, the following mapping:

{'a': Transform(original_key='b')}

This denotes that the original key was named ‘b’, but we are changing it to ‘a’. A regex can also be used as follows:

{r'(.*)a(.*)': Transform(original_key=r'\1b\2'}

This denotes that the key ‘b’ should be renamed to ‘a’. This may apply to multiple different keys at different levels of nesting. The ‘/’ character denotes a successive level of nesting.

We also have the following example:

{'a': Transform(multi_value_fn=lambda kv: kv['b'] * 2)}

This signifies that the new key ‘a’ is the old key ‘b’ multiplied by two.

original_key:

Denotes the original name of the key. Represented as a string with ‘/’ denoting successive levels of nesting. If the key corresponding to this Transform is a regex, backreferences (such as 1) will be replaced with the appropriate matched group in the regex. Note: not needed if multi_value_fn is provided.

use_fallback:

If True, takes the value from the fallback tree. If default_to_original=True in apply_transformations, the fallback tree is new_tree. If default_to_original=False in apply_transformations, the fallback tree is original_tree.

value_fn:

A function accepting a single value and returning a single value. The value provided as an argument is the value of the transformation key in the original PyTree.

multi_value_fn:

A function accepting a string and PyTree and returning any value. The string is the result key associated with the returned value, so the function implementation can know for which key it is supposed to return a value for. The PyTree argument will be the original PyTree, and the function should return the value of the key in the new PyTree.

multi_value_fn_input_args:

A dict of key name (in the original tree) to required input arguments (typically RestoreArgs - see PyTreeCheckpointHandler). These arguments are not used directly in apply_transformations, but are necessary when applying transformations when restoring from a checkpoint in PyTreeCheckpointHandler. These arguments identify “dependencies” in the original tree (the checkpoint) which are needed as inputs by the function, and provides additional information needed for restoration. IMPORTANT: using multi_value_fn during PyTreeCheckpointHandler.restore REQUIRES inputs to be identified.

__eq__(other)#

Return self==value.

__hash__ = None#
__init__(original_key=None, use_fallback=False, value_fn=None, multi_value_fn=None)#
orbax.checkpoint.transform_utils.apply_transformations(original_tree, transformations, new_tree, default_to_original=True)[source][source]#

Applies transformations to a pytree.

Also uses transformations to provide structure to the output tree.

Example:

original_tree = {
  'a': 1,
  'b': {'c': 5, 'd': [0, 1, 2, 3]},
  'f': 2,
  'b1': {'c': 2},
  'b2': {'c': 3},
}
transformations = {
  'a1': Transform(original_key='a'),  # rename
  # another way of doing above
  'a1': Transform(multi_value_fn=lambda kv: kv['a']),
  'b': {
    # doubled original, and drop b/d
    'c': Transform(multi_value_fn=lambda kv: kv['b']['c'] * 2)
  },
  # Copy original into multiple new keys
  'c1': Transform(original_key='b/c'),
  'c2': Transform(original_key='b/c'),
  # one to many mapping
  'x': Transform(multi_value_fn=lambda kv: kv['b']['d'][0]),
  'y': Transform(multi_value_fn=lambda kv: kv['b']['d'][1:]),
  # many to one mapping
  'z': Transform(multi_value_fn=lambda kv: kv['a'] * 2 + sum(kv['b']['d'])),
  r'x(\d.*)': Transform(original_key=r'b\1')
}

# defines the structure of the result
new_tree = {
  'a1': ...,
  'a1': ...,
  'b': {'c': ...},
  'c1': ...,
  'c2': ...,
  'x': ...,
  'y': ...,
  'z': ...,
  # defined in original_tree and new_tree, but not in transforms. Value
  # carried over from original_tree.
  'f': ...,
  # This value matters since it is not present in original_tree or
  # transformations, so the value here will simply be preserved in the
  # result.
  'g': 5,
  # These are just 'b1', 'b2', but renamed to 'x1', 'x2', with all values
  # copied over.
  'x1': {'c': 2}
  'x2': {'c': 3}
}
Parameters:
  • original_tree (Any) – a PyTree to be transformed.

  • transformations (Any) – a PyTree of Transform objects.

  • new_tree (Any) – a PyTree defining the structure of the output. A leaf value is only relevant if the key is not present in transformations or original_tree. Note: values in the provided tree must not be None, or they will be filtered out.

  • default_to_original (Optional[bool]) – If True, the values of keys unspecified in transformations will be taken from original_tree. If False, they will be taken from new_tree.

Return type:

Any

Returns:

a transformed PyTree with the structure of new_tree

orbax.checkpoint.transform_utils.merge_trees(*trees, target=None)[source][source]#

Merges the provided PyTrees into a single result.

If trees have overlapping keys, the key of the last tree in the list will take precedence.

Parameters:
  • *trees – PyTrees to merge.

  • target (Optional[Any]) – A PyTree to provide structure for the returned value. If not provided, the result will take the form of a dictionary.

Return type:

Any

Returns:

A single merged PyTree.