Checkpointing PyTrees of Arrays#
A PyTree is the most common way of representing a training state in JAX. While Orbax is designed to be as generic as possible, and provides customization options for all manner of checkpointable objects, PyTrees naturally have pride of place. Furthermore, the standard object used to represent large, sharded arrays is the jax.Array. This, too, has extensive first-class support.
CheckpointHandler Support#
There are essentially two options provided by Orbax for working with PyTrees.
StandardCheckpointHandler- applicable in the majority of use cases.PyTreeCheckpointHandler- useful when advanced customization is desired.
import numpy as np
import orbax.checkpoint as ocp
import jax
sharding = jax.sharding.NamedSharding(
jax.sharding.Mesh(jax.devices(), ('model',)),
jax.sharding.PartitionSpec(
'model',
),
)
create_sharded_array = lambda x: jax.device_put(x, sharding)
state = {
'a': np.arange(16),
'b': np.ones(16),
}
state = jax.tree_util.tree_map(create_sharded_array, state)
abstract_state = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, state)
Basic Checkpointing#
Let’s use StandardCheckpointHandler to work with PyTrees of jax.Array.
path = ocp.test_utils.erase_and_create_empty('/tmp/basic/')
# Make sure to use async for improved performance!
ckptr = ocp.AsyncCheckpointer(ocp.StandardCheckpointHandler())
ckptr.save(path / '1', args=ocp.args.StandardSave(state))
ckptr.wait_until_finished()
We specify the abstract_state in order to restore with the given dtypes, shapes, and shardings for each leaf.
restored = ckptr.restore(path / '1', args=ocp.args.StandardRestore(abstract_state))
restored
{'a': Array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], dtype=int32),
'b': Array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32)}
restored['a'].sharding
NamedSharding(mesh=Mesh('model': 1, axis_types=(Auto,)), spec=P('model',), memory_kind=device)
You can do the exact same with a “concrete” target rather than an “abstract” target. However, this requires that you fully initialize the target train state
before restoring from the checkpoint, which is inefficient. It is better practice to only initialize metadata (either by manually creating jax.ShapeDtypeStructs or using jax.eval_shape).
ckptr.restore(path / '1', args=ocp.args.StandardRestore(state))
{'a': Array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], dtype=int32),
'b': Array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32)}
Customizing Restored Properties#
Array dtype#
def set_restore_dtype(x: jax.ShapeDtypeStruct) -> jax.ShapeDtypeStruct:
return x.update(dtype=np.int16)
cast_dtype_abstract_state = jax.tree_util.tree_map(
set_restore_dtype, abstract_state)
ckptr.restore(
path / '1',
args=ocp.args.StandardRestore(cast_dtype_abstract_state),
)
{'a': Array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], dtype=int16),
'b': Array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int16)}
Pad / truncate shape#
different_shape_abstract_state = {
'a': jax.ShapeDtypeStruct(
shape=(8,),
dtype=abstract_state['a'].dtype,
sharding=abstract_state['a'].sharding
),
'b': jax.ShapeDtypeStruct(
shape=(32,),
dtype=abstract_state['b'].dtype,
sharding=abstract_state['b'].sharding
),
}
Ordinarily, specifying a target array with a different shape than in the checkpoint results in an error.
try:
ckptr.restore(
path / '1',
args=ocp.args.StandardRestore(different_shape_abstract_state),
)
except BaseException as e:
print(e)
Requested shape: (8,) is not compatible with the stored shape: (16,). Truncating/padding is disabled. To enable it, set `strict=False` in `ArrayRestoreArgs` for any array in v0 API or `enable_padding_and_truncation=True` in `ArrayOptions.Loading` in v1 API.
We can pad or truncate arrays as they are loaded by specifying strict=False.
ckptr.restore(
path / '1',
args=ocp.args.StandardRestore(different_shape_abstract_state, strict=False),
)
{'a': Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32),
'b': Array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)}
Change sharding#
NOTE: This can often be a particularly sharp edge.
Sharding commonly needs to be changed when loading a checkpoint saved on one topology to a different topology.
If changing topologies, you MUST specify sharding when restoring.
Unless you are loading on the exact same topology, Orbax does not make any decisions about shardings on you behalf. If you have the exact same topology, however, it is possible to avoid specifying the sharding when loading. This is demonstrated below:
restored = ckptr.restore(path / '1')
restored['a'].sharding
NamedSharding(mesh=Mesh('model': 1, axis_types=(Auto,)), spec=P('model',), memory_kind=device)
In the example below, we alter the sharding while loading.
sharding = jax.sharding.NamedSharding(
jax.sharding.Mesh(jax.devices(), ('x',)),
jax.sharding.PartitionSpec(),
)
def set_sharding(x: jax.ShapeDtypeStruct) -> jax.ShapeDtypeStruct:
return x.update(sharding=sharding)
change_sharding_abstract_state = jax.tree_util.tree_map(
set_sharding, abstract_state)
restored = ckptr.restore(
path / '1',
args=ocp.args.StandardRestore(change_sharding_abstract_state),
)
restored['a'].sharding
NamedSharding(mesh=Mesh('x': 1, axis_types=(Auto,)), spec=P(), memory_kind=device)
Partial Restore#
You may wish to restore part of a PyTree contained within a saved checkpoint. For example, consider the following item:
original_item = {
'params': {
'layer1': {
'kernel': np.arange(8),
'bias': np.arange(8),
},
'layer2': {
'kernel': np.arange(8),
'bias': np.arange(8),
},
},
'opt_state': [np.arange(8), np.arange(8)],
'step': 101,
}
path = ocp.test_utils.erase_and_create_empty('/tmp/partial/')
ckptr = ocp.PyTreeCheckpointer()
ckptr.save(path / '1', args=ocp.args.PyTreeSave(original_item))
If we want to restore only a subset of PyTree nodes (params.layer2 and step, for example), we can use Placeholder values.
Placeholder#
To restore part of a PyTree item, we can specify which nodes to ignore during restoration by using ocp.PLACEHOLDER.
reference_item = {
'params': {
'layer1': {
'kernel': ocp.PLACEHOLDER,
'bias': ocp.PLACEHOLDER,
},
'layer2': {
'kernel': np.arange(8),
'bias': np.arange(8),
},
},
'opt_state': [ocp.PLACEHOLDER, ocp.PLACEHOLDER],
'step': 101,
}
ckptr.restore(
path / '1',
args=ocp.args.PyTreeRestore(
item=reference_item,
),
)
{'opt_state': [Ellipsis, Ellipsis],
'params': {'layer1': {'bias': Ellipsis, 'kernel': Ellipsis},
'layer2': {'bias': array([0, 1, 2, 3, 4, 5, 6, 7]),
'kernel': array([0, 1, 2, 3, 4, 5, 6, 7])}},
'step': 101}
Omission#
Alternatively, we can enable the partial_restore option to avoid having to explicitly specify nodes to be skipped. Instead, we simply ignore those nodes during construction of the abstract PyTree.
reference_item = {
'params': {
# Note: omit 'layer1' to avoid loading it (or, more accurately, to avoid
# loading 'layer1.kernel' and 'layer1.bias')
'layer2': {
'kernel': np.arange(8),
'bias': np.arange(8),
},
},
# Omit 'opt_state' to avoid loading 'opt_state[0]' and 'opt_state[1]'
'step': 101,
}
# Loading PyTrees with certain leaves missing is unsafe
try:
ckptr.restore(
path / '1',
args=ocp.args.PyTreeRestore(
item=reference_item,
),
)
except ValueError as e:
print(e)
# So Partial Restore Omission mode must be opted-into using partial_restore=True
ckptr.restore(
path / '1',
args=ocp.args.PyTreeRestore(
item=reference_item,
partial_restore=True,
),
)
User-provided restore item and on-disk value metadata tree structures do not match:
opt_state:
- Source: MISSING
- Target: [ValueMetadataEntry(value_type='np.ndarray', skip_deserialize=False, write_shape=None), ValueMetadataEntry(value_type='np.ndarray', skip_deserialize=False, write_shape=None)]
params.layer1:
- Source: MISSING
- Target: {'bias': ValueMetadataEntry(value_type='np.ndarray', skip_deserialize=False, write_shape=None), 'kernel': ValueMetadataEntry(value_type='np.ndarray', skip_deserialize=False, write_shape=None)}
If this mismatch is intentional, pass `partial_restore=True` to only restore parameters found in `item`.
{'params': {'layer2': {'bias': array([0, 1, 2, 3, 4, 5, 6, 7]),
'kernel': array([0, 1, 2, 3, 4, 5, 6, 7])}},
'step': 101}
Advanced Options#
There are some advanced options that StandardCheckpointHandler does not provide. Additional options can be specified using PyTreeCheckpointHandler
instead.
Saving#
For example, PyTreeCheckpointHandler can be used to customize the on-disk type used to save individual arrays. First, let’s save and restore as normal.
path = ocp.test_utils.erase_and_create_empty('/tmp/advanced/')
# Make sure to use async for improved performance!
ckptr = ocp.AsyncCheckpointer(ocp.PyTreeCheckpointHandler())
ckptr.save(path / '1', args=ocp.args.PyTreeSave(state))
ckptr.wait_until_finished()
restored = ckptr.restore(path / '1')
/home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages/orbax/checkpoint/_src/serialization/jax_array_handlers.py:736: UserWarning: Sharding info not provided when restoring. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.
warnings.warn(
restored['a'].dtype
dtype('int32')
restored['b'].dtype
dtype('float32')
Now, let’s set the dtype of the array when saving.
ckptr.save(
path / '2',
args=ocp.args.PyTreeSave(
state,
save_args={
# We must set one ocp.SaveArgs per leaf.
'a': ocp.SaveArgs(dtype=np.dtype(np.int16)),
'b': ocp.SaveArgs()
}
),
)
ckptr.wait_until_finished()
restored = ckptr.restore(path / '2')
restored['a'].dtype
dtype('int16')
restored['b'].dtype
dtype('float32')
Restoring#
Options similar to the above are available, where we can customize shape, dtype, and sharding when restoring.
ckptr.restore(
path / '2',
args=ocp.args.PyTreeRestore(
restore_args={
# RestoreArgs is the parent class for ArrayRestoreArgs.
# We must set one RestoreArgs per leaf.
'a': ocp.RestoreArgs(restore_type=np.ndarray),
'b': ocp.ArrayRestoreArgs(dtype=np.dtype(np.int16), sharding=sharding)
}
),
)
{'a': array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
dtype=int16),
'b': Array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int16)}
Note that “a” was restored as np.ndarray rather than jax.Array.
PyTreeCheckpointHandler also provides options to perform transformations when restoring via the transforms argument. This is useful when your target tree has a different structure than your checkpoint tree. For example, it can be used to rename keys from the checkpoint (original_key) to match the target structure, as seen in the example below.
ckptr.restore(
path / '2',
args=ocp.args.PyTreeRestore(
# `item` serves as a guide to what the result tree structure should look
# like.
item={
# Value doesn't really matter here, as long as it's not None.
'c': ...,
# Can add in extra keys.
'd': np.arange(8)
},
# `restore_args` must be relative to the result tree, not the
# checkpoint.
restore_args={
'c': ocp.RestoreArgs(restore_type=np.ndarray),
},
transforms={
'c': ocp.Transform(original_key='a')
},
),
)
WARNING:absl:The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
{'c': array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
dtype=int16),
'd': array([0, 1, 2, 3, 4, 5, 6, 7])}