Debugging Guide#

Setup#

Imports#

import jax
import numpy as np
from etils import epath
import orbax.checkpoint as ocp
import tensorstore as ts
import collections
import operator
import asyncio

Create Sample Checkpoint#

state = {
    'a': {
        'x': np.arange(2 ** 24),
        'y': np.arange(1024),
    },
    'b': np.ones(8),
    'c': 42,
}

default_param_name = 'a.x'
default_path = epath.Path('/tmp/checkpoint')
if default_path.exists():
  default_path.rmtree()
with ocp.StandardCheckpointer() as ckptr:
  ckptr.save(default_path, state)

Checkpoint Size#

Actual Size on Disk#

This is the actual size of the checkpoint on disk.

path = ""  # @param {type:"string"}
path = default_path or epath.Path(path)
async def disk_usage(path: epath.Path) -> int:
  """Returns the size of the checkpoint on disk.

  Note: this uses recurision because Orbax checkpoint directories are never
  more than a few levels deep.

  Args:
    path: The path to the checkpoint.
  Returns:
    The size of the checkpoint on disk.
  """

  async def helper(p):
    if p.is_dir():
      return await disk_usage(p)
    else:
      stat = await asyncio.to_thread(p.stat)
      return stat.length

  futures = []
  for p in path.iterdir():
    futures.append(helper(p))
  return sum(await asyncio.gather(*futures))
print('{0:0.3f} GB'.format(float(asyncio.run(disk_usage(path))) / 1e9))
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[5], line 1
----> 1 print('{0:0.3f} GB'.format(float(asyncio.run(disk_usage(path))) / 1e9))

File ~/.asdf/installs/python/3.12.12/lib/python3.12/asyncio/runners.py:191, in run(main, debug, loop_factory)
    161 """Execute the coroutine and return the result.
    162 
    163 This function runs the passed coroutine, taking care of
   (...)    187     asyncio.run(main())
    188 """
    189 if events._get_running_loop() is not None:
    190     # fail fast with short traceback
--> 191     raise RuntimeError(
    192         "asyncio.run() cannot be called from a running event loop")
    194 with Runner(debug=debug, loop_factory=loop_factory) as runner:
    195     return runner.run(main)

RuntimeError: asyncio.run() cannot be called from a running event loop

Implied Size from Checkpoint Metadata#

Users sometimes run into a problem where the checkpoint size on disk seems much larger or smaller than we would expect based on the model itself. Determining the implied size of the checkpoint based on the checkpoint’s own metadata and cross-referencing it against the actual on-disk size can provide some insight.

The actual size on disk is typically expected to be somewhat smaller than the implied size, due to compression.

path = ""  # @param {type:"string"}
path = default_path or epath.Path(path)
metadata = ocp.StandardCheckpointer().metadata(path).item_metadata
size_counts = collections.defaultdict(int)

def get_arr_bytes(meta):
  dtype = meta.dtype
  shape = meta.shape
  size_counts[dtype] += 1
  return np.prod(shape) * np.dtype(dtype).itemsize

total_bytes = jax.tree.reduce(operator.add, jax.tree.map(get_arr_bytes, metadata))
print('{0:0.3f} GB'.format(float(total_bytes) / 1e9))
print()
print('leaf dtype counts:')
for dtype, count in size_counts.items():
  print(f'{dtype}: {count}')
0.134 GB

leaf dtype counts:
int64: 3
float64: 1

Tree Metadata#

Inspecting the tree structure of the checkpoint is crucial, as it allows you to verify that the parameters present in the checkpoint are correct, to say nothing of the array metadata associated with the parameter.

The following can be useful when debugging errors where the loading code was searching for a particular parameter that was not found. A few things could be going wrong here:

  • The parameter is missing from the checkpoint. Ensure the checkpoint is what you think it is, and that it has the correct parameters.

  • If running model surgery, the transformations may be misconfigured. See below.

path = ""  # @param {type:"string"}
path = default_path or epath.Path(path)
metadata = ocp.StandardCheckpointer().metadata(path).item_metadata
metadata_contents = ['.'.join(k) for k in ocp.tree.to_flat_dict(metadata)]
# Here are the parameters present in the checkpoint tree.
for p in metadata_contents:
  print(p)
a.x
a.y
b
c
# Note: instead of "file", use:
#   - "gfile" on Google-internal filesystems.
#   - "gs" on GCS (do not repeat the "gs://" prefix)
ts_contents = ts.KvStore.open({"driver": "ocdbt", "base": f"file://{path.as_posix()}"}).result().list().result()
ts_contents = [p.decode("utf-8") for p in ts_contents]
ts_contents = [p.replace('.zarray', '')[:-1] for p in ts_contents if '.zarray' in p]

# We can assert that the parameters tracked by the metadata file are
# the same as those tracked by Tensorstore. If there is a discrepancy, there may
# be a deeper underlying problem.

assert len(metadata_contents) == len(ts_contents) and sorted(metadata_contents) == sorted(ts_contents)

Individual Parameters#

path = ""  # @param {type:"string"}
# The `param_name` can be obtained by inspecting tree metadata (see above).
param_name = ""  # @param {type:"string"}
path = default_path or epath.Path(path)
param_name = default_param_name or param_name

Value Metadata#

metadata = ocp.StandardCheckpointer().metadata(path).item_metadata
value_metadata = {'.'.join(k): v for k, v in ocp.tree.to_flat_dict(metadata).items()}[param_name]
print(f'shape: {value_metadata.shape}')
print(f'dtype: {value_metadata.dtype}')
shape: (16777216,)
dtype: int64

Array Value#

It can often be helpful to check the raw value of a particular parameter as saved in the checkpoint. This can be done to establish the correctness of a parameter as saved, to eliminate any possibility that saving was done incorrectly for the given parameter (or that the checkpoint has been corrupted). This can help confine the bounds of debugging to restoration.

CAUTION: The read below loads the entire array into memory. For very large arrays, this could result in OOM. To load a smaller slice of the array, simply index into the TensorStore object (t), like this: t[:2, :4].read().result().

ParamInfo = ocp.type_handlers.ParamInfo
ts_context = ts.Context({
    'file_io_concurrency': {'limit': 128},
    'cache_pool#ocdbt': {'total_bytes_limit': 100000000},
})

info = ParamInfo(name=param_name, parent_dir=path, is_ocdbt_checkpoint=True, use_zarr3=False)
tspec = ocp.type_handlers.get_json_tspec_read(info, use_ocdbt=True)

t = ts.open(ts.Spec(tspec), open=True, context=ts_context).result()
arr = t.read().result()
print(arr)
[       0        1        2 ... 16777213 16777214 16777215]