Introduction to Checkpointing with Orbax#
The Orbax library provides multiple loosely related packages geared towards JAX model persistence; checkpointing is a core Orbax component. You can install the checkpointing package with:
pip install orbax-checkpoint
Be sure to check out our PyPI page and GitHub page for more information.
This tutorial (and others in the Orbax documentation) generally assume a basic level of familiarity with the JAX library.
Now, let’s get started with some usage examples. First, we need to set up a simple PyTree containing JAX arrays. This represents our JAX model.
### Setup ###
import itertools
from etils import epath
import jax
import numpy as np
directory = epath.Path('/tmp/101/my-checkpoints')
pytree = {
'a': np.arange(64).reshape((8, 8)),
'b': np.arange(16),
'c': np.asarray(4.5),
}
mesh = jax.sharding.Mesh(jax.devices(), ('x',))
shardings = {
'a': jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec('x', None)
),
'b': jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()),
'c': jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()),
}
pytree = jax.tree.map(
lambda arr, sharding: jax.make_array_from_callback(
arr.shape,
sharding,
lambda idx: arr[idx],
),
pytree,
shardings,
)
_checkpoint_name = itertools.count()
def next_checkpoint_name() -> str:
return f'ckpt{next(_checkpoint_name)}'
Reading and Writing#
First, import the checkpointing package. For v1, it’s crucial to use the exact import statement as below; an incorrect import can lead to errors or unexpected behavior.
from orbax.checkpoint import v1 as ocp
Using the tree of jax.Array created above, let’s save a checkpoint.
checkpoint_name = next_checkpoint_name()
ocp.save(directory / checkpoint_name, pytree)
Loading yields the original PyTree of arrays.
ocp.load(directory / checkpoint_name)
WARNING:absl:TensorStore data files not found in checkpoint path /tmp/101/my-checkpoints/ckpt0. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.
/home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages/orbax/checkpoint/_src/serialization/jax_array_handlers.py:736: UserWarning: Sharding info not provided when restoring. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.
warnings.warn(
{'a': Array([[ 0, 1, 2, 3, 4, 5, 6, 7],
[ 8, 9, 10, 11, 12, 13, 14, 15],
[16, 17, 18, 19, 20, 21, 22, 23],
[24, 25, 26, 27, 28, 29, 30, 31],
[32, 33, 34, 35, 36, 37, 38, 39],
[40, 41, 42, 43, 44, 45, 46, 47],
[48, 49, 50, 51, 52, 53, 54, 55],
[56, 57, 58, 59, 60, 61, 62, 63]], dtype=int32),
'b': Array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], dtype=int32),
'c': Array(4.5, dtype=float32)}
We can inspect the tree structure and array properties using metadata.
ocp.metadata(directory / checkpoint_name).metadata
WARNING:absl:TensorStore data files not found in checkpoint path /tmp/101/my-checkpoints/ckpt0. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.
{'a': ArrayMetadata(shape=(8, 8), dtype=dtype('int32'), sharding_metadata=NamedShardingMetadata(shape=[1], axis_names=['x'], axis_types=(Auto,), partition_spec=('x', None)) device_mesh=DeviceMetadataMesh(mesh=[DeviceMetadata(id=0)]), storage_metadata=StorageMetadata(chunk_shape=(8, 8), write_shape=(8, 8))),
'b': ArrayMetadata(shape=(16,), dtype=dtype('int32'), sharding_metadata=NamedShardingMetadata(shape=[1], axis_names=['x'], axis_types=(Auto,), partition_spec=()) device_mesh=DeviceMetadataMesh(mesh=[DeviceMetadata(id=0)]), storage_metadata=StorageMetadata(chunk_shape=(16,), write_shape=(16,))),
'c': ArrayMetadata(shape=(), dtype=dtype('float32'), sharding_metadata=NamedShardingMetadata(shape=[1], axis_names=['x'], axis_types=(Auto,), partition_spec=()) device_mesh=DeviceMetadataMesh(mesh=[DeviceMetadata(id=0)]), storage_metadata=StorageMetadata(chunk_shape=(), write_shape=()))}
Note that we are accessing the property: metadata(...).metadata. This is the metadata specific to the PyTree itself. Other properties are general to the entire checkpoint, such as timestamps.
Be sure to check out additional documentation on Working with PyTrees.
Checkpointing in a Training Loop#
When training an ML model, checkpoints are commonly used to record progress for later recovery in case of failure, to perform evaluations, or to distribute the model to downstream consumers after the experiment completes. Typically, a checkpoint is saved every n steps.
@jax.jit
def train_step(state):
"""Fake train step. This applies a function to `state` in some way."""
return jax.tree.map(lambda x: x + 1, state)
def initialize_state():
"""Initializes the state, typically given some random number generator."""
return {'step': 0, **pytree}
def init_or_restore(
source_checkpoint_path: str | None,
):
# If provided, restore initial checkpoint (e.g. for fine-tuning).
# This can be referred to as a "source" checkpoint. Note the distinction drawn
# between this "source checkpoint" and the "latest checkpoint". The source
# checkpoint comes from a different experiment entirely, and is just used
# to initialize the current experiment. The latest checkpoint comes from this
# experiment, and allows us to resume after interruption.
if source_checkpoint_path:
return ocp.load(source_checkpoint_path)
# Otherwise, init from scratch
else:
return initialize_state()
def train():
total_steps = 10
with ocp.training.Checkpointer(directory / 'experiment') as ckptr:
# If checkpoints exist in the root directory, we are recovering after a
# restart, and should resume from the latest checkpoint.
# Otherwise, init from scratch or load the source checkpoint.
if ckptr.latest is None:
train_state = init_or_restore(directory / checkpoint_name)
start_step = 0
else:
train_state = ckptr.load()
start_step = ckptr.latest.step
for step in range(start_step, total_steps):
train_state = train_step(train_state)
ckptr.save(step, train_state)
train()
WARNING:absl:TensorStore data files not found in checkpoint path /tmp/101/my-checkpoints/ckpt0. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.
!ls {directory / 'experiment'}
0 1 2 3 4 5 6 7 8 9
/home/docs/.asdf/installs/python/3.12.12/lib/python3.12/pty.py:95: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
pid, fd = os.forkpty()
To summarize, a typical training workflow (from a checkpoint-focused perspective), consists of the following steps:
Identify the latest checkpoint, if any.
If no latest checkpoint is found:
Restore from the source checkpoint if provided, or,
Initialize the model from scratch.
If a latest checkpoint is found, restore it, and resume training from the latest step.
Be sure to check out additional documentation on Training.
API Overview#
Orbax’s main API entry points for users are divided into two levels as we have shown above:
Higher level: sequence-of-steps API (
training.Checkpointer).Lower level: individual path API (free functions).
We recommend that users take advantage of the higher level API when working with a training loop that repeatedly saves (and occasionally loads) checkpoints with a given step interval. The free functions, which work on the basis of arbitrary checkpoint paths, are more useful when working with individual paths in isolation.
Both API levels share the same core API’s:
Saving:
save/save_checkpointablesLoading:
load/load_checkpointablesMetadata:
metadata/checkpointables_metadata
These conceptually perform the same tasks for both API levels, but are accessed
slightly differently. In a training loop, we must create a
training.Checkpointer object. The method signatures accept step (an integer)
as identifiers for each checkpoint. In contrast, for individual checkpoint paths
we directly specify the path string.
For example:
Higher level (sequence-of-steps):
ocp.save(100, state)
ocp.load(100, abstract_state)
ocp.metadata(100)
Lower level (individual paths):
ocp.save('/tmp/my/checkpoint', state)
ocp.load('/tmp/my/checkpoint', abstract_state)
ocp.metadata('/tmp/my/checkpoint')
See additional documentation on Training for the high-level API and see Working with PyTrees and Checkpointables for information on the low-level API.
What’s Next?#
So far, we have seen some simple and common patterns of Orbax usage. This represents just the tip of the checkpointing iceberg. We encourage the reader to explore additional topics.
PyTrees of arrays are a fundamental representation of ML models in JAX. Working with PyTrees examines PyTree checkpointing in greater detail, showing how to reshard, cast, and manipulate other array properties. It also demonstrates multiple mechanisms for partially restoring a PyTree. Further advanced options for saving and restoring PyTrees and arrays are also shown.
Compute efficiency is crucial for training ML models. Async checkpointing shows how to save and load in a background thread, minimizing the performance impact of checkpointing on the training job.
PyTrees of arrays are not the only type of object that needs to be checkpointed. Orbax introduces the concept of a Checkpointable to represent other objects, like dataset iterators or special metadata, that must be saved alongside the main model. Further mechanisms for advanced support for user-customized objects are also shown.
The step-based training loop is a common concept across many ML workflows. We expand on the training module provided by Orbax in Checkpointing in a Training Loop, which offers Checkpointer as the primary entry point.
Interacting directly with the file format of the checkpoint on disk is useful in a variety of circumstances. We provide details on the file format, contributing to a deeper grasp of Orbax concepts, debugging strategies, and advanced options.