Transformations#

Overview#

The transform_utils library provides functions to perform structural PyTree transformations, which can facilitate model surgery for finetuning, migrations between different checkpoint versions, etc.

The API consists of a Transform class and an apply_transformations function.

apply_transformations#

The apply_transformations function accepts an original PyTree, a PyTree of Transform objects and the desired structure of the returned Pytree. The function returns a newly generated PyTree.

def apply_transformations(
    original_tree: PyTree,
    transformations: PyTree,
    new_tree: PyTree,
    default_to_original: Optional[bool] = True) -> PyTree:

Transform#

Transform consists of the following elements:

  • original_key: Denotes the original name of the key. Represented as a string with ‘/’ denoting successive levels of nesting. If the key corresponding to this Transform is a regex, backreferences (such as \1) will be replaced with the appropriate matched group in the regex. Note: not needed if multi_value_fn is provided.

  • use_fallback: if True, takes the value from the fallback tree. If default_to_original=True in apply_transformations, the fallback tree is new_tree. If default_to_original=False in apply_transformations, the fallback tree is original_tree.

  • value_fn: A function accepting a single value and returning a single value. The value provided as an argument is the value of the transformation key in the original PyTree.

  • multi_value_fn: A function accepting a PyTree and returning any value. The PyTree argument will be the original PyTree, and the function should return the value of the key in the new PyTree.

Fallbacks#

Note that there is an additional option for apply_transformations, which is default_to_original (True by default). This means that the values keys unspecified in transformations but present in both trees will be taken from the original tree. If False, such values will be taken from the new tree.

Remember that if a key is present in the new tree, but not in the old, the value will simply be taken from the new tree. If a key is present in the original tree but not in the new, it will be dropped in the result.

Examples#

# Setup
import orbax.checkpoint as ocp
import numpy as np

Renaming keys#

Key renames are common for reusing existing checkpointed state between different models or same model at different versions.

# Example: Migrate original tree into the new_tree, which has the same
# nested structure but different keys.
original_tree = {
    'a': 1,
    'b': 2
}

transformations = {
    'a2': ocp.Transform(original_key='a'),
    'b2': ocp.Transform(original_key='b')
}

new_tree = {
    'a2': ...,
    'b2': ...
}

ocp.apply_transformations(original_tree, transformations, new_tree)
WARNING:absl:The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
{'a2': 1, 'b2': 2}
# Example 2: Renaming with regex

original_tree = {
    'a1': 1,
    'b5': 2
}

transformations = {
    r'([a-z])_([0-9])': ocp.Transform(original_key=r'\1\2'),
}

new_tree = {
    'a_1': ...,
    'b_5': ...
}

ocp.apply_transformations(original_tree, transformations, new_tree)
WARNING:absl:The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
{'a_1': 1, 'b_5': 2}
# Example 3: Renaming nested trees

original_tree = {
    'a': 1,
    'dense_1': {'kernel': 2, 'bias': 3},
    'dense_2': {'kernel': 4, 'bias': 5},
}

# Nested keys can be represented by a single string by separating each level
# with '/'.
transformations = {
    r'([a-z]+)_NEW': ocp.Transform(original_key=r'\1'),
    r'([a-z]+)_([0-9])_NEW/([a-z]+)_1': ocp.Transform(original_key=r'\1_\2/\3'),
}

# This is equivalent to:
transformations = {
    r'([a-z]+)_NEW': ocp.Transform(original_key=r'\1'),
    r'([a-z]+)_([0-9])_NEW': {
        '([a-z]+)_1': ocp.Transform(original_key=r'\1_\2/\3'),}
}

new_tree = {
    'a_NEW': ...,
    'dense_1_NEW': {'kernel_1': ..., 'bias_1': ...},
    'dense_2_NEW': {'kernel_1': ..., 'bias_1': ...},
}

ocp.apply_transformations(original_tree, transformations, new_tree)
WARNING:absl:The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
{'a_NEW': 1,
 'dense_1_NEW': {'bias_1': 3, 'kernel_1': 2},
 'dense_2_NEW': {'bias_1': 5, 'kernel_1': 4}}

Updating the value#

To change a leaf node in the Pytree, define a Transform with a value_fn. This transformation could be used for quantization, modifying hyperparameters, etc.

# Example: Transform the values in a tree.
original_tree = {
    'a': 1,
    'b': 2
}

transformations = {
    'a': ocp.Transform(value_fn=lambda v: v * 2),
    'b2': ocp.Transform(value_fn=lambda v: v * 3, original_key='b')
}

new_tree = {
    'a': ...,
    'b2': ...  # Output different key
}

ocp.apply_transformations(original_tree, transformations, new_tree)
WARNING:absl:The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
{'a': 2, 'b2': 6}
# Example 2: Transform values in a tree with regex (multiply all 'a' keys by 2
# all 'b' keys by 3).
original_tree = {
    'a1': 1,
    'a2': 2,
    'b': 3
}

transformations = {
    r'a([0-9]?)\*2': ocp.Transform(value_fn=lambda v: v * 2,
                                     original_key=r'a\1'),
    r'b([0-9]?)\*3': ocp.Transform(value_fn=lambda v: v * 3,
                                     original_key=r'b\1')
}

new_tree = {
    'a1*2': ...,
    'a2*2': ...,
    'b*3': ...
}

ocp.apply_transformations(original_tree, transformations, new_tree)
WARNING:absl:The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
{'a1*2': 2, 'a2*2': 4, 'b*3': 9}

Restructuring PyTrees#

# Example: Flatten nested structure
original_tree = {
    'a': 1,
    'dense_1': {'kernel': 2, 'bias': 3},
    'dense_2': {'kernel': 4, 'bias': 5},
}

transformations = {
    r'([a-z]+)': ocp.Transform(original_key=r'\1'),
    r'([a-z]+)_([0-9])_([a-z]+)': ocp.Transform(original_key=r'\1_\2/\3'),
}


new_tree = {
    'a': ...,
    'dense_1_kernel': ...,
    'dense_1_bias': ...,
    'dense_2_kernel': ...,
    'dense_2_bias': ...,
}

ocp.apply_transformations(original_tree, transformations, new_tree)
WARNING:absl:The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
{'a': 1,
 'dense_1_bias': 3,
 'dense_1_kernel': 2,
 'dense_2_bias': 5,
 'dense_2_kernel': 4}

Multi-value transformations#

Multi-value transformations can be used to combine multiple values from the original tree into the new tree.

# Example: various multi_value_fn usage
original_tree = {
    'a': np.array([1, 2, 3, 4]),
    'b': {'c': np.array([5, 6, 7, 8])},
}

transformations = {
    'a': ocp.Transform(multi_value_fn=lambda _, kv: kv['a'][-1]),
    'b': {
        'c': ocp.Transform(multi_value_fn=lambda _, kv: kv['a'] + kv['b']['c'])},
}


new_tree = {
    'a': ...,
    'b': {'c': ...}
}

ocp.apply_transformations(original_tree, transformations, new_tree)
WARNING:absl:The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
{'a': 4, 'b': {'c': array([ 6,  8, 10, 12])}}
# Example: Average the weights
original_tree = {
    'a': {'a_1': 1, 'a_2': 2},
    'b': {'b_1': 3, 'b_2': 4, 'b_3': 5},

}

transformations = {
    r'([a-z]+)': ocp.Transform(
        multi_value_fn=lambda k, kv: sum(kv[k].values()) / len(kv[k])),
}


new_tree = {
    'a': ...,
    'b': ...,
}

ocp.apply_transformations(original_tree, transformations, new_tree)
WARNING:absl:The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
{'a': 1.5, 'b': 4.0}

Real world example#

Let’s consider a real-world example. In this scenario, we have a saved checkpoint with parameters Dense_0, Dense_1. We want to restore this checkpoint, with modifications, into a model for training with layers Dense_0, Dense_1, Dense_2, Dense_3.

In this example, we will map original layers 0 and 1 onto the new layers 1 and 2, respectively. We want the new layers 0 and 3 to be initialized randomly, or with some new values.

The new model may be initialized as a Flax TrainState, for example.

params = model.init(
    jax.random.PRNGKey(0), jnp.ones([1, model.input_size]))
new_state = TrainState.create(
    apply_fn=model.apply, params=params, tx=optimizer)
# Restore original state.
original_state = manager.restore(step)
 transformations = {
      # NewModel layer 0 is a newly inserted layer, thus use_fallback=True.
      r'(.*)Dense_0(.*)': Transform(use_fallback=True),
      # OriginalModel layer 0 maps to NewModel layer 1
      r'(.*)Dense_1(.*)': Transform(original_key=r'\1Dense_0\2'),
      # OriginalModel layer 1 maps to NewModel layer 2
      r'(.*)Dense_2(.*)': Transform(original_key=r'\1Dense_1\2')
  }  # Note: NewModel layer 3 is newly added.
  restored_state = apply_transformations(original_state, transformations, new_state)

Let’s unpack what’s happening with these transformations.

For layer 0, we want to instruct the function to ignore what’s in original_state, and to instead use the value from new_state. For this, we set use_fallback=True.

For Dense_1 and Dense_2, we simple provide a regex mapping the original name of the key (Dense_0 and Dense_1, respectively) to their new values using the original_key field. Note that we can use a regex to match any key containing the desired pattern, since a PyTree checkpoint will typically represent a single layer with multiple different arrays, each containing the pattern.

Finally, we can simply omit Dense_3 from transformations, as the Dense_3 was provided as a key in new_state and the function will simply take the value from new_state and put it in the result.

Restoring a Checkpoint#

import flax.struct

@flax.struct.dataclass
class Small:
    key1: int

@flax.struct.dataclass
class Big:
    key1: int
    key2: int

to_save = Big(key1=10, key2=100)
to_restore = Small(key1=0)

path = '/tmp/my-checkpoints/'
ckptr = ocp.PyTreeCheckpointer()
ckptr.save(path, to_save)

restored1 = ckptr.restore(
  path, args=ocp.args.PyTreeRestore(
    to_restore,
    restore_args=ocp.checkpoint_utils.construct_restore_args(to_restore),
    transforms={}
  )
)
restored2 = ckptr.restore(
  path, args=ocp.args.PyTreeRestore(
    to_restore,
    restore_args=ocp.checkpoint_utils.construct_restore_args(to_restore),
    transforms={
        r'(.*)key1(.*)': ocp.Transform(original_key=r'\1key2\2')
    }
  )
)
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
/tmp/ipykernel_1456/1688818069.py in <module>
----> 1 import flax.struct
      2 
      3 @flax.struct.dataclass
      4 class Small:
      5     key1: int

ModuleNotFoundError: No module named 'flax'
restored1
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
/tmp/ipykernel_1456/3477447246.py in <module>
----> 1 restored1

NameError: name 'restored1' is not defined
restored2
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
/tmp/ipykernel_1456/3102324940.py in <module>
----> 1 restored2

NameError: name 'restored2' is not defined

Tips and Tricks#

Regex group names#

If your regex is getting complicated, you can set group names using (?P<name>...). This group can be referenced using the standard \N, where N is the numeric backreference, or \g<name> where name is the named backreference.

# Example:
original_tree = {
    'dense_1': {'kernel': 2, 'bias': 3},
}

transformations = {
    r'(?P<layer>[a-z]+)_(?P<num>[0-9])_(?P<weight>[a-z]+)': ocp.Transform(
        original_key=r'\g<layer>_\g<num>/\g<weight>'),
}


new_tree = {
    'dense_1_kernel': ...,
    'dense_1_bias': ...,
}

ocp.apply_transformations(original_tree, transformations, new_tree)
WARNING:absl:The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
{'dense_1_bias': 3, 'dense_1_kernel': 2}