Checkpoint Format Guide#

It is important to have an understanding of how Orbax structures checkpoints on disk, particularly if you ever need to debug at the checkpoint level, or if you wish to work with specific pieces of a larger checkpoint.

First, some setup:

from etils import epath
import jax
import numpy as np
import orbax.checkpoint as ocp
sharding = jax.sharding.NamedSharding(
    jax.sharding.Mesh(jax.devices(), ('model',)),
    jax.sharding.PartitionSpec(
        'model',
    ),
)
create_sharded_array = lambda x: jax.device_put(x, sharding)
state = {
    'a': np.arange(16),
    'b': np.ones(16),
}
state = jax.tree_util.tree_map(create_sharded_array, state)
abstract_state = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, state)
state['c'] = np.arange(4)
state['d'] = 5
state['e'] = 'foo'
state
{'a': Array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15],      dtype=int32),
 'b': Array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],      dtype=float32),
 'c': array([0, 1, 2, 3]),
 'd': 5,
 'e': 'foo'}
def print_directory(directory: epath.PathLike, level: int = 0):
  """Prints a directory tree for debugging purposes."""
  directory = epath.Path(directory)
  assert directory.is_dir()
  level_str = '..' * level
  if level == 0:
    print(f'Printing directory tree: {directory}/')
  else:
    print(f'{level_str}{directory.name}/')

  level_str = '..' * (level + 1)
  for p in directory.iterdir():
    if p.is_dir():
      print_directory(p, level=level + 1)
    else:
      print(f'{level_str}{p.name}')

We will start by creating a checkpoint for step 0, consisting of two items: state and metadata.

path = ocp.test_utils.erase_and_create_empty('/tmp/checkpoint')
global_metadata = {'global_property': 'foo'}
with ocp.CheckpointManager(
    path, item_names=('state', 'custom_data'), metadata=global_metadata
) as mngr:
  mngr.save(
      0,
      args=ocp.args.Composite(
          state=ocp.args.PyTreeSave(state),
          custom_data=ocp.args.JsonSave({'lang': 'en', 'version': 1.2}),
      ),
  )
print_directory(path)
Printing directory tree: /tmp/checkpoint/
..0/
...._CHECKPOINT_METADATA
....custom_data/
......metadata
....state/
......manifest.ocdbt
......d/
........fb9e55817938aadc5b81918cec752943
......ocdbt.process_0/
........manifest.ocdbt
........d/
..........f169d7401933a7b4e1e4aead5a621ee5
..........4d2d5283362376a2b345d46d34059e04
..........a4ac62ada5767c139c414b0f8a78e116
..........ba649b3d43c72fdc4b88f12c47f683f2
......_strings.json
......array_metadatas/
........process_0
......_sharding
......_METADATA
..metadata/
...._ROOT_METADATA

Let’s understand each of these pieces separately.

Root Directory#

The “root directory” is understood to be the directory provided when creating a CheckpointManager. It represents the parent directory where all “sequential” checkpoints will reside (see below). In the above example, this corresponds to /tmp/checkpoint/.

Within the root directory, aside from the sequential checkpoints, there may also be a metadata subdirectory (if metadata was provided when configuring the CheckpointManager).

Sequential Checkpoint#

With the term “sequential checkpoint”, we refer to a checkpoint that represents a particular step in a longer sequence. Typically, in Orbax, this is simply denoted by a directory named with an integer value (0/ in the above example). However, options are available to customize the default format.

The sequential checkpoint has a top-level _CHECKPOINT_METADATA file that stores basic information like the creation timestamp, and other fields.

Checkpoint Items#

Within a sequential checkpoint directory, we have subdirectories corresponding to “items”. An “item” represents a logically distinct unit of a larger checkpoint, so these are naturally represented in separate subdirectories. In the above example, the items are state and custom_data.

This representation makes compositition easier if you want to combine the dataset from one checkpoint with the state from another, for instance. It also prevents collisions if you use the same CheckpointHandler to save both state and embeddings, for instance.

Below this level, the format is no longer universally standard, because each CheckpointHandler customizes its own file format.

PyTree Checkpoints#

Because the state item was saved with ocp.args.PyTreeSave (the same would apply if saved with ocp.args.StandardSave), it takes the following form:

print_directory(path / '0' / 'state')
Printing directory tree: /tmp/checkpoint/0/state/
..manifest.ocdbt
..d/
....fb9e55817938aadc5b81918cec752943
..ocdbt.process_0/
....manifest.ocdbt
....d/
......f169d7401933a7b4e1e4aead5a621ee5
......4d2d5283362376a2b345d46d34059e04
......a4ac62ada5767c139c414b0f8a78e116
......ba649b3d43c72fdc4b88f12c47f683f2
.._strings.json
..array_metadatas/
....process_0
.._sharding
.._METADATA

The _METADATA file provides a complete description of the PyTree structure, including custom and empty nodes.

The tree is represented as a flattened dictionary, where each key is represented as a tuple, where successive elements denote successive levels of nesting. For example, for the dict {'a': {'b': [1, 2]}} the metadata file would contain two entries with keys ('a', 'b', '0') and ('a', 'b', '1').

Keys at each level of nesting also encode what type they are: i.e. whether they are a dict key or a sequential key.

Finally, metadata about the value type is stored (e.g. jax.Array, np.ndarray, etc.) in order to allow for later reconstruction without explicitly requiring the object type to be provided.

import json

json.loads((path / '0' / 'state' / '_METADATA').read_text())
{'tree_metadata': {"('a',)": {'key_metadata': [{'key': 'a', 'key_type': 2}],
   'value_metadata': {'value_type': 'jax.Array',
    'skip_deserialize': False,
    'write_shape': [16]}},
  "('b',)": {'key_metadata': [{'key': 'b', 'key_type': 2}],
   'value_metadata': {'value_type': 'jax.Array',
    'skip_deserialize': False,
    'write_shape': [16]}},
  "('c',)": {'key_metadata': [{'key': 'c', 'key_type': 2}],
   'value_metadata': {'value_type': 'np.ndarray', 'skip_deserialize': False}},
  "('d',)": {'key_metadata': [{'key': 'd', 'key_type': 2}],
   'value_metadata': {'value_type': 'scalar', 'skip_deserialize': False}},
  "('e',)": {'key_metadata': [{'key': 'e', 'key_type': 2}],
   'value_metadata': {'value_type': 'string', 'skip_deserialize': False}}},
 'use_ocdbt': True,
 'use_zarr3': False,
 'store_array_data_equal_to_fill_value': True,
 'custom_metadata': None}

The _sharding file stores information about the shardings originally used when saving jax.Arrays in the tree. It isn’t really human-readable though. To get information about shardings, use the metadata APIs.

Beyond these metadata files, which are directly managed by Orbax, we also have a manifest.ocdbt file managed by the TensorStore library. Actual array data is stored within the d/ subdirectory. Since these files are opaque to human readers, we will not go into detail on their structure.

Finally, you’ll notice the presence of the directory ocdbt.process_0/, which also has a manifest.ocdbt and its own d/ subdirectory. One such folder exists for every process on which the checkpoint was saved. This exists because each process first writes its own data independently to its corresponding subdirectory.

When all processes have finished, Orbax runs a finalization pass to cheaply merge the metadatas from all per-process subdirectories into a global view (note that this still references data in the original subdirectories).

Working with TensorStore#

Sometimes, it is helpful to work directly with the TensorStore API to debug individual parameters in a checkpoint.

from etils import epath
import jax
import tensorstore as ts

ts_context = ts.Context(
    {
        # Provide cache pool for B-tree nodes to avoid repeated reads.
        # 100MB limit.
        'cache_pool#ocdbt': {'total_bytes_limit': 100000000},
    },
    parent=jax.experimental.array_serialization.serialization.TS_CONTEXT,
)

To read using TensorStore, we need to construct a TensorStore Spec. For this, we can use Orbax APIs. The spec points to a base path, as well as a particular parameter name (a in this case). It contains further options related to the checkpoint format.

ParamInfo = ocp.type_handlers.ParamInfo
state_dir = path / '0' / 'state'
param_name = 'a'
info = ParamInfo(name='a', parent_dir=state_dir, is_ocdbt_checkpoint=True, use_zarr3=True)
tspec = ocp.type_handlers.get_json_tspec_read(info, use_ocdbt=True)
tspec
{'driver': 'zarr3',
 'kvstore': {'driver': 'ocdbt',
  'base': {'driver': 'file',
   'path': '/tmp/checkpoint/0/state',
   'file_io_locking': {'mode': 'non_atomic'}},
  'manifest': {'driver': 'file', 'path': '/tmp/checkpoint/0/state'},
  'path': 'a',
  'cache_pool': 'cache_pool#ocdbt'},
 'recheck_cached_data': False,
 'recheck_cached_metadata': False,
 'fill_missing_data_reads': False}

We can verify which keys are present in the checkpoint, which matches information we gathered earlier from the Orbax metadata API.

ts.KvStore.open({"driver": "ocdbt", "base": "file:///tmp/checkpoint/0/state/"}).result().list().result()
[b'a/.zarray',
 b'a/0',
 b'b/.zarray',
 b'b/0',
 b'c/.zarray',
 b'c/0',
 b'd/.zarray',
 b'd/0']

Finally, we can directly restore the array using TensorStore.

tspec = {'driver': 'zarr', 'kvstore': {'driver': 'ocdbt', 'base': 'file:///tmp/checkpoint/0/state/', 'path': 'a'}}
t = ts.open(ts.Spec(tspec), open=True, context=ts_context).result()
result = t.read().result()
result
array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15],
      dtype=int32)