Checkpoint Format Guide#
What is an Orbax checkpoint?#
An Orbax checkpoint is a directory containing an empty file named
orbax.checkpoint. All Orbax checkpoints saved with the V1 API will include
this file, and any directories not including the file are not valid checkpoints
(note that they could still be valid older checkpoints saved with the V0 API).
Overview#
Consider the following directory tree:
/path/to/my/checkpoints/
0/
pytree/
...
dataset/
...
100/
...
200/
...
What does each level represent?
The top-level directory is called a root directory.
Within the root directory is a sequence of individual checkpoints. In a training context, each of these checkpoints corresponds to an integer step.
Within each checkpoint are a set of checkpointables corresponding to individual elements like the PyTree train state, the dataset iterator, and so on.
Let’s take a closer look at these elements.
Singular Checkpoints#
A checkpoint is a persistent representation of an ML model present in a storage location, typically on disk. When a model is saved using Orbax, it becomes a checkpoint. When a checkpoint is loaded using Orbax, it becomes a model.
Concretely, in Orbax, a checkpoint is composed of a collection of **checkpointables**. That means if we save using the following:
ocp.save_checkpointables(
'/path/to/my/checkpoint/',
dict(pytree=..., dataset=..., other_checkpointable=...),
)
We get a checkpoint on disk with a structure similar to the following:
/path/to/my/checkpoint/ # The checkpoint path.
pytree/ # A directory containing the PyTree piece of the checkpoint.
dataset/ # A directory containing the dataset piece of the checkpoint.
other_checkpointable/ # Another checkpointable
Each checkpointable is represented by a subdirectory.
Similarly, we can use a different API (see Checkpointing PyTrees):
ocp.save(
'/path/to/my/checkpoint/',
pytree_of_arrays,
)
This produces a checkpoint where pytree is the only subdirectory.
/path/to/my/checkpoint/ # The checkpoint path.
pytree/ # A directory containing the PyTree piece of the checkpoint.
Sequence of Checkpoints#
Make sure not to confuse a “checkpoint” with a “sequence of checkpoints”. For example, when using training.Checkpointer, multiple checkpoints representing steps will saved to a root directory.
For example, if we save a sequence of steps using the following:
with ocp.training.Checkpointer('/path/to/my/root_directory/') as ckptr:
for step in range(start_step, num_steps):
ckptr.save_checkpointables(step, ...)
Our root directory will look like the following, where each integer-numbered subdirectory represents a single checkpoint, corresponding to a step.
/path/to/my/root_directory/
0/
100/
200/
...
Format Details#
Now that we understand the checkpoint format abstractly, let’s get to some concrete details.
First, some setup:
import json
import pprint
from etils import epath
import jax
import numpy as np
from orbax.checkpoint import v1 as ocp
directory = epath.Path('/tmp/checkpoint-format/my-checkpoints')
mesh = jax.sharding.Mesh(jax.devices(), ('x',))
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(None))
pytree = {
'params': {
'layer0': {
'kernel': np.random.uniform(size=(2, 2)),
'bias': np.ones(2),
}
},
'opt_state': {'0': np.random.random(size=(2,))},
}
pytree = jax.device_put(pytree, sharding)
def print_directory(directory: epath.PathLike, level: int = 0):
"""Prints a directory tree for debugging purposes."""
directory = epath.Path(directory)
assert directory.is_dir()
level_str = '..' * level
if level == 0:
print(f'{directory}/')
else:
print(f'{level_str}{directory.name}/')
level_str = '..' * (level + 1)
for p in directory.iterdir():
if p.is_dir():
print_directory(p, level=level + 1)
else:
print(f'{level_str}{p.name}')
Generic Checkpoints#
Let’s create a checkpoint with two checkpointables, pytree and
extra_properties. Let’s also pass some custom metadata, which allows users to
provide JSON-serializable properties. For demonstration purposes, let’s save
extra_properties as a JSON checkpointable.
# Note that the example would work even without the extra step of forcing
# `extra_properties` to be handled by `JsonHandler`. We just want to ensure it
# gets JSON-encoded for demonstration purposes.
with ocp.Context(
checkpointables_options=ocp.options.CheckpointablesOptions.create_with_handlers(
extra_properties=ocp.handlers.JsonHandler
)
):
ocp.save_checkpointables(
directory / 'ckpt-0',
dict(pytree=pytree, extra_properties={'foo': 'bar'}),
custom_metadata={'version': 1.0},
)
!ls {directory / 'ckpt-0'}
_CHECKPOINT_METADATA extra_properties orbax.checkpoint pytree
/home/docs/.asdf/installs/python/3.12.12/lib/python3.12/pty.py:95: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
pid, fd = os.forkpty()
As we expected, each checkpointable gets its own subdirectory. There is also a
_CHECKPOINT_METADATA file created, which contains JSON-encoded metadata.
pprint.pp(
json.loads((directory / 'ckpt-0' / '_CHECKPOINT_METADATA').read_text())
)
{'item_handlers': {'pytree': 'orbax.checkpoint.experimental.v1._src.handlers.pytree_handler.PyTreeHandler',
'extra_properties': 'orbax.checkpoint.experimental.v1._src.handlers.json_handler.JsonHandler'},
'metrics': {},
'performance_metrics': {},
'init_timestamp_nsecs': 1779312196647094272,
'commit_timestamp_nsecs': 1779312196728874548,
'custom_metadata': {'version': 1.0}}
This file contains a number of internal properties recorded by Orbax. The most
important of these is item_handlers, which records the handler used to save
each checkpointable, to facilitate later loading.
Notice that our custom_metadata is also stored in this file.
PyTree Checkpointables#
Using the same checkpoint, let’s dig into the pytree subdirectory.
print_directory(directory / 'ckpt-0' / 'pytree')
/tmp/checkpoint-format/my-checkpoints/ckpt-0/pytree/
..manifest.ocdbt
..d/
....262b45de16eb55af3ba251e3fbb15565
..ocdbt.process_0/
....manifest.ocdbt
....d/
......70de16254939129be13f412c7e6b5cba
......d22291039d0678732b387a0309817880
......811fe3b569393ba8e4a438c2ad012182
......9121a55ea1342af2e6dec0586c5b921b
..array_metadatas/
....process_0
.._sharding
.._METADATA
The _METADATA file provides a complete description of the PyTree structure,
including custom and empty nodes.
The tree is represented as a flattened dictionary, where each key is represented
as a tuple, where successive elements denote successive levels of nesting. For
example, for the dict {'a': {'b': [1, 2]}} the metadata file would contain two
entries with keys ('a', 'b', '0') and ('a', 'b', '1').
Keys at each level of nesting also encode what type they are: i.e. whether they are a dict key or a sequential key.
Finally, metadata about the value type is stored (e.g. jax.Array,
np.ndarray, etc.) in order to allow for later reconstruction without
explicitly requiring the object type to be provided.
pprint.pp(
json.loads((directory / 'ckpt-0' / 'pytree' / '_METADATA').read_text())
)
{'tree_metadata': {"('opt_state', '0')": {'key_metadata': [{'key': 'opt_state',
'key_type': 2},
{'key': '0',
'key_type': 2}],
'value_metadata': {'value_type': 'jax.Array',
'skip_deserialize': False,
'write_shape': [2]}},
"('params', 'layer0', 'bias')": {'key_metadata': [{'key': 'params',
'key_type': 2},
{'key': 'layer0',
'key_type': 2},
{'key': 'bias',
'key_type': 2}],
'value_metadata': {'value_type': 'jax.Array',
'skip_deserialize': False,
'write_shape': [2]}},
"('params', 'layer0', 'kernel')": {'key_metadata': [{'key': 'params',
'key_type': 2},
{'key': 'layer0',
'key_type': 2},
{'key': 'kernel',
'key_type': 2}],
'value_metadata': {'value_type': 'jax.Array',
'skip_deserialize': False,
'write_shape': [2,
2]}}},
'use_ocdbt': True,
'use_zarr3': True,
'store_array_data_equal_to_fill_value': True,
'custom_metadata': None}
While the exact structure of the metadata is an internal implementation detail and is subject to change, it can still be useful to manually inspect the tree structure. In most cases, however, it is still preferable to rely on public methods intended for obtaining metadata.
pprint.pp(ocp.metadata(directory / 'ckpt-0').metadata)
WARNING:absl:TensorStore data files not found in checkpoint path /tmp/checkpoint-format/my-checkpoints/ckpt-0. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.
{'opt_state': {'0': ArrayMetadata(shape=(2,),
dtype=dtype('float32'),
sharding_metadata=NamedShardingMetadata(shape=[1], axis_names=['x'], axis_types=(Auto,), partition_spec=(None,)) device_mesh=DeviceMetadataMesh(mesh=[DeviceMetadata(id=0)]),
storage_metadata=StorageMetadata(chunk_shape=(2,),
write_shape=(2,)))},
'params': {'layer0': {'bias': ArrayMetadata(shape=(2,),
dtype=dtype('float32'),
sharding_metadata=NamedShardingMetadata(shape=[1], axis_names=['x'], axis_types=(Auto,), partition_spec=(None,)) device_mesh=DeviceMetadataMesh(mesh=[DeviceMetadata(id=0)]),
storage_metadata=StorageMetadata(chunk_shape=(2,),
write_shape=(2,))),
'kernel': ArrayMetadata(shape=(2, 2),
dtype=dtype('float32'),
sharding_metadata=NamedShardingMetadata(shape=[1], axis_names=['x'], axis_types=(Auto,), partition_spec=(None,)) device_mesh=DeviceMetadataMesh(mesh=[DeviceMetadata(id=0)]),
storage_metadata=StorageMetadata(chunk_shape=(2,
2),
write_shape=(2,
2)))}}}
print_directory(directory / 'ckpt-0' / 'pytree')
/tmp/checkpoint-format/my-checkpoints/ckpt-0/pytree/
..manifest.ocdbt
..d/
....262b45de16eb55af3ba251e3fbb15565
..ocdbt.process_0/
....manifest.ocdbt
....d/
......70de16254939129be13f412c7e6b5cba
......d22291039d0678732b387a0309817880
......811fe3b569393ba8e4a438c2ad012182
......9121a55ea1342af2e6dec0586c5b921b
..array_metadatas/
....process_0
.._sharding
.._METADATA
Aside from the _METADATA file, most other files are not human-readable.
The _sharding file stores information about the shardings used when saving
jax.Arrays in the tree. Similarly array_metadatas records array properties
separately on each process, so that these properties may be later compared and
validated.
Orbax uses the TensorStore library to
save individual arrays. Actual array data is stored within the d/ subdirectory
while directly managed by Orbax, while TensorStore metadata is recorded by the
manifest.ocdbt file. These files are not human-readable and require
TensorStore APIs to parse (see below).
Finally, you’ll notice the presence of the directory ocdbt.process_0/, which
also has a manifest.ocdbt and its own d/ subdirectory. One such folder
exists for every process on which the checkpoint was saved. This exists because
each process first writes its own data independently to its corresponding
subdirectory.
When all processes have finished, Orbax runs a finalization pass to cheaply merge the metadatas from all per-process subdirectories into a global view (note that this still references data in the original subdirectories). This allows for scalability in checkpoint saving as the number of concurrent processes increases.
Working with TensorStore#
Sometimes, it is helpful to work directly with the TensorStore API to debug individual parameters in a checkpoint.
import tensorstore as ts
pytree_path = directory / 'ckpt-0' / 'pytree'
We can verify which keys are present in the checkpoint, which matches
information we gathered earlier from the Orbax metadata API.
ts.KvStore.open(
{"driver": "ocdbt", "base": f"file://{pytree_path.as_posix()}"}
).result().list().result()
[b'opt_state.0/c/0',
b'opt_state.0/zarr.json',
b'params.layer0.bias/c/0',
b'params.layer0.bias/zarr.json',
b'params.layer0.kernel/c/0/0',
b'params.layer0.kernel/zarr.json']
To read using TensorStore, we need to construct a TensorStore Spec. For this, we
can use Orbax APIs. The spec points to a base path, as well as a particular
parameter name (a in this case). It contains further options related to the
checkpoint format.
tspec = {
'driver': 'zarr3',
'kvstore': {
'driver': 'ocdbt',
'base': {'driver': 'file', 'path': pytree_path.as_posix()},
'path': 'params.layer0.kernel',
},
}
Finally, we can directly restore the array using TensorStore.
t = ts.open(ts.Spec(tspec), open=True).result()
result = t.read().result()
result
array([[0.20867518, 0.5380874 ],
[0.3181005 , 0.42343944]], dtype=float32)
Other Checkpointables#
Finally, let’s return to the other checkpointable in our example, called
extra_properties. Since we explicitly required the use of JsonHandler to save this object, this piece of the checkpoint is easily human-readable.
print_directory(directory / 'ckpt-0' / 'extra_properties')
/tmp/checkpoint-format/my-checkpoints/ckpt-0/extra_properties/
..data.json
pprint.pp(
json.loads(
(directory / 'ckpt-0' / 'extra_properties' / 'data.json').read_text()
)
)
{'foo': 'bar'}