Working with PyTree Checkpoints#
A PyTree is the most common way of representing a training state in JAX. While Orbax is designed to be as generic as possible, and provides customization options for all manner of checkpointable objects, PyTrees naturally have pride of place. Furthermore, the standard object used to represent large, sharded arrays is the jax.Array. This, too, has extensive first-class support.
Exclusive APIs to checkpoint PyTrees#
The following APIs can be used to checkpoint PyTrees exclusively.
To save:
ocp.save(...)ocp.save_async(...)training.Checkpointer.save(...)training.Checkpointer.save_async(...)
To load:
ocp.load(...)ocp.load_async(...)training.Checkpointer.load(...)training.Checkpointer.load_async(...)
Of course, the save_checkpointables(...) and load_checkpointables(...)
flavor APIs can be used to save a PyTree too.
Let’s setup a PyTree of jax.Array to play with these APIs.
from etils import epath
import jax
import numpy as np
import orbax.checkpoint.experimental.v1 as ocp
sharding = jax.sharding.NamedSharding(
jax.sharding.Mesh(jax.devices(), ('model',)),
jax.sharding.PartitionSpec(
'model',
),
)
create_sharded_array = lambda x: jax.device_put(x, sharding)
pytree = {
'a': np.arange(16, dtype=np.int32),
'b': np.ones(16, dtype=np.int32),
}
pytree = jax.tree_util.tree_map(create_sharded_array, pytree)
pytree
{'a': Array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], dtype=int32),
'b': Array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int32)}
Basic Checkpointing#
Let’s use ocp.save_*/ocp.load_* to work with the pytree created earlier.
path = epath.Path('/tmp/checkpointing-pytrees/basic/')
path.rmtree(missing_ok=True)
# Simple save using default options:
ocp.save(path, pytree)
We can easily restore using the following snippet.
Warning: do not use for production-sensitive cases.
loaded = ocp.load(path)
loaded
WARNING:absl:TensorStore data files not found in checkpoint path /tmp/checkpointing-pytrees/basic. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.
/home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages/orbax/checkpoint/_src/serialization/jax_array_handlers.py:736: UserWarning: Sharding info not provided when restoring. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.
warnings.warn(
{'a': Array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], dtype=int32),
'b': Array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int32)}
It is not recommended to load this way for production-sensitive cases because the user cannot make any guarantees about what they are loading. If the shapes of some arrays have changed in the model since the checkpoint was saved, errors can be seen when attempting to create the model. If the device topology has changed, we will see errors when attempting to place arrays on devices.
It is therefore recommended that users always specify an abstract pytree when loading.
Understanding Abstract Trees and Leaves#
An abstract PyTree is just a normal PyTree, but with abstract leaves. An abstract leaf is a cheap representation of a leaf type (such as an array) that contains only metadata, and does not represent the real values. (Contrast with a concrete PyTree, which contains real data in the form of large arrays, and other types.)
Let’s create an abstract PyTree matching the structure of the PyTree we originally saved.
abstract_state = {
'a': jax.ShapeDtypeStruct(shape=(16,), dtype=np.int32, sharding=sharding),
'b': jax.ShapeDtypeStruct(shape=(16,), dtype=np.int32, sharding=sharding),
}
abstract_state
{'a': ShapeDtypeStruct(shape=(16,), dtype=int32, sharding=NamedSharding(mesh=Mesh('model': 1, axis_types=(Auto,)), spec=P('model',), memory_kind=device)),
'b': ShapeDtypeStruct(shape=(16,), dtype=int32, sharding=NamedSharding(mesh=Mesh('model': 1, axis_types=(Auto,)), spec=P('model',), memory_kind=device))}
# Load using abstract_state.
loaded = ocp.load(path, abstract_state)
loaded
WARNING:absl:TensorStore data files not found in checkpoint path /tmp/checkpointing-pytrees/basic. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.
{'a': Array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], dtype=int32),
'b': Array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int32)}
(loaded['a'].sharding, loaded['b'].sharding)
(NamedSharding(mesh=Mesh('model': 1, axis_types=(Auto,)), spec=P('model',), memory_kind=device),
NamedSharding(mesh=Mesh('model': 1, axis_types=(Auto,)), spec=P('model',), memory_kind=device))
The metadata method returns a CheckpointMetadata object with a number of properties, but the core metadata property is just an abstract PyTree. This can also be used for loading as shown below.
metadata = ocp.metadata(path).metadata
metadata
WARNING:absl:TensorStore data files not found in checkpoint path /tmp/checkpointing-pytrees/basic. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.
{'a': ArrayMetadata(shape=(16,), dtype=dtype('int32'), sharding_metadata=NamedShardingMetadata(shape=[1], axis_names=['model'], axis_types=(Auto,), partition_spec=('model',)) device_mesh=DeviceMetadataMesh(mesh=[DeviceMetadata(id=0)]), storage_metadata=StorageMetadata(chunk_shape=(16,), write_shape=(16,))),
'b': ArrayMetadata(shape=(16,), dtype=dtype('int32'), sharding_metadata=NamedShardingMetadata(shape=[1], axis_names=['model'], axis_types=(Auto,), partition_spec=('model',)) device_mesh=DeviceMetadataMesh(mesh=[DeviceMetadata(id=0)]), storage_metadata=StorageMetadata(chunk_shape=(16,), write_shape=(16,)))}
loaded = ocp.load(path, metadata)
loaded
WARNING:absl:TensorStore data files not found in checkpoint path /tmp/checkpointing-pytrees/basic. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.
{'a': Array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], dtype=int32),
'b': Array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int32)}
(loaded['a'].sharding, loaded['b'].sharding)
(NamedSharding(mesh=Mesh('model': 1, axis_types=(Auto,)), spec=P('model',), memory_kind=device),
NamedSharding(mesh=Mesh('model': 1, axis_types=(Auto,)), spec=P('model',), memory_kind=device))
Note that it is also valid to provide a “concrete” PyTree for loading rather than an “abstract” target, since by definition, the concrete leaves contain all the same properties provided by the abstract leaves.
However, this requires that you fully initialize the target train state
before loading from the checkpoint, which is inefficient or impractical for real use cases. It is better practice to only initialize metadata (either by manually creating jax.ShapeDtypeStructs or using jax.eval_shape).
ocp.load(path, pytree)
WARNING:absl:TensorStore data files not found in checkpoint path /tmp/checkpointing-pytrees/basic. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.
{'a': Array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], dtype=int32),
'b': Array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int32)}
Standard Leaf Types#
The following standard leaf types are supported by Orbax by default. Each concrete leaf type has a corresponding abstract leaf type. Most abstract types are implemented as Protocol’s, so that any object implementing the required properties can be accepted as a valid abstract type.
|
|
Properties |
|---|---|---|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
None is always a valid abstract leaf; it serves as an indication that the leaf should be restored using metadata stored in the checkpoint.
Type[AbstractLeaf] is also always a valid abstract leaf; it again serves as an indication that the leaf should be restored using the metadata, but with the additional constraint to load as the indicated type. For example, instead of specifying jax.ShapeDtypeStruct(shape=..., dtype=..., sharding=...), it is sufficient to pass jax.ShapeDtypeStruct. Similarly, instead of passing 0 to restore as an int, the type itself may be passed.
To summarize, here are the ways you can load a PyTree using abstract leaves, with the way we most recommend at the top, and the way we least recommend at the bottom.
1. Fully-specified abstract values
This provides the most loading validations and requires the least amount of unnecessary metadata reads.
abstract_state = {
'a': jax.ShapeDtypeStruct(shape=..., dtype=..., sharding=jax.sharding.NamedSharding(...))
}
2. Only types specified
This guarantees that each leaf will be loaded with the indicated type, but metadata will be used to restore specific properties for each leaf.
abstract_state = {
'a': jax.ShapeDtypeStruct,
'b': int,
'c': np.ndarray,
}
3. None specified (per-leaf)
This is essentially the same as (2), but metadata will also be used to decide which type each leaf should be loaded as.
abstract_state = {
'a': None,
'b': None,
}
4. None specified
This loads the PyTree structure without any checks, and can lead to errors later in your code if the checkpoint does not have the structure you expect.
abstract_state = None
Customizing Loaded Properties for Arrays#
Array dtype#
def set_loading_dtype(x: jax.ShapeDtypeStruct) -> jax.ShapeDtypeStruct:
return x.update(dtype=np.int16)
cast_dtype_abstract_state = jax.tree_util.tree_map(
set_loading_dtype, abstract_state
)
ocp.load(path, cast_dtype_abstract_state)
WARNING:absl:TensorStore data files not found in checkpoint path /tmp/checkpointing-pytrees/basic. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.
{'a': Array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], dtype=int16),
'b': Array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int16)}
Change sharding#
NOTE: This can often be a particularly sharp edge.
Sharding commonly needs to be changed when loading a checkpoint saved on one topology to a different topology.
If changing topologies, you MUST specify sharding when loading.
Unless you are loading on the exact same topology, Orbax does not make any decisions about shardings on your behalf. If you have the exact same topology, however, it is possible to avoid specifying the sharding when loading. This is demonstrated below:
loaded = ocp.load(path)
WARNING:absl:TensorStore data files not found in checkpoint path /tmp/checkpointing-pytrees/basic. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.
(loaded['a'].sharding, loaded['b'].sharding)
(NamedSharding(mesh=Mesh('model': 1, axis_types=(Auto,)), spec=P('model',), memory_kind=device),
NamedSharding(mesh=Mesh('model': 1, axis_types=(Auto,)), spec=P('model',), memory_kind=device))
In the example below, we alter the sharding while loading.
sharding = jax.sharding.NamedSharding(
jax.sharding.Mesh(jax.devices(), ('x',)),
jax.sharding.PartitionSpec(),
)
def set_sharding(x: jax.ShapeDtypeStruct) -> jax.ShapeDtypeStruct:
return x.update(sharding=sharding)
change_sharding_abstract_state = jax.tree_util.tree_map(
set_sharding, abstract_state
)
loaded = ocp.load(path, change_sharding_abstract_state)
WARNING:absl:TensorStore data files not found in checkpoint path /tmp/checkpointing-pytrees/basic. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.
(loaded['a'].sharding, loaded['b'].sharding)
(NamedSharding(mesh=Mesh('x': 1, axis_types=(Auto,)), spec=P(), memory_kind=device),
NamedSharding(mesh=Mesh('x': 1, axis_types=(Auto,)), spec=P(), memory_kind=device))
We can use pytree metadata instead of the abstract pytree.
metadata = ocp.metadata(path).metadata
change_sharding_metadata = jax.tree_util.tree_map(
lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=sharding), metadata
)
loaded = ocp.load(path, change_sharding_metadata)
(loaded['a'].sharding, loaded['b'].sharding)
WARNING:absl:TensorStore data files not found in checkpoint path /tmp/checkpointing-pytrees/basic. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.
WARNING:absl:TensorStore data files not found in checkpoint path /tmp/checkpointing-pytrees/basic. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.
(NamedSharding(mesh=Mesh('x': 1, axis_types=(Auto,)), spec=P(), memory_kind=device),
NamedSharding(mesh=Mesh('x': 1, axis_types=(Auto,)), spec=P(), memory_kind=device))
Change leaf type#
The abstract leaf type dictates the loaded type for each leaf. If we save a value as a jax.Array but provide an abstract leaf without the required sharding property, Orbax will load as np.ndarray. Similarly, we can save as an int and load as a float if we specify float as the abstract leaf.
path = epath.Path('/tmp/checkpointing-pytrees/change-type/')
path.rmtree(missing_ok=True)
pytree_with_scalars = {
'a': np.asarray(12),
'b': 13.5,
'c': create_sharded_array(np.arange(8)),
}
ocp.save(path, pytree_with_scalars)
abstract_state_with_scalars = {
'a': float,
'b': int,
'c': np.empty((8,)),
}
ocp.load(path, abstract_state_with_scalars)
WARNING:absl:TensorStore data files not found in checkpoint path /tmp/checkpointing-pytrees/change-type. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.
{'a': 12.0, 'b': 13, 'c': array([0., 1., 2., 3., 4., 5., 6., 7.])}
Partial Loading#
You may wish to load part of a PyTree contained within a saved checkpoint. For example, consider the following item:
original_item = {
'params': {
'layer1': {
'kernel': np.arange(8),
'bias': np.arange(8),
},
'layer2': {
'kernel': np.arange(8),
'bias': np.arange(8),
},
},
'opt_state': [np.arange(8), np.arange(8)],
'step': 101,
}
path = epath.Path('/tmp/checkpointing-pytrees/partial/')
path.rmtree(missing_ok=True)
ocp.save(path / '1', original_item)
If we want to load only a subset of PyTree nodes (params.layer2 and step, for example), we can use Placeholder values.
Placeholder#
To load part of a PyTree item, we can specify which nodes to ignore during loading by using ... (ocp.PLACEHOLDER).
reference_item = {
'params': {
'layer1': {
'kernel': ...,
'bias': ...,
},
'layer2': {
'kernel': np.arange(8),
'bias': np.arange(8),
},
},
'opt_state': [..., ...],
'step': 101,
}
ocp.load(path / '1', reference_item)
WARNING:absl:TensorStore data files not found in checkpoint path /tmp/checkpointing-pytrees/partial/1. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.
{'opt_state': [Ellipsis, Ellipsis],
'params': {'layer1': {'bias': Ellipsis, 'kernel': Ellipsis},
'layer2': {'bias': array([0, 1, 2, 3, 4, 5, 6, 7]),
'kernel': array([0, 1, 2, 3, 4, 5, 6, 7])}},
'step': 101}
Advanced Customizations#
ocp.Context enables more customizations.
For customized save/load behavior, these APIs should be invoked within a ocp.Context
instance, which in turn can be configured with a number of options like Saving, Loading,
FileOptions etc.
The usage pattern is as follows:
with ocp.Context(
pytree_options=PyTreeOptions(...),
file_options=FileOptions(...),
):
ocp.save(path, pytree)
Let’s explore few examples. Please also take a look at API Reference for specific option details.
Saving#
Customizing Array dtype#
we can customize the on-disk type used to save individual arrays. First, let’s save and load as normal.
path = epath.Path('/tmp/checkpointing-pytrees/advanced/')
path.rmtree(missing_ok=True)
ocp.save(path / '1', pytree)
loaded = ocp.load(path / '1')
WARNING:absl:TensorStore data files not found in checkpoint path /tmp/checkpointing-pytrees/advanced/1. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.
(loaded['a'].dtype, loaded['b'].dtype)
(dtype('int32'), dtype('int32'))
Now, let’s set the dtype of selective array when saving.
def scoped_storage_options_creator(keypath, value):
del value
last_key = keypath[-1]
# Override 'a' to int16
if isinstance(last_key, jax.tree_util.GetAttrKey) and last_key.name == 'a':
return ocp.options.ArrayOptions.Saving.StorageOptions(
dtype=np.dtype(np.int16)
)
# Return None to use global default storage_options for other leaves
return None
with ocp.Context(
array_options=ocp.options.ArrayOptions(
saving=ocp.options.ArrayOptions.Saving(
scoped_storage_options_creator=scoped_storage_options_creator,
)
)
):
ocp.save(path / '2', pytree, overwrite=True)
loaded = ocp.load(path / '2')
WARNING:absl:TensorStore data files not found in checkpoint path /tmp/checkpointing-pytrees/advanced/2. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.
(loaded['a'].dtype, loaded['b'].dtype)
(dtype('int32'), dtype('int32'))
Now, let’s set the dtype of all arrays when saving.
scoped_storage_options_creator = (
lambda k, v: ocp.options.ArrayOptions.Saving.StorageOptions(
dtype=np.dtype(np.int16)
)
)
with ocp.Context(
array_options=ocp.options.ArrayOptions(
saving=ocp.options.ArrayOptions.Saving(
scoped_storage_options_creator=scoped_storage_options_creator
)
)
):
ocp.save(path / '3', pytree, overwrite=True)
loaded = ocp.load(path / '3')
WARNING:absl:TensorStore data files not found in checkpoint path /tmp/checkpointing-pytrees/advanced/3. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.
(loaded['a'].dtype, loaded['b'].dtype)
(dtype('int16'), dtype('int16'))
High Throughput with ocdbt option#
For high throughput and avoid creating separate subdirectories for each leaf, enable use_ocdbt. Please note that it is enabled by default.
with ocp.Context(
array_options=ocp.options.ArrayOptions(
saving=ocp.options.ArrayOptions.Saving(
use_ocdbt=True,
)
)
):
ocp.save(path / '4', pytree, overwrite=True)
A checkpoint created with this option enabled can be identified by presence of files manifest.ocdbt and subdirs like ocdbt.process_*.
!ls /tmp/checkpointing-pytrees/advanced/4/pytree
/home/docs/.asdf/installs/python/3.12.12/lib/python3.12/pty.py:95: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
pid, fd = os.forkpty()
ls: cannot access '/tmp/checkpointing-pytrees/advanced/4/pytree': No such file or directory
However, for use cases like large stacked models, disabling this option may be more efficient.
with ocp.Context(
array_options=ocp.options.ArrayOptions(
saving=ocp.options.ArrayOptions.Saving(
use_ocdbt=False,
)
)
):
ocp.save(path / '5', pytree, overwrite=True)
!ls /tmp/checkpointing-pytrees/advanced/5/pytree
WARNING:absl:[process=0][thread=Thread-136 (_event_loop_runner)] Skipping merge of OCDBT checkpoints: No per-process OCDBT checkpoint subdirs found in /tmp/checkpointing-pytrees/advanced/5.orbax-checkpoint-tmp/state,
ls: cannot access '/tmp/checkpointing-pytrees/advanced/5/pytree': No such file or directory
Please note how each leaf is written in its own subdir when use_ocdbt=False.
Loading#
Pad / truncate shape#
Ordinarily, specifying a target array with a different shape than in the checkpoint results in an error.
# Original shape.
loaded = ocp.load(path / '1')
(loaded['a'].shape, loaded['b'].shape)
WARNING:absl:TensorStore data files not found in checkpoint path /tmp/checkpointing-pytrees/advanced/1. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.
((16,), (16,))
different_shape_abstract_state = {
'a': jax.ShapeDtypeStruct(
shape=(8,),
dtype=abstract_state['a'].dtype,
sharding=abstract_state['a'].sharding,
),
'b': jax.ShapeDtypeStruct(
shape=(32,),
dtype=abstract_state['b'].dtype,
sharding=abstract_state['b'].sharding,
),
}
try:
ocp.load(path / '1', different_shape_abstract_state)
except BaseException as e:
print(e)
WARNING:absl:TensorStore data files not found in checkpoint path /tmp/checkpointing-pytrees/advanced/1. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.
Requested shape: (32,) is not compatible with the stored shape: (16,). Truncating/padding is disabled. To enable it, set `strict=False` in `ArrayRestoreArgs` for any array in v0 API or `enable_padding_and_truncation=True` in `ArrayOptions.Loading` in v1 API.
We can pad or truncate arrays as they are loaded by specifying enable_padding_and_truncation=True.
with ocp.Context(
array_options=ocp.options.ArrayOptions(
loading=ocp.options.ArrayOptions.Loading(
enable_padding_and_truncation=True
)
)
):
loaded = ocp.load(path / '1', different_shape_abstract_state)
WARNING:absl:TensorStore data files not found in checkpoint path /tmp/checkpointing-pytrees/advanced/1. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.
(loaded['a'].shape, loaded['b'].shape)
((8,), (32,))