Transformations#
Overview#
The transform_utils library provides functions to perform structural PyTree transformations, which can facilitate model surgery for finetuning, migrations between different checkpoint versions, etc.
The API consists of a Transform
class and an apply_transformations
function.
apply_transformations#
The apply_transformations
function accepts an original PyTree, a PyTree of Transform
objects and the desired structure of the returned Pytree. The function returns a newly generated PyTree.
def apply_transformations(
original_tree: PyTree,
transformations: PyTree,
new_tree: PyTree,
default_to_original: Optional[bool] = True) -> PyTree:
Transform
#
Transform
consists of the following elements:
original_key
: Denotes the original name of the key. Represented as a string with ‘/’ denoting successive levels of nesting. If the key corresponding to this Transform is a regex, backreferences (such as \1) will be replaced with the appropriate matched group in the regex. Note: not needed if multi_value_fn is provided.use_fallback
: if True, takes the value from the fallback tree. Ifdefault_to_original=True
inapply_transformations
, the fallback tree isnew_tree
. Ifdefault_to_original=False
inapply_transformations
, the fallback tree isoriginal_tree
.value_fn
: A function accepting a single value and returning a single value. The value provided as an argument is the value of the transformation key in the original PyTree.multi_value_fn
: A function accepting a PyTree and returning any value. The PyTree argument will be the original PyTree, and the function should return the value of the key in the new PyTree.
Fallbacks#
Note that there is an additional option for apply_transformations
, which is
default_to_original
(True by default). This means that the values keys
unspecified in transformations
but present in both trees will be taken from
the original tree. If False, such values will be taken from the new tree.
Remember that if a key is present in the new tree, but not in the old, the value will simply be taken from the new tree. If a key is present in the original tree but not in the new, it will be dropped in the result.
Examples#
# Setup
import orbax.checkpoint as ocp
import numpy as np
Renaming keys#
Key renames are common for reusing existing checkpointed state between different models or same model at different versions.
# Example: Migrate original tree into the new_tree, which has the same
# nested structure but different keys.
original_tree = {
'a': 1,
'b': 2
}
transformations = {
'a2': ocp.Transform(original_key='a'),
'b2': ocp.Transform(original_key='b')
}
new_tree = {
'a2': ...,
'b2': ...
}
ocp.apply_transformations(original_tree, transformations, new_tree)
WARNING:absl:The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
{'a2': 1, 'b2': 2}
# Example 2: Renaming with regex
original_tree = {
'a1': 1,
'b5': 2
}
transformations = {
r'([a-z])_([0-9])': ocp.Transform(original_key=r'\1\2'),
}
new_tree = {
'a_1': ...,
'b_5': ...
}
ocp.apply_transformations(original_tree, transformations, new_tree)
WARNING:absl:The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
{'a_1': 1, 'b_5': 2}
# Example 3: Renaming nested trees
original_tree = {
'a': 1,
'dense_1': {'kernel': 2, 'bias': 3},
'dense_2': {'kernel': 4, 'bias': 5},
}
# Nested keys can be represented by a single string by separating each level
# with '/'.
transformations = {
r'([a-z]+)_NEW': ocp.Transform(original_key=r'\1'),
r'([a-z]+)_([0-9])_NEW/([a-z]+)_1': ocp.Transform(original_key=r'\1_\2/\3'),
}
# This is equivalent to:
transformations = {
r'([a-z]+)_NEW': ocp.Transform(original_key=r'\1'),
r'([a-z]+)_([0-9])_NEW': {
'([a-z]+)_1': ocp.Transform(original_key=r'\1_\2/\3'),}
}
new_tree = {
'a_NEW': ...,
'dense_1_NEW': {'kernel_1': ..., 'bias_1': ...},
'dense_2_NEW': {'kernel_1': ..., 'bias_1': ...},
}
ocp.apply_transformations(original_tree, transformations, new_tree)
WARNING:absl:The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
{'a_NEW': 1,
'dense_1_NEW': {'bias_1': 3, 'kernel_1': 2},
'dense_2_NEW': {'bias_1': 5, 'kernel_1': 4}}
Updating the value#
To change a leaf node in the Pytree, define a Transform
with a value_fn
. This transformation could be used for quantization, modifying hyperparameters, etc.
# Example: Transform the values in a tree.
original_tree = {
'a': 1,
'b': 2
}
transformations = {
'a': ocp.Transform(value_fn=lambda v: v * 2),
'b2': ocp.Transform(value_fn=lambda v: v * 3, original_key='b')
}
new_tree = {
'a': ...,
'b2': ... # Output different key
}
ocp.apply_transformations(original_tree, transformations, new_tree)
WARNING:absl:The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
{'a': 2, 'b2': 6}
# Example 2: Transform values in a tree with regex (multiply all 'a' keys by 2
# all 'b' keys by 3).
original_tree = {
'a1': 1,
'a2': 2,
'b': 3
}
transformations = {
r'a([0-9]?)\*2': ocp.Transform(value_fn=lambda v: v * 2,
original_key=r'a\1'),
r'b([0-9]?)\*3': ocp.Transform(value_fn=lambda v: v * 3,
original_key=r'b\1')
}
new_tree = {
'a1*2': ...,
'a2*2': ...,
'b*3': ...
}
ocp.apply_transformations(original_tree, transformations, new_tree)
WARNING:absl:The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
{'a1*2': 2, 'a2*2': 4, 'b*3': 9}
Restructuring PyTrees#
# Example: Flatten nested structure
original_tree = {
'a': 1,
'dense_1': {'kernel': 2, 'bias': 3},
'dense_2': {'kernel': 4, 'bias': 5},
}
transformations = {
r'([a-z]+)': ocp.Transform(original_key=r'\1'),
r'([a-z]+)_([0-9])_([a-z]+)': ocp.Transform(original_key=r'\1_\2/\3'),
}
new_tree = {
'a': ...,
'dense_1_kernel': ...,
'dense_1_bias': ...,
'dense_2_kernel': ...,
'dense_2_bias': ...,
}
ocp.apply_transformations(original_tree, transformations, new_tree)
WARNING:absl:The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
{'a': 1,
'dense_1_bias': 3,
'dense_1_kernel': 2,
'dense_2_bias': 5,
'dense_2_kernel': 4}
Multi-value transformations#
Multi-value transformations can be used to combine multiple values from the original tree into the new tree.
# Example: various multi_value_fn usage
original_tree = {
'a': np.array([1, 2, 3, 4]),
'b': {'c': np.array([5, 6, 7, 8])},
}
transformations = {
'a': ocp.Transform(multi_value_fn=lambda _, kv: kv['a'][-1]),
'b': {
'c': ocp.Transform(multi_value_fn=lambda _, kv: kv['a'] + kv['b']['c'])},
}
new_tree = {
'a': ...,
'b': {'c': ...}
}
ocp.apply_transformations(original_tree, transformations, new_tree)
WARNING:absl:The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
{'a': 4, 'b': {'c': array([ 6, 8, 10, 12])}}
# Example: Average the weights
original_tree = {
'a': {'a_1': 1, 'a_2': 2},
'b': {'b_1': 3, 'b_2': 4, 'b_3': 5},
}
transformations = {
r'([a-z]+)': ocp.Transform(
multi_value_fn=lambda k, kv: sum(kv[k].values()) / len(kv[k])),
}
new_tree = {
'a': ...,
'b': ...,
}
ocp.apply_transformations(original_tree, transformations, new_tree)
WARNING:absl:The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
{'a': 1.5, 'b': 4.0}
Real world example#
Let’s consider a real-world example. In this scenario, we have a saved
checkpoint with parameters Dense_0
, Dense_1
. We want to restore this
checkpoint, with modifications, into a model for training with layers Dense_0
,
Dense_1
, Dense_2
, Dense_3
.
In this example, we will map original layers 0 and 1 onto the new layers 1 and 2, respectively. We want the new layers 0 and 3 to be initialized randomly, or with some new values.
The new model may be initialized as a Flax TrainState, for example.
params = model.init(
jax.random.PRNGKey(0), jnp.ones([1, model.input_size]))
new_state = TrainState.create(
apply_fn=model.apply, params=params, tx=optimizer)
# Restore original state.
original_state = manager.restore(step)
transformations = {
# NewModel layer 0 is a newly inserted layer, thus use_fallback=True.
r'(.*)Dense_0(.*)': Transform(use_fallback=True),
# OriginalModel layer 0 maps to NewModel layer 1
r'(.*)Dense_1(.*)': Transform(original_key=r'\1Dense_0\2'),
# OriginalModel layer 1 maps to NewModel layer 2
r'(.*)Dense_2(.*)': Transform(original_key=r'\1Dense_1\2')
} # Note: NewModel layer 3 is newly added.
restored_state = apply_transformations(original_state, transformations, new_state)
Let’s unpack what’s happening with these transformations.
For layer 0, we want to instruct the function to ignore what’s in
original_state
, and to instead use the value from new_state
. For this, we
set use_fallback=True
.
For Dense_1
and Dense_2
, we simple provide a regex mapping the original name
of the key (Dense_0
and Dense_1
, respectively) to their new values using the
original_key
field. Note that we can use a regex to match any key containing
the desired pattern, since a PyTree checkpoint will typically represent a single
layer with multiple different arrays, each containing the pattern.
Finally, we can simply omit Dense_3
from transformations
, as the Dense_3
was provided as a key in new_state
and the function will simply take the value
from new_state
and put it in the result.
Restoring a Checkpoint#
import flax.struct
@flax.struct.dataclass
class Small:
key1: int
@flax.struct.dataclass
class Big:
key1: int
key2: int
to_save = Big(key1=10, key2=100)
to_restore = Small(key1=0)
path = '/tmp/my-checkpoints/'
ckptr = ocp.PyTreeCheckpointer()
ckptr.save(path, to_save)
restored1 = ckptr.restore(
path, args=ocp.args.PyTreeRestore(
to_restore,
restore_args=ocp.checkpoint_utils.construct_restore_args(to_restore),
transforms={}
)
)
restored2 = ckptr.restore(
path, args=ocp.args.PyTreeRestore(
to_restore,
restore_args=ocp.checkpoint_utils.construct_restore_args(to_restore),
transforms={
r'(.*)key1(.*)': ocp.Transform(original_key=r'\1key2\2')
}
)
)
---------------------------------------------------------------------------
ModuleNotFoundError Traceback (most recent call last)
/tmp/ipykernel_1456/1688818069.py in <module>
----> 1 import flax.struct
2
3 @flax.struct.dataclass
4 class Small:
5 key1: int
ModuleNotFoundError: No module named 'flax'
restored1
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
/tmp/ipykernel_1456/3477447246.py in <module>
----> 1 restored1
NameError: name 'restored1' is not defined
restored2
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
/tmp/ipykernel_1456/3102324940.py in <module>
----> 1 restored2
NameError: name 'restored2' is not defined
Tips and Tricks#
Regex group names#
If your regex is getting complicated, you can set group names using (?P<name>...)
. This group can be referenced using the standard \N
, where N is the numeric backreference, or \g<name>
where name
is the named backreference.
# Example:
original_tree = {
'dense_1': {'kernel': 2, 'bias': 3},
}
transformations = {
r'(?P<layer>[a-z]+)_(?P<num>[0-9])_(?P<weight>[a-z]+)': ocp.Transform(
original_key=r'\g<layer>_\g<num>/\g<weight>'),
}
new_tree = {
'dense_1_kernel': ...,
'dense_1_bias': ...,
}
ocp.apply_transformations(original_tree, transformations, new_tree)
WARNING:absl:The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
{'dense_1_bias': 3, 'dense_1_kernel': 2}