Model Surgery#
Oftentimes, the model we saved to disk isn’t exactly the model we wish to work with in memory. Some examples of this are:
Stacking/unstacking layers to match your training setup
Fine-tuning a multi-modal model from multiple uni-modal models
Using a frozen teacher model at each iteration of the training loop for a student model
Loading only the weights section of the PyTree, and ignoring things like optimizer state, when doing model evaluation
Model surgery is a toolset designed precisely for this kind of task.
Orbax Checkpointing currently exposes a Partial Loading API, which allows for a subset of PyTree leaves (or, a “strict subtree”) to be loaded from the full model on disk. More arbitrary manipulation of leaves and trees is planned to be added in the future, such as loading multiple trees and merging them into one.
Let’s first take a look at what it’s like to restore part of a PyTree, then touch on the planned Advanced Model Surgery API.
import jax
import numpy as np
from orbax.checkpoint import v1 as ocp
from etils import epath
path = epath.Path('/tmp/model_surgery/my-checkpoints/ckpt-1')
pytree = {
'params': {
'layer0': {
'kernel': np.random.uniform(size=(2, 2)),
'bias': np.ones(2),
},
},
'opt_state': {
'0': np.random.random(size=(2,)),
'1': [np.ones(2), np.ones(2)],
},
'step': np.asarray(0),
}
mesh = jax.sharding.Mesh(jax.devices(), ('x',))
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(None))
pytree = jax.tree.map(
lambda arr: jax.make_array_from_callback(
arr.shape,
sharding,
lambda idx: arr[idx],
),
pytree,
)
ocp.save(path, pytree, overwrite=True)
Partial Loading#
Partial loading is a way to solve the most common use case of loading a different tree than is present in the checkpoint - where leaves or subtrees can be omitted. The canonical example is to skip loading the optimizer state when you’re doing evaluation. There are a couple of ways to do this with the Partial Loading API. Let’s take a look at both.
Placeholder#
Since we don’t need the optimizer state (opt_state) during model evaluation, we can signal to Orbax to skip loading the leaves with that node by using the ocp.PLACEHOLDER (...) value.
abstract_tree = {
'params': {
'layer0': {
'kernel': np.array([]),
'bias': np.array([]),
},
},
# Skip loading 'opt_state'
'opt_state': {
'0': ...,
'1': [..., ...],
},
'step': np.array([]),
}
ocp.load(path, abstract_tree)
WARNING:absl:TensorStore data files not found in checkpoint path /tmp/model_surgery/my-checkpoints/ckpt-1. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.
{'opt_state': {'0': Ellipsis, '1': [Ellipsis, Ellipsis]},
'params': {'layer0': {'bias': array([1., 1.]),
'kernel': array([[0.38237846, 0.30743006],
[0.43622473, 0.06146657]])}},
'step': array(0.)}
Note that ocp.PLACEHOLDER can only be used for leaves, so opt_state: ocp.PLACEHOLDER would not work. Keeping the structure consistent in this way is important for use cases like merging the original state with the restored state.
bad_abstract_tree = {
'params': {
'layer0': {
'kernel': np.array([]),
'bias': np.array([]),
},
},
# Skip loading 'opt_state'
'opt_state': ...,
'step': np.array([]),
}
try:
ocp.load(path, bad_abstract_tree)
except Exception as e:
print(e)
WARNING:absl:TensorStore data files not found in checkpoint path /tmp/model_surgery/my-checkpoints/ckpt-1. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.
User-provided restore item and on-disk value metadata tree structures do not match:
opt_state:
- Source: <class 'ellipsis'>
- Target: <class 'dict'>
If this mismatch is intentional, pass `partial_restore=True` to only restore parameters found in `item`.
Creating an abstract tree by hand is tedious. A more natural way to do this is by using something like JAX’s tree_map_with_path.
def _create_abstract_leaf_for_partial_load(leaf_path, _):
leaf_path = jax.tree_util.keystr(leaf_path, simple=True, separator='/')
if (leaf_path.split('/')[0] == 'opt_state'):
return ocp.PLACEHOLDER
else:
return np.array([])
easy_abstract_tree = jax.tree.map_with_path(
_create_abstract_leaf_for_partial_load,
pytree
)
ocp.load(path, easy_abstract_tree)
WARNING:absl:TensorStore data files not found in checkpoint path /tmp/model_surgery/my-checkpoints/ckpt-1. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.
{'opt_state': {'0': Ellipsis, '1': [Ellipsis, Ellipsis]},
'params': {'layer0': {'bias': array([1., 1.]),
'kernel': array([[0.38237846, 0.30743006],
[0.43622473, 0.06146657]])}},
'step': array(0.)}
We may not have direct access to the original PyTree when creating the abstract counterpart, and in that case, we’ll need to use the on-disk metadata.
on_disk_pytree_structure = ocp.metadata(path).metadata
abstract_tree_from_metadata = jax.tree.map_with_path(
_create_abstract_leaf_for_partial_load,
on_disk_pytree_structure
)
ocp.load(path, abstract_tree_from_metadata)
WARNING:absl:TensorStore data files not found in checkpoint path /tmp/model_surgery/my-checkpoints/ckpt-1. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.
WARNING:absl:TensorStore data files not found in checkpoint path /tmp/model_surgery/my-checkpoints/ckpt-1. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.
{'opt_state': {'0': Ellipsis, '1': [Ellipsis, Ellipsis]},
'params': {'layer0': {'bias': array([1., 1.]),
'kernel': array([[0.38237846, 0.30743006],
[0.43622473, 0.06146657]])}},
'step': array(0.)}
Omission#
Alternatively, we can enable the partial_load option to avoid having to explicitly specify nodes to be skipped. Instead, we simply ignore those nodes during construction of the abstract PyTree.
abstract_tree = {
'params': {
'layer0': {
'kernel': np.array([]),
'bias': np.array([]),
},
},
# Note: omit 'opt_state' to avoid loading it
'step': 0,
}
# Loading PyTrees with certain leaves missing is unsafe
try:
ocp.load(path, abstract_tree)
except ValueError as e:
print(e)
# So partial_load must be opted-into
with ocp.Context(
pytree_options=ocp.options.PyTreeOptions(
loading=ocp.options.PyTreeOptions.Loading(
partial_load=True,
),
),
):
ocp.load(path, abstract_tree)
WARNING:absl:TensorStore data files not found in checkpoint path /tmp/model_surgery/my-checkpoints/ckpt-1. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.
WARNING:absl:TensorStore data files not found in checkpoint path /tmp/model_surgery/my-checkpoints/ckpt-1. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.
User-provided restore item and on-disk value metadata tree structures do not match:
opt_state:
- Source: MISSING
- Target: {'0': ValueMetadataEntry(value_type='jax.Array', skip_deserialize=False, write_shape=(2,)), '1': [ValueMetadataEntry(value_type='jax.Array', skip_deserialize=False, write_shape=(2,)), ValueMetadataEntry(value_type='jax.Array', skip_deserialize=False, write_shape=(2,))]}
If this mismatch is intentional, pass `partial_restore=True` to only restore parameters found in `item`.
Model Surgery#
While partial loading is useful for omitting parts of a PyTree, it does not allow for more complex manipulations. In contrast, the planned Model Surgery API is a powerful toolset where the user can manipulate trees and leaves in arbitrary ways. This includes restructuring trees, modifying values, and even loading and merging multiple distinct checkpoints into a single model in memory.
The core of this API will be user-defined transformation functions that are applied to checkpoints during the loading process.
Single-Model Transformations#
A common use case for model surgery is transforming a single checkpoint into a different structure. For example, you might want to stack model layers that were saved individually. This can be accomplished with a transform_fn that takes the PyTree from the source checkpoint and returns a new, modified PyTree.
A potential API for this could look like:
ocp.load_and_transform = lambda *args: None
def stack_layers_transform(source_tree):
params = source_tree['params']
# Assumes layers are named 'layer0', 'layer1', etc.
layer_keys = sorted([k for k in params if 'layer' in k])
stacked_layers = jax.numpy.stack([params[k]['kernel'] for k in layer_keys])
new_params = {'stacked_layers': stacked_layers}
# Bring over any other parameters that are not part of the stacking.
for k in params:
if 'layer' not in k:
new_params[k] = params[k]
source_tree['params'] = new_params
return source_tree
abstract_tree = ...
# The API would apply the transformation during loading.
restored_tree = ocp.load_and_transform(path, stack_layers_transform, abstract_tree)
Multi-Model Transformations#
A more advanced use case is merging multiple checkpoints. A key example is creating a multi-modal model by combining two separately trained uni-modal models (e.g., an image model and a text model).
A transformation function for this scenario would accept multiple source trees and define how they should be combined.
def merge_models_transform(image_model_tree, text_model_tree):
return {
'params': {
'image_encoder': image_model_tree['params'],
'text_encoder': text_model_tree['params'],
# A new fusion layer. The user can initialize it later.
'fusion_layer': {
'kernel': np.empty((512, 256)),
'bias': np.empty((256,)),
}
},
# Can also merge other things, like step counts etc.
'step': image_model_tree['step'],
}
image_model_path = ...
text_model_path = ...
# The API would take multiple paths and apply the transform.
final_model = ocp.load_and_transform(
merge_models_transform,
image_model_path,
text_model_path,
)
This example also highlights an important feature: any parameters in the target structure that are not explicitly populated from a source checkpoint (like ‘fusion_layer’) would be initialized from scratch. This makes it easy to combine pre-trained components with new, untrained ones.
This planned API aims to provide maximum flexibility, making complex restoration and fine-tuning workflows more straightforward to implement.