Customizing Checkpointing Logic#

This page is relevant if your model state contains custom leaves in a PyTree, or doesn’t use PyTree at all.

If your model uses PyTree but has custom leaves, read the TypeHandler section to see how register the custom type with PyTreeCheckpointHandler.

If your model doesn’t use PyTree or if you want to implement different serialization/deserialization logic, skip to the CheckpointHandler section.

Setup#

If you’re running this guide in a notebook, make sure to run this cell first.

import asyncio
from concurrent import futures
from dataclasses import dataclass
import functools
import os
import time
from typing import Any, List, Optional, Sequence

from etils import epath
import numpy as np
import orbax.checkpoint as ocp

ParamInfo = ocp.pytree_checkpoint_handler.ParamInfo
Metadata = ocp.metadata.Metadata

TypeHandler#

PyTreeCheckpointHandler walks through the input PyTree and uses registered TypeHandlers to serialize/deserialize the leaves. If your custom model state is stored within the leaves of a PyTree, implement a TypeHandler and use it with PyTreeCheckpointHandler.

Standard TypeHandlers

Orbax includes pre-defined TypeHandlers for saving certain types:

  • ArrayHandler: jax.Array

  • NumpyHandler: np.ndarray

  • ScalarHandler: int, float

  • StringHandler: str

These default implementations all use Tensorstore to serialize and deserialize data except for StringHandler which serializes to JSON.

Custom serialization / deserialization#

To implement a custom TypeHandler, we must define the async serialize and deserialize methods (the section “Async vs Non-Async” lists reasons why these methods should be asynchronous). The new TypeHandler is then registered so that the PyTreeCheckpointHandler knows to use this handler when there is a MyState leaf in the PyTree.

The inputs to the TypeHandler are batched to allow for performance optimizations in certain cases. PyTreeCheckpointHandler groups all leaves of the same type and dispatches them all in one-per-type batch.

The example below defines a TypeHandler for a custom dataclass that stores multiple numpy arrays.

@dataclass
class MyState:
  a: np.array
  b: np.array


# Make sure to only run this cell once, otherwise a new `MyState` dataclass will
# be created which could mess up Python issubclass/isinstance checks.

Here is a possible TypeHandler implementation for MyState:

class MyStateHandler(ocp.pytree_checkpoint_handler.TypeHandler):
  """Serializes MyState to the numpy npz format."""

  def __init__(self):
    self._executor = futures.ThreadPoolExecutor(max_workers=1)

  def typestr(self) -> str:
    return 'MyState'

  async def serialize(
      self,
      values: Sequence[MyState],
      infos: Sequence[ParamInfo],
      args: Optional[Sequence[ocp.SaveArgs]],
  ) -> List[futures.Future]:
    del args  # Unused in this example.
    futures = []
    for value, info in zip(values, infos):
      # make sure the per-key directory is present as OCDBT doesn't create one
      info.path.mkdir(exist_ok=True)
      futures.append(
          self._executor.submit(
              functools.partial(_write_state, value, info.path)
          )
      )
    return futures

  async def deserialize(
      self,
      infos: Sequence[ParamInfo],
      args: Optional[Sequence[ocp.RestoreArgs]] = None,
  ) -> MyState:
    del args  # Unused in this example.
    futures = []
    for info in infos:
      futures.append(
          await asyncio.get_event_loop().run_in_executor(
              self._executor, functools.partial(_from_state, info.path)
          )
      )
    return await asyncio.gather(*futures)

  async def metadata(self, infos: Sequence[ParamInfo]) -> Sequence[Metadata]:
    # This method is explained in a separate section.
    return [Metadata(name=info.name, directory=info.path) for info in infos]


def _write_state(state: MyState, path: epath.Path) -> str:
  path = path / 'my_state.npz'
  np.savez(path, a=state.a, b=state.b)
  return path


async def _from_state(path: epath.Path) -> MyState:
  data = np.load(path / 'my_state.npz')
  return MyState(a=data['a'], b=data['b'])


ocp.type_handlers.register_type_handler(
    MyState, MyStateHandler(), override=True
)
assert ocp.type_handlers.has_type_handler(MyState)

Here is MyStateHandler in action:

my_tree = {
    'state': {'a': np.array([1, 2, 3]), 'b': np.array([4, 5, 6])},
    'my_state': MyState(a=np.array([10, 20, 30]), b=np.array([40, 50, 60])),
}


checkpointer = ocp.Checkpointer(
    ocp.PyTreeCheckpointHandler()
)
path = epath.Path('/tmp/my_checkpoints/')

# Clear older checkpoints from directory.
# Checkpointer.save will fail if path already exists, unless `force=True`
if path.exists():
  path.rmtree()
path.mkdir()

checkpointer.save(path / 'my_tree', my_tree)
!echo "Files in path:" $(ls /tmp/my_checkpoints)
!echo "Files in 'my_tree':" $(ls /tmp/my_checkpoints/my_tree)
!echo "Files in 'my_tree/my_state':" $(ls /tmp/my_checkpoints/my_tree/my_state)
/home/docs/.asdf/installs/python/3.9.18/lib/python3.9/pty.py:85: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  pid, fd = os.forkpty()
Files in path: my_tree
Files in 'my_tree': _CHECKPOINT_METADATA _METADATA checkpoint d manifest.ocdbt my_state ocdbt.process_0
Files in 'my_tree/my_state': my_state.npz
checkpointer.restore(path / 'my_tree')
{'my_state': MyState(a=array([10, 20, 30]), b=array([40, 50, 60])),
 'state': {'a': array([1, 2, 3]), 'b': array([4, 5, 6])}}

Metadata#

The metadata() method is used for inspecting existing checkpoints and is generally implemented to be less costly than a full restore. Some example use cases are determining whether the restored values can fit in the available memory, getting the checkpointed PyTree structure to extract specific subtrees, or validating whether the shapes and dtypes of the values match with your model data.

In the previous example, MyStateHandler returned the default Metadata() object since the TypeHandler interface requires it. However, we recommend completing this implementation especially if the custom type targets general users.

# 'my_state' returns a default Metadata object.
checkpointer.metadata(path / 'my_tree')
{'my_state': Metadata(name='my_state', directory=PosixGPath('/tmp/my_checkpoints/my_tree/my_state')),
 'state': {'a': ArrayMetadata(name='state.a', directory=PosixGPath('/tmp/my_checkpoints/my_tree'), shape=(3,), sharding=None, dtype=dtype('int64')),
  'b': ArrayMetadata(name='state.b', directory=PosixGPath('/tmp/my_checkpoints/my_tree'), shape=(3,), sharding=None, dtype=dtype('int64'))}}

Example implementation of MyStateHandler.metadata:

# Define a metadata class.
@dataclass
class MyStateMetadata(Metadata):
  a_shape: np.shape
  b_shape: np.shape
  name: str = 'my_state'


class MyStateHandlerWithMetdata(MyStateHandler):

  async def metadata(
      self, infos: Sequence[ParamInfo]
  ) -> ocp.value_metadata.Metadata:
    metadata = []
    for info in infos:
      metadata.append(
          await asyncio.get_event_loop().run_in_executor(
              self._executor, functools.partial(_read_metadata, info)
          )
      )
    return await asyncio.gather(*metadata)


async def _read_metadata(info: ParamInfo) -> MyStateMetadata:
  # This function reads the entire state, but can be more optimally defined
  # by reading the header from the npz file. Another option is collectively
  # gathering all of the metadata info during serialization, and writing it to
  # a file. Since metadata is generally pretty small, it's better to write
  # to a single file rather than one for each value.
  result = await _from_state(info.path)
  return MyStateMetadata(
      a_shape=result.a.shape,
      b_shape=result.b.shape,
      directory=info.path,
  )


ocp.type_handlers.register_type_handler(
    MyState, MyStateHandlerWithMetdata(), override=True
)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_896/1325065468.py in <module>
      1 # Define a metadata class.
      2 @dataclass
----> 3 class MyStateMetadata(Metadata):
      4   a_shape: np.shape
      5   b_shape: np.shape

~/.asdf/installs/python/3.9.18/lib/python3.9/dataclasses.py in dataclass(cls, init, repr, eq, order, unsafe_hash, frozen)
   1019 
   1020     # We're called as @dataclass without parens.
-> 1021     return wrap(cls)
   1022 
   1023 

~/.asdf/installs/python/3.9.18/lib/python3.9/dataclasses.py in wrap(cls)
   1011 
   1012     def wrap(cls):
-> 1013         return _process_class(cls, init, repr, eq, order, unsafe_hash, frozen)
   1014 
   1015     # See if we're being called as @dataclass or @dataclass().

~/.asdf/installs/python/3.9.18/lib/python3.9/dataclasses.py in _process_class(cls, init, repr, eq, order, unsafe_hash, frozen)
    925                 if f._field_type in (_FIELD, _FIELD_INITVAR)]
    926         _set_new_attribute(cls, '__init__',
--> 927                            _init_fn(flds,
    928                                     frozen,
    929                                     has_post_init,

~/.asdf/installs/python/3.9.18/lib/python3.9/dataclasses.py in _init_fn(fields, frozen, has_post_init, self_name, globals)
    502                 seen_default = True
    503             elif seen_default:
--> 504                 raise TypeError(f'non-default argument {f.name!r} '
    505                                 'follows default argument')
    506 

TypeError: non-default argument 'directory' follows default argument

Now check the metadata, the PyTree should now contain MyStateMetadata.

checkpointer = ocp.PyTreeCheckpointer()
checkpointer.metadata(path / 'my_tree')
{'my_state': Metadata(name='my_state', directory=PosixGPath('/tmp/my_checkpoints/my_tree/my_state')),
 'state': {'a': ArrayMetadata(name='state.a', directory=PosixGPath('/tmp/my_checkpoints/my_tree'), shape=(3,), sharding=None, dtype=dtype('int64')),
  'b': ArrayMetadata(name='state.b', directory=PosixGPath('/tmp/my_checkpoints/my_tree'), shape=(3,), sharding=None, dtype=dtype('int64'))}}

In this example, we didn’t need to re-save the checkpoint using the newly registered MyStateHandlerWithMetdata TypeHandler, because the class doesn’t write new files into the checkpoint.

CheckpointHandler#

If your state is not stored within a PyTree, or if you’d like to customize more aspects of checkpointing, implement CheckpointHandler. CheckpointHandlers operate on the entire object so you have a lot of flexibility on how to save and restore the object.

As of orbax-checkpoint-0.5.0, CheckpointHandler API has changed. This page shows a side-by-side comparison of the old and new APIs.

The legacy APIs are deprecated and will stop working after May 1st, 2024. Please ensure you are using the new style by then.

Example

Serializing the same dataclass used in the TypeHandler example:

@dataclass
class MyState:
    a: np.array
    b: np.array
state = MyState(a=np.array([1.0, 1.5]), b=np.array([3, 4, 5]))

Before#

import glob
import json


class LegacyMyStateCheckpointHandler(ocp.CheckpointHandler):

  def save(
      self,
      directory: epath.Path,
      item: MyState,
      # You can define any argument here:
      use_npz=True,
      **kwargs,
  ):
    if use_npz:
      np.savez(directory / 'my_state.npz', a=item.a, b=item.b)
    else:
      with open(os.path.join(directory, 'my_state.json'), 'w') as f:
        f.write(json.dumps(dict(a=state.a.tolist(), b=state.b.tolist())))

  def restore(
      self,
      directory: epath.Path,
      item: Optional[Any] = None,
      # You can define any argument here as well.
      restore_as_dict=False,
      **kwargs,
  ) -> Any:
    state_file = glob.glob(os.fspath(directory / '*.*'))[0]
    if state_file == 'my_state.npz':
      data = np.load(directory / 'my_state.npz')
    else:
      with open(state_file, 'r') as f:
        data = json.load(f)
        data['a'] = np.array(data['a'])
        data['b'] = np.array(data['b'])
    if restore_as_dict:
      return dict(a=data['a'], b=data['b'])
    return MyState(a=data['a'], b=data['b'])

  def metadata(self, directory: epath.Path) -> Optional[Any]:
    """Returns metadata about the saved item."""
    # In this example, the State is restored entirely, but this can be
    # optimized. For example, but writing a `metadata` file in `self.save()`,
    # and reading the file in this method.
    result = self.restore(directory)
    return MyStateMetadata(
        a_shape=result.a.shape,
        b_shape=result.b.shape,
        directory=directory / 'my_state',
    )

After#

import glob
import json


class MyStateCheckpointHandler(ocp.CheckpointHandler):

  def save(
      self,
      directory: epath.Path,
      args: 'MyStateSave',
  ):
    if args.use_npz:
      np.savez(directory / 'my_state.npz', a=args.item.a, b=args.item.b)
    else:
      with open(os.path.join(directory, 'my_state.json'), 'w') as f:
        f.write(
            json.dumps(dict(a=args.item.a.tolist(), b=args.item.b.tolist()))
        )

  def restore(
      self,
      directory: epath.Path,
      args: 'MyStateRestore',
  ) -> Any:
    state_file = glob.glob(os.fspath(directory / '*.*'))[0]
    if state_file == 'my_state.npz':
      data = np.load(directory / 'my_state.npz')
    else:
      with open(state_file, 'r') as f:
        data = json.load(f)
        data['a'] = np.array(data['a'])
        data['b'] = np.array(data['b'])
    if args.restore_as_dict:
      return dict(a=data['a'], b=data['b'])
    return MyState(a=data['a'], b=data['b'])

  def metadata(self, directory: epath.Path) -> Optional[Any]:
    """Returns metadata about the saved item."""
    # In this example, the State is restored entirely, but this can be
    # optimized. For example, but writing a `metadata` file in `self.save()`,
    # and reading the file in this method.
    result = self.restore(directory, args=MyStateRestore())
    return MyStateMetadata(
        a_shape=result.a.shape,
        b_shape=result.b.shape,
        directory=directory / 'my_state',
    )


@ocp.args.register_with_handler(MyStateCheckpointHandler, for_save=True)
@dataclass
class MyStateSave(ocp.args.CheckpointArgs):
  item: MyState
  use_npz: bool = True


@ocp.args.register_with_handler(MyStateCheckpointHandler, for_restore=True)
@dataclass
class MyStateRestore(ocp.args.CheckpointArgs):
  restore_as_dict: bool = False

These classes can be passed to create a new Checkpointer, which can be used to save or restore a new checkpoint.

legacy_path2 = epath.Path('/tmp/legacy-checkpoint-handler-example/')
legacy_checkpointer = ocp.Checkpointer(LegacyMyStateCheckpointHandler())

if legacy_path2.exists():
  legacy_path2.rmtree()
legacy_path2.mkdir()

legacy_checkpointer.save(legacy_path2 / 'state', state, use_npz=False)
!echo "Files in legacy checkpoint path:" $(ls /tmp/legacy-checkpoint-handler-example/)
!echo "Files in legacy 'state' directory:" $(ls /tmp/legacy-checkpoint-handler-example/state)

print('restored state: ', legacy_checkpointer.restore(legacy_path2 / 'state'))
print('restored state as dict: ', legacy_checkpointer.restore(legacy_path2 / 'state', restore_as_dict=True))
print('metadata:', legacy_checkpointer.metadata(legacy_path2 / 'state'))
WARNING:absl:No registered CheckpointArgs found for handler type: <class '__main__.LegacyMyStateCheckpointHandler'>
Files in legacy checkpoint path: state
Files in legacy 'state' directory: _CHECKPOINT_METADATA my_state.json
restored state:  MyState(a=array([1. , 1.5]), b=array([3, 4, 5]))
restored state as dict:  {'a': array([1. , 1.5]), 'b': array([3, 4, 5])}
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
/tmp/ipykernel_896/3432597007.py in <module>
     12 print('restored state: ', legacy_checkpointer.restore(legacy_path2 / 'state'))
     13 print('restored state as dict: ', legacy_checkpointer.restore(legacy_path2 / 'state', restore_as_dict=True))
---> 14 print('metadata:', legacy_checkpointer.metadata(legacy_path2 / 'state'))

~/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.9/site-packages/orbax/checkpoint/checkpointer.py in metadata(self, directory)
    175     """See superclass documentation."""
    176     directory = epath.Path(directory)
--> 177     return self._handler.metadata(directory)
    178 
    179   def close(self):

~/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.9/site-packages/orbax/checkpoint/composite_checkpoint_handler.py in metadata(self, directory)
     94 
     95   def metadata(self, directory: epath.Path) -> Optional[Any]:
---> 96     return self._handler.metadata(directory)
     97 
     98   def structure(self, directory: epath.Path) -> Optional[Any]:

/tmp/ipykernel_896/560276776.py in metadata(self, directory)
     45     # and reading the file in this method.
     46     result = self.restore(directory)
---> 47     return MyStateMetadata(
     48         a_shape=result.a.shape,
     49         b_shape=result.b.shape,

NameError: name 'MyStateMetadata' is not defined
path2 = epath.Path('/tmp/checkpoint-handler-example/')
checkpointer = ocp.Checkpointer(MyStateCheckpointHandler())

if path2.exists():
  path2.rmtree()
path2.mkdir()

checkpointer.save(path2 / 'state', args=MyStateSave(item=state, use_npz=False))
!echo "Files in checkpoint path:" $(ls /tmp/checkpoint-handler-example/)
!echo "Files in 'state' directory:" $(ls /tmp/checkpoint-handler-example/state)

print('restored state: ', checkpointer.restore(path2 / 'state', args=MyStateRestore()))
print('restored state as dict: ', checkpointer.restore(path2 / 'state', args=MyStateRestore(restore_as_dict=True)))
print('metadata:',checkpointer.metadata(path2 / 'state'))
Files in checkpoint path: state
Files in 'state' directory: _CHECKPOINT_METADATA my_state.json
restored state:  MyState(a=array([1. , 1.5]), b=array([3, 4, 5]))
restored state as dict:  {'a': array([1. , 1.5]), 'b': array([3, 4, 5])}
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
/tmp/ipykernel_896/2863598327.py in <module>
     12 print('restored state: ', checkpointer.restore(path2 / 'state', args=MyStateRestore()))
     13 print('restored state as dict: ', checkpointer.restore(path2 / 'state', args=MyStateRestore(restore_as_dict=True)))
---> 14 print('metadata:',checkpointer.metadata(path2 / 'state'))

~/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.9/site-packages/orbax/checkpoint/checkpointer.py in metadata(self, directory)
    175     """See superclass documentation."""
    176     directory = epath.Path(directory)
--> 177     return self._handler.metadata(directory)
    178 
    179   def close(self):

/tmp/ipykernel_896/3120497307.py in metadata(self, directory)
     41     # and reading the file in this method.
     42     result = self.restore(directory, args=MyStateRestore())
---> 43     return MyStateMetadata(
     44         a_shape=result.a.shape,
     45         b_shape=result.b.shape,

NameError: name 'MyStateMetadata' is not defined

Async vs Non-Async#

Asynchronous checkpointing allows training to proceed during the I/O, which prevents expensive computational resources from stalling during the CPU writes. When possible, we highly recommend implementing async handlers.

Async saving can be implemented by copying data to the corresponding worker CPU (if necessary), then parallelizing the writing tasks (e.g. by using the await keyword).

TypeHandler deserialization should be defined using async to allow multiple objects to be deserialized at a time.

AsyncCheckpointHandler#

The AsyncCheckpointHandler interface adds a new async_save abstract method, and should be used with AsyncCheckpointer to write checkpoints asynchronously.

Note that in the new style, AsyncCheckpointHandler’s save() and async_save() methods work on args instead of the legacy item etc arguments. Also, the args type needs to be registered against the AsyncCheckpointHandler concrete class.

Example

class MyStateAsyncCheckpointHandler(ocp.AsyncCheckpointHandler, MyStateCheckpointHandler):
  def __init__(self):
    self._executor = futures.ThreadPoolExecutor(max_workers=1)

  def save(self, directory: epath.Path, args: MyStateSave):
    time.sleep(.5)  # Artificially inflate the time spent in this method.
    super().save(directory, args)

  async def async_save(self, directory: epath.Path, args: MyStateSave):
    return [self._executor.submit(functools.partial(
        self.save, directory, args))]

  def close(self):
    self._executor.shutdown()

# Register MyStateAsyncCheckpointHandler for MyStateSave and MyStateRestore.
# NOTE: This registration will overwrite the previous one with MyStateCheckpointHandler.
# It is just for illustrating this example and should be avoided in real world systems.
ocp.args.register_with_handler(MyStateAsyncCheckpointHandler, for_save=True)(MyStateSave)
ocp.args.register_with_handler(MyStateAsyncCheckpointHandler, for_restore=True)(MyStateRestore)

path3 = epath.Path('/tmp/checkpoint-handler-async/')
if path3.exists():
  path3.rmtree()
path3.mkdir()

async_checkpointer = ocp.AsyncCheckpointer(MyStateAsyncCheckpointHandler())
async_checkpointer.save(path3 / 'async-state', args=MyStateSave(item=state))
!echo "directory contents: "; ls /tmp/checkpoint-handler-async/
directory contents: 
async-state

After the write is complete, the tmp folder is renamed to just async_state.

async_checkpointer.wait_until_finished()
async_checkpointer.close()

!ls /tmp/checkpoint-handler-async/
!ls /tmp/checkpoint-handler-async/async-state
async-state
_CHECKPOINT_METADATA  my_state.npz