Checkpointing with Orbax#

This page serves as a simple overview of common tasks that you may wish to accomplish with Orbax. For more in-depth documentation of the API’s, see API Overview.

Saving and Restoring#

The following example shows how you can synchronously save and restore a PyTree. See Checkpointing PyTrees for more detail.

import numpy as np
import orbax.checkpoint as ocp
import jax

Ensure that the top-level directory already exists before saving.

path = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/')

Create a basic PyTree. This is simply a nested tree-like structure, which can include dicts, lists, or more complicated objects. For the leaves of the tree, Orbax is capable of handling many different types. For our purposes, we will simply use a nested dict of some simple arrays.

my_tree = {
    'a': np.arange(8),
    'b': {
        'c': 42,
        'd': np.arange(16),
    },
}
abstract_my_tree = jax.tree_util.tree_map(
    ocp.utils.to_shape_dtype_struct, my_tree)

To save and restore, we create a Checkpointer object. The Checkpointer must be constructed with a CheckpointHandler - essentially as a configuration providing the Checkpointer with the logic needed to save and restore your object.

For PyTrees, the most common checkpointable object, we can use the convenient shorthand of StandardCheckpointer, which is the same as Checkpointer(StandardCheckpointHandler()) (see docs for more info).

checkpointer = ocp.StandardCheckpointer()
# 'checkpoint_name' must not already exist.
checkpointer.save(path / 'checkpoint_name', my_tree)
checkpointer.restore(
    path / 'checkpoint_name/',
    args=ocp.args.StandardRestore(abstract_my_tree)
)
{'a': array([0, 1, 2, 3, 4, 5, 6, 7]),
 'b': {'c': 42,
  'd': array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15])}}

Metadata about the checkpoint can be retrieved with the metadata function, making it easy to gather information about an arbitrary checkpoint, or to manually inspect certain properties.

checkpointer.metadata(path / 'checkpoint_name')
{'a': ArrayMetadata(name='a', directory=PosixGPath('/tmp/my-checkpoints/checkpoint_name'), shape=(8,), sharding=None, dtype=dtype('int64')),
 'b': {'c': ScalarMetadata(name='b.c', directory=PosixGPath('/tmp/my-checkpoints/checkpoint_name'), shape=(), sharding=None, dtype=dtype('int64')),
  'd': ArrayMetadata(name='b.d', directory=PosixGPath('/tmp/my-checkpoints/checkpoint_name'), shape=(16,), sharding=None, dtype=dtype('int64'))}}

Multiple Objects#

It is often necessary to deal with multiple distinct checkpointable objects at once, often with different types. A Checkpointer combined with a CompositeCheckpointHandler (docs) can be used to represent a single checkpoint consisting of multiple sub-items, each represented by a sub-directory within the checkpoint.

However, when you have a particular object that you’re saving, Orbax needs to know how you want to save it. After all, if you provide a nested dict to save, there’s no way to tell whether it should be saved in a simple JSON representation sufficient for basic metadata, or whether it requires more advanced logic suitable for sharded jax.Arrays. This information can be provided via the orbax.checkpoint.args module.

In the example below, we can imagine that state is a PyTree consisting of large sharded arrays. In contrast metadata contains a few strings and ints, and can easily be saved using JSON.

metadata = {
    'version': 1.0,
    'lang': 'en',
}
checkpointer = ocp.Checkpointer(
    ocp.CompositeCheckpointHandler('state', 'metadata')
)
checkpointer.save(
    path / 'composite_checkpoint',
    args=ocp.args.Composite(
        state=ocp.args.StandardSave(my_tree),
        metadata=ocp.args.JsonSave(metadata),
    ),
)
restored = checkpointer.restore(path / 'composite_checkpoint')
WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.
restored.state
{'a': array([0, 1, 2, 3, 4, 5, 6, 7]),
 'b': {'c': 42,
  'd': array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15])}}
restored.metadata
{'version': 1.0, 'lang': 'en'}

Inspecting the checkpoint directory, we can see that it has sub-directories for state and metadata.

list((path / 'composite_checkpoint').iterdir())
[PosixGPath('/tmp/my-checkpoints/composite_checkpoint/state'),
 PosixGPath('/tmp/my-checkpoints/composite_checkpoint/_CHECKPOINT_METADATA'),
 PosixGPath('/tmp/my-checkpoints/composite_checkpoint/metadata')]

Managing Checkpoints#

In the context of training a model, it is helpful to deal with a series of steps. The CheckpointManager allows you to save steps sequentially, according to a given period, cleaning up after a certain number are stored, and many other functionalities.

Beware: CheckpointManager.save(...) happens in a background thread by default. See Asynchronous Checkpointing for more details.

path = ocp.test_utils.erase_and_create_empty('/tmp/checkpoint_manager')
state = {
    'a': np.arange(8),
    'b': np.arange(16),
}
extra_params = [42, 43]
# Keeps a maximum of 3 checkpoints, and only saves every other step.
options = ocp.CheckpointManagerOptions(max_to_keep=3, save_interval_steps=2)
mngr = ocp.CheckpointManager(
    path, options=options, item_names=('state', 'extra_params')
)

for step in range(11):  # [0, 1, ..., 10]
  mngr.save(
      step,
      args=ocp.args.Composite(
          state=ocp.args.StandardSave(state),
          extra_params=ocp.args.JsonSave(extra_params),
      ),
  )
mngr.wait_until_finished()
restored = mngr.restore(10)
restored_state, restored_extra_params = restored.state, restored.extra_params
WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.
mngr.all_steps()
[6, 8, 10]
mngr.latest_step()
10
mngr.should_save(11)
False

A Standard Recipe#

In most cases, users will wish to save and restore a PyTree representing a model state over the course of many training steps. Many users will also wish to do this is a multi-host, multi-device environment.

First, we will create a PyTree state with sharded jax.Array as leaves.

import jax

path = ocp.test_utils.erase_and_create_empty('/tmp/checkpoint_manager_sharded')

sharding = jax.sharding.NamedSharding(
    jax.sharding.Mesh(jax.devices(), ('model',)),
    jax.sharding.PartitionSpec(
        'model',
    ),
)
create_sharded_array = lambda x: jax.device_put(x, sharding)
train_state = {
    'a': np.arange(16),
    'b': np.ones(16),
}
train_state = jax.tree_util.tree_map(create_sharded_array, train_state)
jax.tree_util.tree_map(lambda x: x.sharding, train_state)
{'a': NamedSharding(mesh=Mesh('model': 1), spec=PartitionSpec('model',)),
 'b': NamedSharding(mesh=Mesh('model': 1), spec=PartitionSpec('model',))}
num_steps = 10
options = ocp.CheckpointManagerOptions(max_to_keep=3, save_interval_steps=2)
mngr = ocp.CheckpointManager(path, options=options)


@jax.jit
def train_fn(state):
  return jax.tree_util.tree_map(lambda x: x + 1, state)


for step in range(num_steps):
  train_state = train_fn(train_state)
  mngr.save(step, args=ocp.args.StandardSave(train_state))
mngr.wait_until_finished()
mngr.restore(mngr.latest_step())
WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.
{'a': Array([ 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24],      dtype=int32),
 'b': Array([10., 10., 10., 10., 10., 10., 10., 10., 10., 10., 10., 10., 10.,
        10., 10., 10.], dtype=float32)}

Let’s imagine now that we are starting a new training run, and would like to restore the checkpoint previously saved. In this case, we only know the tree structure of the checkpoint, and not the actual array values. We would also like to load the arrays with different sharding constraints than how they were originally saved.

train_state = jax.tree_util.tree_map(np.zeros_like, train_state)
sharding = jax.sharding.NamedSharding(
    jax.sharding.Mesh(jax.devices(), ('model',)),
    jax.sharding.PartitionSpec(
        None,
    ),
)
create_sharded_array = lambda x: jax.device_put(x, sharding)
train_state = jax.tree_util.tree_map(create_sharded_array, train_state)
abstract_train_state = jax.tree_util.tree_map(
    ocp.utils.to_shape_dtype_struct, train_state
)

Construct arguments needed for restoration.

restored = mngr.restore(
    mngr.latest_step(),
    args=ocp.args.StandardRestore(abstract_train_state),
)
restored
{'a': Array([ 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24],      dtype=int32),
 'b': Array([10., 10., 10., 10., 10., 10., 10., 10., 10., 10., 10., 10., 10.,
        10., 10., 10.], dtype=float32)}
jax.tree_util.tree_map(lambda x: x.sharding, restored)
{'a': NamedSharding(mesh=Mesh('model': 1), spec=PartitionSpec(None,)),
 'b': NamedSharding(mesh=Mesh('model': 1), spec=PartitionSpec(None,))}