Checkpoint Saving#

Defines exported symbols for orbax.checkpoint.experimental.v1.

Prefer to use the style:

import orbax.checkpoint.experimental.v1 as ocp

Saving functions#

orbax.checkpoint.experimental.v1.save(path, state, *, checkpointable_name='state', overwrite=False, custom_metadata=None)[source][source]#

Saves a PyTree.

The operation blocks until complete. For improved performance, consider using save_async() instead. This function should be called on all available controller processes.

Example usage:
Simple save of a dictionary containing JAX arrays::
state = {
‘params’: {

‘w’: jnp.ones((8, 8)), ‘b’: jnp.zeros(8),

}, ‘step’: 100

} # Saves to /tmp/my_checkpoint/ ocp.save(‘/tmp/my_checkpoint’, state)

Parameters:
  • path (UnionType[Path, str]) – The path to save the checkpoint to.

  • state (PyTreeOf[UnionType[Array, ndarray, int, float, number, bytes, bool, str]]) – The PyTree to save. This may be any JAX PyTree (including custom objects registered as PyTrees) consisting of supported leaf types. See orbax.checkpoint.experimental.v1.tree for a table of standard supported leaf types.

  • checkpointable_name (str) – The name of the checkpointable to save a pytree under. Defaults to ‘pytree’.

  • overwrite (bool) – If True, fully overwrites an existing checkpoint in path. Otherwise, raises an error if the checkpoint already exists.

  • custom_metadata (UnionType[list[JsonValue], dict[str, JsonValue], None]) – User-provided custom metadata. An arbitrary JSON-serializable dictionary the user can use to store additional information. The field is treated as opaque by Orbax.

orbax.checkpoint.experimental.v1.save_checkpointables(path, checkpointables, *, overwrite=False, custom_metadata=None)[source][source]#

Saves a dictionary of checkpointables.

A “checkpointable” refers to a logical piece of the checkpoint that is distinct in some way from other pieces. Checkpointables are separable; they may or may not be loaded concurrently and some may be omitted from the checkpoint entirely. Checkpointables are often represented by different types, and have different representations on disk. The quintessential example is model params vs. dataset.

For example, one might do:

ocp.save_checkpointables(
    path,
    {
        'params': pytree_of_arrays,
        'dataset': pygrain.DatasetIterator(...),
    }
)

It is also possible to do:

train_state = {
    'params': params_pytree_of_arrays,
    'opt_state': opt_state_pytree_of_arrays,
    'step': step,
    ...
}
ocp.save_checkpointables(path, train_state)

This is not the ideal way of doing things because it is then difficult to run transformations that involve the entire train state (see the load_and_transform API).

This function should be called on all available controller processes.

Parameters:
  • path (UnionType[Path, str]) – The path to save the checkpoint to.

  • checkpointables (dict[str, Checkpointable]) – A dictionary of checkpointables. Dictionary keys represent the names of the checkpointables, while the values are the checkpointable objects themselves.

  • overwrite (bool) – If True, fully overwrites an existing checkpoint in path. Otherwise, raises an error if the checkpoint already exists.

  • custom_metadata (UnionType[list[JsonValue], dict[str, JsonValue], None]) – User-provided custom metadata. An arbitrary JSON-serializable dictionary the user can use to store additional information. The field is treated as opaque by Orbax.

Return type:

None

orbax.checkpoint.experimental.v1.save_async(path, state, *, checkpointable_name='state', overwrite=False, custom_metadata=None)[source][source]#

Saves a PyTree asynchronously.

Unlike save(), this function returns immediately after the save operation is scheduled (except for certain operations, like device-to-host copying of on-device arrays, which must happen on the main thread). Further writing operations continue in a background thread. An AsyncResponse is returned that can be used to block until the save is complete (using response.result()). Make sure to wait for completion before attempting to load the checkpoint or exiting the program. This function should be called on all available controller processes.

Example usage:

Simple save of a dictionary containing JAX arrays asynchronously:

state = {
    'params': {
        'w': jnp.ones((8, 8)),
        'b': jnp.zeros(8),
    },
    'step': 100
}
# Saves to /tmp/my_checkpoint/
future = ocp.experimental.v1.save_async(
    '/tmp/my_checkpoint', state
)

# Perform other work here...

# Wait for completion only when necessary
future.result()
Parameters:
  • path (UnionType[Path, str]) – The path to save the checkpoint to.

  • state (PyTreeOf[UnionType[Array, ndarray, int, float, number, bytes, bool, str]]) – The PyTree to save. This may be any JAX PyTree (including custom objects registered as PyTrees) consisting of supported leaf types. See orbax.checkpoint.v1.tree for a table of standard supported leaf types.

  • checkpointable_name (str) – The name of the checkpointable to save a pytree under. Defaults to ‘pytree’.

  • overwrite (bool) – If True, fully overwrites an existing checkpoint in path. Otherwise, raises an error if the checkpoint already exists.

  • custom_metadata (UnionType[list[JsonValue], dict[str, JsonValue], None]) – User-provided custom metadata. An arbitrary JSON-serializable dictionary the user can use to store additional information. The field is treated as opaque by Orbax.

Return type:

AsyncResponse[None]

Returns:

An AsyncResponse that can be used to block until the save is complete. Blocking can be done using response.result(), which returns None.

orbax.checkpoint.experimental.v1.save_checkpointables_async(path, checkpointables, *, overwrite=False, custom_metadata=None)[source][source]#

Saves a dictionary of checkpointables asynchronously.

See save_checkpointables() documentation.

Unlike save_checkpointables(), this function returns immediately after the save operation is scheduled (except for certain operations, like device-to-host copying of on-device arrays, which must happen on the main thread). Further writing operations continue in a background thread. An AsyncResponse is returned that can be used to block until the save is complete (using response.result()). Make sure to wait for completion before attempting to load the checkpoint or exiting the program. This function should be called on all available controller processes.

Example usage:

Saving multiple distinct components (e.g. model parameters and dataset iterator) asynchronously:

path = '/tmp/my_checkpoint_step_100'

# Setup components
params = {'w': jnp.ones((8, 8)), 'b': jnp.zeros(8)}

# Setup Grain iterator (Stateful Checkpointable)
import grain
dataset_iter = iter(
    grain.MapDataset.range(30)
    .batch(3)
    .map(lambda x: x.tolist())
)

# Save multiple components
checkpointables = {
    'model': params,
    'dataset': dataset_iter,
}

# Start the async save
response = ocp.save_checkpointables_async(path, checkpointables)

# Perform other operations here...

# Wait for the save to finish
response.result()
Parameters:
  • path (UnionType[Path, str]) – The path to save the checkpoint to.

  • checkpointables (dict[str, Checkpointable]) – A dictionary of checkpointables. Dictionary keys represent the names of the checkpointables, while the values are the checkpointable objects themselves.

  • overwrite (bool) – If True, fully overwrites an existing checkpoint in path. Otherwise, raises an error if the checkpoint already exists.

  • custom_metadata (UnionType[list[JsonValue], dict[str, JsonValue], None]) – User-provided custom metadata. An arbitrary JSON-serializable dictionary the user can use to store additional information. The field is treated as opaque by Orbax.

Return type:

AsyncResponse[None]

Returns:

An AsyncResponse that can be used to block until the save is complete. Blocking can be done using response.result(), which returns None.