Checkpoint Format Guide#

What is an Orbax checkpoint?#

An Orbax checkpoint is a directory containing an empty file named orbax.checkpoint. All Orbax checkpoints saved with the V1 API will include this file, and any directories not including the file are not valid checkpoints (note that they could still be valid older checkpoints saved with the V0 API).

Overview#

Consider the following directory tree:

/path/to/my/checkpoints/
  0/
    pytree/
      ...
    dataset/
      ...
  100/
    ...
  200/
    ...

What does each level represent?

The top-level directory is called a root directory.

Within the root directory is a sequence of individual checkpoints. In a training context, each of these checkpoints corresponds to an integer step.

Within each checkpoint are a set of checkpointables corresponding to individual elements like the PyTree train state, the dataset iterator, and so on.

Let’s take a closer look at these elements.

Singular Checkpoints#

A checkpoint is a persistent representation of an ML model present in a storage location, typically on disk. When a model is saved using Orbax, it becomes a checkpoint. When a checkpoint is loaded using Orbax, it becomes a model.

Concretely, in Orbax, a checkpoint is composed of a collection of **checkpointables**. That means if we save using the following:

ocp.save_checkpointables(
  '/path/to/my/checkpoint/',
  dict(pytree=..., dataset=..., other_checkpointable=...),
)

We get a checkpoint on disk with a structure similar to the following:

/path/to/my/checkpoint/  # The checkpoint path.
  pytree/  # A directory containing the PyTree piece of the checkpoint.
  dataset/  # A directory containing the dataset piece of the checkpoint.
  other_checkpointable/  # Another checkpointable

Each checkpointable is represented by a subdirectory.

Similarly, we can use a different API (see Checkpointing PyTrees):

ocp.save(
  '/path/to/my/checkpoint/',
  pytree_of_arrays,
)

This produces a checkpoint where pytree is the only subdirectory.

/path/to/my/checkpoint/  # The checkpoint path.
  pytree/  # A directory containing the PyTree piece of the checkpoint.

Sequence of Checkpoints#

Make sure not to confuse a “checkpoint” with a “sequence of checkpoints”. For example, when using training.Checkpointer, multiple checkpoints representing steps will saved to a root directory.

For example, if we save a sequence of steps using the following:

with ocp.training.Checkpointer('/path/to/my/root_directory/') as ckptr:
  for step in range(start_step, num_steps):
    ckptr.save_checkpointables(step, ...)

Our root directory will look like the following, where each integer-numbered subdirectory represents a single checkpoint, corresponding to a step.

/path/to/my/root_directory/
  0/
  100/
  200/
  ...

Format Details#

Now that we understand the checkpoint format abstractly, let’s get to some concrete details.

First, some setup:

import json
import pprint
from etils import epath
import jax
import numpy as np
from orbax.checkpoint import v1 as ocp
directory = epath.Path('/tmp/checkpoint-format/my-checkpoints')
mesh = jax.sharding.Mesh(jax.devices(), ('x',))
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(None))

pytree = {
    'params': {
        'layer0': {
            'kernel': np.random.uniform(size=(2, 2)),
            'bias': np.ones(2),
        }
    },
    'opt_state': {'0': np.random.random(size=(2,))},
}
pytree = jax.device_put(pytree, sharding)
def print_directory(directory: epath.PathLike, level: int = 0):
  """Prints a directory tree for debugging purposes."""
  directory = epath.Path(directory)
  assert directory.is_dir()
  level_str = '..' * level
  if level == 0:
    print(f'{directory}/')
  else:
    print(f'{level_str}{directory.name}/')

  level_str = '..' * (level + 1)
  for p in directory.iterdir():
    if p.is_dir():
      print_directory(p, level=level + 1)
    else:
      print(f'{level_str}{p.name}')

Generic Checkpoints#

Let’s create a checkpoint with two checkpointables, pytree and extra_properties. Let’s also pass some custom metadata, which allows users to provide JSON-serializable properties. For demonstration purposes, let’s save extra_properties as a JSON checkpointable.

# Note that the example would work even without the extra step of forcing
# `extra_properties` to be handled by `JsonHandler`. We just want to ensure it
# gets JSON-encoded for demonstration purposes.
with ocp.Context(
    checkpointables_options=ocp.options.CheckpointablesOptions.create_with_handlers(
        extra_properties=ocp.handlers.JsonHandler
    )
):
  ocp.save_checkpointables(
      directory / 'ckpt-0',
      dict(pytree=pytree, extra_properties={'foo': 'bar'}),
      custom_metadata={'version': 1.0},
  )
!ls {directory / 'ckpt-0'}
_CHECKPOINT_METADATA  extra_properties	orbax.checkpoint  pytree
/home/docs/.asdf/installs/python/3.12.12/lib/python3.12/pty.py:95: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  pid, fd = os.forkpty()

As we expected, each checkpointable gets its own subdirectory. There is also a _CHECKPOINT_METADATA file created, which contains JSON-encoded metadata.

pprint.pp(
    json.loads((directory / 'ckpt-0' / '_CHECKPOINT_METADATA').read_text())
)
{'item_handlers': {'pytree': 'orbax.checkpoint.experimental.v1._src.handlers.pytree_handler.PyTreeHandler',
                   'extra_properties': 'orbax.checkpoint.experimental.v1._src.handlers.json_handler.JsonHandler'},
 'metrics': {},
 'performance_metrics': {},
 'init_timestamp_nsecs': 1779312196647094272,
 'commit_timestamp_nsecs': 1779312196728874548,
 'custom_metadata': {'version': 1.0}}

This file contains a number of internal properties recorded by Orbax. The most important of these is item_handlers, which records the handler used to save each checkpointable, to facilitate later loading.

Notice that our custom_metadata is also stored in this file.

PyTree Checkpointables#

Using the same checkpoint, let’s dig into the pytree subdirectory.

print_directory(directory / 'ckpt-0' / 'pytree')
/tmp/checkpoint-format/my-checkpoints/ckpt-0/pytree/
..manifest.ocdbt
..d/
....262b45de16eb55af3ba251e3fbb15565
..ocdbt.process_0/
....manifest.ocdbt
....d/
......70de16254939129be13f412c7e6b5cba
......d22291039d0678732b387a0309817880
......811fe3b569393ba8e4a438c2ad012182
......9121a55ea1342af2e6dec0586c5b921b
..array_metadatas/
....process_0
.._sharding
.._METADATA

The _METADATA file provides a complete description of the PyTree structure, including custom and empty nodes.

The tree is represented as a flattened dictionary, where each key is represented as a tuple, where successive elements denote successive levels of nesting. For example, for the dict {'a': {'b': [1, 2]}} the metadata file would contain two entries with keys ('a', 'b', '0') and ('a', 'b', '1').

Keys at each level of nesting also encode what type they are: i.e. whether they are a dict key or a sequential key.

Finally, metadata about the value type is stored (e.g. jax.Array, np.ndarray, etc.) in order to allow for later reconstruction without explicitly requiring the object type to be provided.

pprint.pp(
    json.loads((directory / 'ckpt-0' / 'pytree' / '_METADATA').read_text())
)
{'tree_metadata': {"('opt_state', '0')": {'key_metadata': [{'key': 'opt_state',
                                                            'key_type': 2},
                                                           {'key': '0',
                                                            'key_type': 2}],
                                          'value_metadata': {'value_type': 'jax.Array',
                                                             'skip_deserialize': False,
                                                             'write_shape': [2]}},
                   "('params', 'layer0', 'bias')": {'key_metadata': [{'key': 'params',
                                                                      'key_type': 2},
                                                                     {'key': 'layer0',
                                                                      'key_type': 2},
                                                                     {'key': 'bias',
                                                                      'key_type': 2}],
                                                    'value_metadata': {'value_type': 'jax.Array',
                                                                       'skip_deserialize': False,
                                                                       'write_shape': [2]}},
                   "('params', 'layer0', 'kernel')": {'key_metadata': [{'key': 'params',
                                                                        'key_type': 2},
                                                                       {'key': 'layer0',
                                                                        'key_type': 2},
                                                                       {'key': 'kernel',
                                                                        'key_type': 2}],
                                                      'value_metadata': {'value_type': 'jax.Array',
                                                                         'skip_deserialize': False,
                                                                         'write_shape': [2,
                                                                                         2]}}},
 'use_ocdbt': True,
 'use_zarr3': True,
 'store_array_data_equal_to_fill_value': True,
 'custom_metadata': None}

While the exact structure of the metadata is an internal implementation detail and is subject to change, it can still be useful to manually inspect the tree structure. In most cases, however, it is still preferable to rely on public methods intended for obtaining metadata.

pprint.pp(ocp.metadata(directory / 'ckpt-0').metadata)
WARNING:absl:TensorStore data files not found in checkpoint path /tmp/checkpoint-format/my-checkpoints/ckpt-0. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.
{'opt_state': {'0': ArrayMetadata(shape=(2,),
                                  dtype=dtype('float32'),
                                  sharding_metadata=NamedShardingMetadata(shape=[1], axis_names=['x'], axis_types=(Auto,), partition_spec=(None,)) device_mesh=DeviceMetadataMesh(mesh=[DeviceMetadata(id=0)]),
                                  storage_metadata=StorageMetadata(chunk_shape=(2,),
                                                                   write_shape=(2,)))},
 'params': {'layer0': {'bias': ArrayMetadata(shape=(2,),
                                             dtype=dtype('float32'),
                                             sharding_metadata=NamedShardingMetadata(shape=[1], axis_names=['x'], axis_types=(Auto,), partition_spec=(None,)) device_mesh=DeviceMetadataMesh(mesh=[DeviceMetadata(id=0)]),
                                             storage_metadata=StorageMetadata(chunk_shape=(2,),
                                                                              write_shape=(2,))),
                       'kernel': ArrayMetadata(shape=(2, 2),
                                               dtype=dtype('float32'),
                                               sharding_metadata=NamedShardingMetadata(shape=[1], axis_names=['x'], axis_types=(Auto,), partition_spec=(None,)) device_mesh=DeviceMetadataMesh(mesh=[DeviceMetadata(id=0)]),
                                               storage_metadata=StorageMetadata(chunk_shape=(2,
                                                                                             2),
                                                                                write_shape=(2,
                                                                                             2)))}}}
print_directory(directory / 'ckpt-0' / 'pytree')
/tmp/checkpoint-format/my-checkpoints/ckpt-0/pytree/
..manifest.ocdbt
..d/
....262b45de16eb55af3ba251e3fbb15565
..ocdbt.process_0/
....manifest.ocdbt
....d/
......70de16254939129be13f412c7e6b5cba
......d22291039d0678732b387a0309817880
......811fe3b569393ba8e4a438c2ad012182
......9121a55ea1342af2e6dec0586c5b921b
..array_metadatas/
....process_0
.._sharding
.._METADATA

Aside from the _METADATA file, most other files are not human-readable.

The _sharding file stores information about the shardings used when saving jax.Arrays in the tree. Similarly array_metadatas records array properties separately on each process, so that these properties may be later compared and validated.

Orbax uses the TensorStore library to save individual arrays. Actual array data is stored within the d/ subdirectory while directly managed by Orbax, while TensorStore metadata is recorded by the manifest.ocdbt file. These files are not human-readable and require TensorStore APIs to parse (see below).

Finally, you’ll notice the presence of the directory ocdbt.process_0/, which also has a manifest.ocdbt and its own d/ subdirectory. One such folder exists for every process on which the checkpoint was saved. This exists because each process first writes its own data independently to its corresponding subdirectory.

When all processes have finished, Orbax runs a finalization pass to cheaply merge the metadatas from all per-process subdirectories into a global view (note that this still references data in the original subdirectories). This allows for scalability in checkpoint saving as the number of concurrent processes increases.

Working with TensorStore#

Sometimes, it is helpful to work directly with the TensorStore API to debug individual parameters in a checkpoint.

import tensorstore as ts

pytree_path = directory / 'ckpt-0' / 'pytree'

We can verify which keys are present in the checkpoint, which matches information we gathered earlier from the Orbax metadata API.

ts.KvStore.open(
    {"driver": "ocdbt", "base": f"file://{pytree_path.as_posix()}"}
).result().list().result()
[b'opt_state.0/c/0',
 b'opt_state.0/zarr.json',
 b'params.layer0.bias/c/0',
 b'params.layer0.bias/zarr.json',
 b'params.layer0.kernel/c/0/0',
 b'params.layer0.kernel/zarr.json']

To read using TensorStore, we need to construct a TensorStore Spec. For this, we can use Orbax APIs. The spec points to a base path, as well as a particular parameter name (a in this case). It contains further options related to the checkpoint format.

tspec = {
    'driver': 'zarr3',
    'kvstore': {
        'driver': 'ocdbt',
        'base': {'driver': 'file', 'path': pytree_path.as_posix()},
        'path': 'params.layer0.kernel',
    },
}

Finally, we can directly restore the array using TensorStore.

t = ts.open(ts.Spec(tspec), open=True).result()
result = t.read().result()
result
array([[0.20867518, 0.5380874 ],
       [0.3181005 , 0.42343944]], dtype=float32)

Other Checkpointables#

Finally, let’s return to the other checkpointable in our example, called extra_properties. Since we explicitly required the use of JsonHandler to save this object, this piece of the checkpoint is easily human-readable.

print_directory(directory / 'ckpt-0' / 'extra_properties')
/tmp/checkpoint-format/my-checkpoints/ckpt-0/extra_properties/
..data.json
pprint.pp(
    json.loads(
        (directory / 'ckpt-0' / 'extra_properties' / 'data.json').read_text()
    )
)
{'foo': 'bar'}