Checkpointing PyTrees of Arrays#

A PyTree is the most common way of representing a training state in JAX. While Orbax is designed to be as generic as possible, and provides customization options for all manner of checkpointable objects, PyTrees naturally have pride of place. Furthermore, the standard object used to represent large, sharded arrays is the jax.Array. This, too, has extensive first-class support.

CheckpointHandler Support#

There are essentially two options provided by Orbax for working with PyTrees.

import numpy as np
import orbax.checkpoint as ocp
import jax
sharding = jax.sharding.NamedSharding(
    jax.sharding.Mesh(jax.devices(), ('model',)),
    jax.sharding.PartitionSpec(
        'model',
    ),
)
create_sharded_array = lambda x: jax.device_put(x, sharding)
state = {
    'a': np.arange(16),
    'b': np.ones(16),
}
state = jax.tree_util.tree_map(create_sharded_array, state)
abstract_state = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, state)

Basic Checkpointing#

Let’s use StandardCheckpointHandler to work with PyTrees of jax.Array.

path = ocp.test_utils.erase_and_create_empty('/tmp/basic/')
# Make sure to use async for improved performance!
ckptr = ocp.AsyncCheckpointer(ocp.StandardCheckpointHandler())
ckptr.save(path / '1', args=ocp.args.StandardSave(state))
ckptr.wait_until_finished()

We specify the abstract_state in order to restore with the given dtypes, shapes, and shardings for each leaf.

restored = ckptr.restore(path / '1', args=ocp.args.StandardRestore(abstract_state))
restored
{'a': Array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15],      dtype=int32),
 'b': Array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],      dtype=float32)}
restored['a'].sharding
NamedSharding(mesh=Mesh('model': 1), spec=PartitionSpec('model',))

You can do the exact same with a “concrete” target rather than an “abstract” target. However, this requires that you fully initialize the target train state before restoring from the checkpoint, which is inefficient. It is better practice to only initialize metadata (either by manually creating jax.ShapeDtypeStructs or using jax.eval_shape).

ckptr.restore(path / '1', args=ocp.args.StandardRestore(state))
{'a': Array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15],      dtype=int32),
 'b': Array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],      dtype=float32)}

Customizing Restored Properties#

Array dtype#

def set_restore_dtype(x: jax.ShapeDtypeStruct) -> jax.ShapeDtypeStruct:
  x.dtype = np.int16
  return x

cast_dtype_abstract_state = jax.tree_util.tree_map(
    set_restore_dtype, abstract_state)
ckptr.restore(
    path / '1',
    args=ocp.args.StandardRestore(cast_dtype_abstract_state),
)
{'a': Array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15],      dtype=int16),
 'b': Array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int16)}

Pad / truncate shape#

different_shape_abstract_state = {
    'a': jax.ShapeDtypeStruct(
        shape=(8,),
        dtype=abstract_state['a'].dtype,
        sharding=abstract_state['a'].sharding
    ),
    'b': jax.ShapeDtypeStruct(
        shape=(32,),
        dtype=abstract_state['b'].dtype,
        sharding=abstract_state['b'].sharding
    ),
}
ckptr.restore(
    path / '1',
    args=ocp.args.StandardRestore(different_shape_abstract_state),
)
{'a': Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int16),
 'b': Array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int16)}

Change sharding#

NOTE: This can often be a particularly sharp edge.

Sharding commonly needs to be changed when loading a checkpoint saved on one topology to a different topology.

If changing topologies, you MUST specify sharding when restoring.

Unless you are loading on the exact same topology, Orbax does not make any decisions about shardings on you behalf. If you have the exact same topology, however, it is possible to avoid specifying the sharding when loading. This is demonstrated below:

restored = ckptr.restore(path / '1')
WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.
restored['a'].sharding
NamedSharding(mesh=Mesh('model': 1), spec=PartitionSpec('model',))

In the example below, we alter the sharding while loading.

sharding = jax.sharding.NamedSharding(
    jax.sharding.Mesh(jax.devices(), ('x',)),
    jax.sharding.PartitionSpec(),
)
def set_sharding(x: jax.ShapeDtypeStruct) -> jax.ShapeDtypeStruct:
  x.sharding = sharding
  return x

change_sharding_abstract_state = jax.tree_util.tree_map(
    set_sharding, abstract_state)
restored = ckptr.restore(
    path / '1',
    args=ocp.args.StandardRestore(change_sharding_abstract_state),
)
restored['a'].sharding
NamedSharding(mesh=Mesh('x': 1), spec=PartitionSpec())

Advanced Options#

There are some advanced options that StandardCheckpointHandler does not provide. Additional options can be specified using PyTreeCheckpointHandler instead.

Saving#

For example, PyTreeCheckpointHandler can be used to customize the on-disk type used to save individual arrays. First, let’s save and restore as normal.

path = ocp.test_utils.erase_and_create_empty('/tmp/advanced/')
# Make sure to use async for improved performance!
ckptr = ocp.AsyncCheckpointer(ocp.PyTreeCheckpointHandler())
ckptr.save(path / '1', args=ocp.args.PyTreeSave(state))
ckptr.wait_until_finished()
restored = ckptr.restore(path / '1')
/home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.9/site-packages/orbax/checkpoint/type_handlers.py:1401: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.
  warnings.warn(
restored['a'].dtype
dtype('int32')
restored['b'].dtype
dtype('float32')

Now, let’s set the dtype of the array when saving.

ckptr.save(
    path / '2',
    args=ocp.args.PyTreeSave(
        state,
        save_args={
          # We must set one ocp.SaveArgs per leaf.
          'a': ocp.SaveArgs(dtype=np.int16),
          'b': ocp.SaveArgs()
        }
    ),
)
ckptr.wait_until_finished()
restored = ckptr.restore(path / '2')
restored['a'].dtype
dtype('int16')
restored['b'].dtype
dtype('float32')

Restoring#

Options similar to the above are available, where we can customize shape, dtype, and sharding when restoring.

ckptr.restore(
    path / '2',
    args=ocp.args.PyTreeRestore(
        restore_args={
          # RestoreArgs is the parent class for ArrayRestoreArgs.
          # We must set one RestoreArgs per leaf.
          'a': ocp.RestoreArgs(restore_type=np.ndarray),
          'b': ocp.ArrayRestoreArgs(dtype=np.int16, sharding=sharding)
        }
    ),
)
{'a': array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15],
       dtype=int16),
 'b': Array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int16)}

Note that “a” was restored as np.ndarray rather than jax.Array.

PyTreeCheckpointHandler also provides options to perform transformations when restoring. This is useful when your target tree has a different structure than your checkpoint tree. This allows you to avoid loading some keys or rename other keys. Full details are available at the Transformations page.

ckptr.restore(
    path / '2',
    args=ocp.args.PyTreeRestore(
        # `item` serves as a guide to what the result tree structure should look
        # like.
        item={
            # Value doesn't really matter here, as long as it's not None.
            'c': ...,
            # Can add in extra keys.
            'd': np.arange(8)
        },
        # `restore_args` must be relative to the result tree, not the
        # checkpoint.
        restore_args={
          'c': ocp.RestoreArgs(restore_type=np.ndarray),
        },
        transforms={
            'c': ocp.Transform(original_key='a')
        },
    ),
)
WARNING:absl:The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
{'c': array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15],
       dtype=int16),
 'd': array([0, 1, 2, 3, 4, 5, 6, 7])}