Customizing Checkpointing Logic#
This page is relevant if your model state contains custom leaves in a PyTree, or doesn’t use PyTree at all.
If your model uses PyTree but has custom leaves, read the TypeHandler
section to see how register the custom type with PyTreeCheckpointHandler
.
If your model doesn’t use PyTree or if you want to implement different serialization/deserialization logic, skip to the CheckpointHandler
section.
Setup#
If you’re running this guide in a notebook, make sure to run this cell first.
import asyncio
from concurrent import futures
from dataclasses import dataclass
import functools
import os
import time
from typing import Any, List, Optional, Sequence
from etils import epath
import numpy as np
import orbax.checkpoint as ocp
ParamInfo = ocp.pytree_checkpoint_handler.ParamInfo
Metadata = ocp.metadata.Metadata
TypeHandler#
PyTreeCheckpointHandler
walks through the input PyTree and uses registered TypeHandlers
to serialize/deserialize the leaves. If your custom model state is stored within the leaves of a PyTree, implement a TypeHandler
and use it with PyTreeCheckpointHandler
.
Standard TypeHandlers
Orbax includes pre-defined TypeHandlers
for saving certain types:
ArrayHandler
: jax.ArrayNumpyHandler
: np.ndarrayScalarHandler
: int, floatStringHandler
: str
These default implementations all use Tensorstore to serialize and deserialize data except for StringHandler
which serializes to JSON.
Custom serialization / deserialization#
To implement a custom TypeHandler
, we must define the async serialize and deserialize methods (the section “Async vs Non-Async” lists reasons why these methods should be asynchronous). The new TypeHandler
is then registered so that the PyTreeCheckpointHandler
knows to use this handler when there is a MyState
leaf in the PyTree.
The inputs to the TypeHandler
are batched to allow for performance optimizations in certain cases. PyTreeCheckpointHandler
groups all leaves of the same type and dispatches them all in one-per-type batch.
The example below defines a TypeHandler
for a custom dataclass that stores multiple numpy arrays.
@dataclass
class MyState:
a: np.array
b: np.array
# Make sure to only run this cell once, otherwise a new `MyState` dataclass will
# be created which could mess up Python issubclass/isinstance checks.
Here is a possible TypeHandler
implementation for MyState
:
class MyStateHandler(ocp.pytree_checkpoint_handler.TypeHandler):
"""Serializes MyState to the numpy npz format."""
def __init__(self):
self._executor = futures.ThreadPoolExecutor(max_workers=1)
def typestr(self) -> str:
return 'MyState'
async def serialize(
self,
values: Sequence[MyState],
infos: Sequence[ParamInfo],
args: Optional[Sequence[ocp.SaveArgs]],
) -> List[futures.Future]:
del args # Unused in this example.
futures = []
for value, info in zip(values, infos):
# make sure the per-key directory is present as OCDBT doesn't create one
info.path.mkdir(exist_ok=True)
futures.append(
self._executor.submit(
functools.partial(_write_state, value, info.path)
)
)
return futures
async def deserialize(
self,
infos: Sequence[ParamInfo],
args: Optional[Sequence[ocp.RestoreArgs]] = None,
) -> MyState:
del args # Unused in this example.
futures = []
for info in infos:
futures.append(
await asyncio.get_event_loop().run_in_executor(
self._executor, functools.partial(_from_state, info.path)
)
)
return await asyncio.gather(*futures)
async def metadata(self, infos: Sequence[ParamInfo]) -> Sequence[Metadata]:
# This method is explained in a separate section.
return [Metadata(name=info.name, directory=info.path) for info in infos]
def _write_state(state: MyState, path: epath.Path) -> str:
path = path / 'my_state.npz'
np.savez(path, a=state.a, b=state.b)
return path
async def _from_state(path: epath.Path) -> MyState:
data = np.load(path / 'my_state.npz')
return MyState(a=data['a'], b=data['b'])
ocp.type_handlers.register_type_handler(
MyState, MyStateHandler(), override=True
)
assert ocp.type_handlers.has_type_handler(MyState)
Here is MyStateHandler
in action:
my_tree = {
'state': {'a': np.array([1, 2, 3]), 'b': np.array([4, 5, 6])},
'my_state': MyState(a=np.array([10, 20, 30]), b=np.array([40, 50, 60])),
}
checkpointer = ocp.Checkpointer(
ocp.PyTreeCheckpointHandler()
)
path = epath.Path('/tmp/my_checkpoints/')
# Clear older checkpoints from directory.
# Checkpointer.save will fail if path already exists, unless `force=True`
if path.exists():
path.rmtree()
path.mkdir()
checkpointer.save(path / 'my_tree', my_tree)
!echo "Files in path:" $(ls /tmp/my_checkpoints)
!echo "Files in 'my_tree':" $(ls /tmp/my_checkpoints/my_tree)
!echo "Files in 'my_tree/my_state':" $(ls /tmp/my_checkpoints/my_tree/my_state)
/home/docs/.asdf/installs/python/3.9.18/lib/python3.9/pty.py:85: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
pid, fd = os.forkpty()
Files in path: my_tree
Files in 'my_tree': _CHECKPOINT_METADATA _METADATA checkpoint d manifest.ocdbt my_state ocdbt.process_0
Files in 'my_tree/my_state': my_state.npz
checkpointer.restore(path / 'my_tree')
{'my_state': MyState(a=array([10, 20, 30]), b=array([40, 50, 60])),
'state': {'a': array([1, 2, 3]), 'b': array([4, 5, 6])}}
Metadata#
The metadata()
method is used for inspecting existing checkpoints and is generally implemented to be less costly than a full restore. Some example use cases are determining whether the restored values can fit in the available memory, getting the checkpointed PyTree structure to extract specific subtrees, or validating whether the shapes and dtypes of the values match with your model data.
In the previous example, MyStateHandler
returned the default Metadata()
object since the TypeHandler
interface requires it. However, we recommend completing this implementation especially if the custom type targets general users.
# 'my_state' returns a default Metadata object.
checkpointer.metadata(path / 'my_tree')
{'my_state': Metadata(name='my_state', directory=PosixGPath('/tmp/my_checkpoints/my_tree/my_state')),
'state': {'a': ArrayMetadata(name='state.a', directory=PosixGPath('/tmp/my_checkpoints/my_tree'), shape=(3,), sharding=None, dtype=dtype('int64')),
'b': ArrayMetadata(name='state.b', directory=PosixGPath('/tmp/my_checkpoints/my_tree'), shape=(3,), sharding=None, dtype=dtype('int64'))}}
Example implementation of MyStateHandler.metadata:
# Define a metadata class.
@dataclass
class MyStateMetadata(Metadata):
a_shape: np.shape
b_shape: np.shape
name: str = 'my_state'
class MyStateHandlerWithMetdata(MyStateHandler):
async def metadata(
self, infos: Sequence[ParamInfo]
) -> ocp.value_metadata.Metadata:
metadata = []
for info in infos:
metadata.append(
await asyncio.get_event_loop().run_in_executor(
self._executor, functools.partial(_read_metadata, info)
)
)
return await asyncio.gather(*metadata)
async def _read_metadata(info: ParamInfo) -> MyStateMetadata:
# This function reads the entire state, but can be more optimally defined
# by reading the header from the npz file. Another option is collectively
# gathering all of the metadata info during serialization, and writing it to
# a file. Since metadata is generally pretty small, it's better to write
# to a single file rather than one for each value.
result = await _from_state(info.path)
return MyStateMetadata(
a_shape=result.a.shape,
b_shape=result.b.shape,
directory=info.path,
)
ocp.type_handlers.register_type_handler(
MyState, MyStateHandlerWithMetdata(), override=True
)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/tmp/ipykernel_896/1325065468.py in <module>
1 # Define a metadata class.
2 @dataclass
----> 3 class MyStateMetadata(Metadata):
4 a_shape: np.shape
5 b_shape: np.shape
~/.asdf/installs/python/3.9.18/lib/python3.9/dataclasses.py in dataclass(cls, init, repr, eq, order, unsafe_hash, frozen)
1019
1020 # We're called as @dataclass without parens.
-> 1021 return wrap(cls)
1022
1023
~/.asdf/installs/python/3.9.18/lib/python3.9/dataclasses.py in wrap(cls)
1011
1012 def wrap(cls):
-> 1013 return _process_class(cls, init, repr, eq, order, unsafe_hash, frozen)
1014
1015 # See if we're being called as @dataclass or @dataclass().
~/.asdf/installs/python/3.9.18/lib/python3.9/dataclasses.py in _process_class(cls, init, repr, eq, order, unsafe_hash, frozen)
925 if f._field_type in (_FIELD, _FIELD_INITVAR)]
926 _set_new_attribute(cls, '__init__',
--> 927 _init_fn(flds,
928 frozen,
929 has_post_init,
~/.asdf/installs/python/3.9.18/lib/python3.9/dataclasses.py in _init_fn(fields, frozen, has_post_init, self_name, globals)
502 seen_default = True
503 elif seen_default:
--> 504 raise TypeError(f'non-default argument {f.name!r} '
505 'follows default argument')
506
TypeError: non-default argument 'directory' follows default argument
Now check the metadata, the PyTree should now contain MyStateMetadata
.
checkpointer = ocp.PyTreeCheckpointer()
checkpointer.metadata(path / 'my_tree')
{'my_state': Metadata(name='my_state', directory=PosixGPath('/tmp/my_checkpoints/my_tree/my_state')),
'state': {'a': ArrayMetadata(name='state.a', directory=PosixGPath('/tmp/my_checkpoints/my_tree'), shape=(3,), sharding=None, dtype=dtype('int64')),
'b': ArrayMetadata(name='state.b', directory=PosixGPath('/tmp/my_checkpoints/my_tree'), shape=(3,), sharding=None, dtype=dtype('int64'))}}
In this example, we didn’t need to re-save the checkpoint using the newly registered MyStateHandlerWithMetdata
TypeHandler, because the class doesn’t write new files into the checkpoint.
CheckpointHandler#
If your state is not stored within a PyTree, or if you’d like to customize more aspects of checkpointing, implement CheckpointHandler
. CheckpointHandlers
operate on the entire object so you have a lot of flexibility on how to save and restore the object.
As of orbax-checkpoint-0.5.0
, CheckpointHandler API has changed. This page shows a side-by-side comparison of the old and new APIs.
The legacy APIs are deprecated and will stop working after May 1st, 2024. Please ensure you are using the new style by then.
Example
Serializing the same dataclass used in the TypeHandler
example:
@dataclass
class MyState:
a: np.array
b: np.array
state = MyState(a=np.array([1.0, 1.5]), b=np.array([3, 4, 5]))
Before#
import glob
import json
class LegacyMyStateCheckpointHandler(ocp.CheckpointHandler):
def save(
self,
directory: epath.Path,
item: MyState,
# You can define any argument here:
use_npz=True,
**kwargs,
):
if use_npz:
np.savez(directory / 'my_state.npz', a=item.a, b=item.b)
else:
with open(os.path.join(directory, 'my_state.json'), 'w') as f:
f.write(json.dumps(dict(a=state.a.tolist(), b=state.b.tolist())))
def restore(
self,
directory: epath.Path,
item: Optional[Any] = None,
# You can define any argument here as well.
restore_as_dict=False,
**kwargs,
) -> Any:
state_file = glob.glob(os.fspath(directory / '*.*'))[0]
if state_file == 'my_state.npz':
data = np.load(directory / 'my_state.npz')
else:
with open(state_file, 'r') as f:
data = json.load(f)
data['a'] = np.array(data['a'])
data['b'] = np.array(data['b'])
if restore_as_dict:
return dict(a=data['a'], b=data['b'])
return MyState(a=data['a'], b=data['b'])
def metadata(self, directory: epath.Path) -> Optional[Any]:
"""Returns metadata about the saved item."""
# In this example, the State is restored entirely, but this can be
# optimized. For example, but writing a `metadata` file in `self.save()`,
# and reading the file in this method.
result = self.restore(directory)
return MyStateMetadata(
a_shape=result.a.shape,
b_shape=result.b.shape,
directory=directory / 'my_state',
)
After#
import glob
import json
class MyStateCheckpointHandler(ocp.CheckpointHandler):
def save(
self,
directory: epath.Path,
args: 'MyStateSave',
):
if args.use_npz:
np.savez(directory / 'my_state.npz', a=args.item.a, b=args.item.b)
else:
with open(os.path.join(directory, 'my_state.json'), 'w') as f:
f.write(
json.dumps(dict(a=args.item.a.tolist(), b=args.item.b.tolist()))
)
def restore(
self,
directory: epath.Path,
args: 'MyStateRestore',
) -> Any:
state_file = glob.glob(os.fspath(directory / '*.*'))[0]
if state_file == 'my_state.npz':
data = np.load(directory / 'my_state.npz')
else:
with open(state_file, 'r') as f:
data = json.load(f)
data['a'] = np.array(data['a'])
data['b'] = np.array(data['b'])
if args.restore_as_dict:
return dict(a=data['a'], b=data['b'])
return MyState(a=data['a'], b=data['b'])
def metadata(self, directory: epath.Path) -> Optional[Any]:
"""Returns metadata about the saved item."""
# In this example, the State is restored entirely, but this can be
# optimized. For example, but writing a `metadata` file in `self.save()`,
# and reading the file in this method.
result = self.restore(directory, args=MyStateRestore())
return MyStateMetadata(
a_shape=result.a.shape,
b_shape=result.b.shape,
directory=directory / 'my_state',
)
@ocp.args.register_with_handler(MyStateCheckpointHandler, for_save=True)
@dataclass
class MyStateSave(ocp.args.CheckpointArgs):
item: MyState
use_npz: bool = True
@ocp.args.register_with_handler(MyStateCheckpointHandler, for_restore=True)
@dataclass
class MyStateRestore(ocp.args.CheckpointArgs):
restore_as_dict: bool = False
These classes can be passed to create a new Checkpointer
, which can be used to save or restore a new checkpoint.
legacy_path2 = epath.Path('/tmp/legacy-checkpoint-handler-example/')
legacy_checkpointer = ocp.Checkpointer(LegacyMyStateCheckpointHandler())
if legacy_path2.exists():
legacy_path2.rmtree()
legacy_path2.mkdir()
legacy_checkpointer.save(legacy_path2 / 'state', state, use_npz=False)
!echo "Files in legacy checkpoint path:" $(ls /tmp/legacy-checkpoint-handler-example/)
!echo "Files in legacy 'state' directory:" $(ls /tmp/legacy-checkpoint-handler-example/state)
print('restored state: ', legacy_checkpointer.restore(legacy_path2 / 'state'))
print('restored state as dict: ', legacy_checkpointer.restore(legacy_path2 / 'state', restore_as_dict=True))
print('metadata:', legacy_checkpointer.metadata(legacy_path2 / 'state'))
WARNING:absl:No registered CheckpointArgs found for handler type: <class '__main__.LegacyMyStateCheckpointHandler'>
Files in legacy checkpoint path: state
Files in legacy 'state' directory: _CHECKPOINT_METADATA my_state.json
restored state: MyState(a=array([1. , 1.5]), b=array([3, 4, 5]))
restored state as dict: {'a': array([1. , 1.5]), 'b': array([3, 4, 5])}
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
/tmp/ipykernel_896/3432597007.py in <module>
12 print('restored state: ', legacy_checkpointer.restore(legacy_path2 / 'state'))
13 print('restored state as dict: ', legacy_checkpointer.restore(legacy_path2 / 'state', restore_as_dict=True))
---> 14 print('metadata:', legacy_checkpointer.metadata(legacy_path2 / 'state'))
~/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.9/site-packages/orbax/checkpoint/checkpointer.py in metadata(self, directory)
175 """See superclass documentation."""
176 directory = epath.Path(directory)
--> 177 return self._handler.metadata(directory)
178
179 def close(self):
~/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.9/site-packages/orbax/checkpoint/composite_checkpoint_handler.py in metadata(self, directory)
94
95 def metadata(self, directory: epath.Path) -> Optional[Any]:
---> 96 return self._handler.metadata(directory)
97
98 def structure(self, directory: epath.Path) -> Optional[Any]:
/tmp/ipykernel_896/560276776.py in metadata(self, directory)
45 # and reading the file in this method.
46 result = self.restore(directory)
---> 47 return MyStateMetadata(
48 a_shape=result.a.shape,
49 b_shape=result.b.shape,
NameError: name 'MyStateMetadata' is not defined
path2 = epath.Path('/tmp/checkpoint-handler-example/')
checkpointer = ocp.Checkpointer(MyStateCheckpointHandler())
if path2.exists():
path2.rmtree()
path2.mkdir()
checkpointer.save(path2 / 'state', args=MyStateSave(item=state, use_npz=False))
!echo "Files in checkpoint path:" $(ls /tmp/checkpoint-handler-example/)
!echo "Files in 'state' directory:" $(ls /tmp/checkpoint-handler-example/state)
print('restored state: ', checkpointer.restore(path2 / 'state', args=MyStateRestore()))
print('restored state as dict: ', checkpointer.restore(path2 / 'state', args=MyStateRestore(restore_as_dict=True)))
print('metadata:',checkpointer.metadata(path2 / 'state'))
Files in checkpoint path: state
Files in 'state' directory: _CHECKPOINT_METADATA my_state.json
restored state: MyState(a=array([1. , 1.5]), b=array([3, 4, 5]))
restored state as dict: {'a': array([1. , 1.5]), 'b': array([3, 4, 5])}
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
/tmp/ipykernel_896/2863598327.py in <module>
12 print('restored state: ', checkpointer.restore(path2 / 'state', args=MyStateRestore()))
13 print('restored state as dict: ', checkpointer.restore(path2 / 'state', args=MyStateRestore(restore_as_dict=True)))
---> 14 print('metadata:',checkpointer.metadata(path2 / 'state'))
~/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.9/site-packages/orbax/checkpoint/checkpointer.py in metadata(self, directory)
175 """See superclass documentation."""
176 directory = epath.Path(directory)
--> 177 return self._handler.metadata(directory)
178
179 def close(self):
/tmp/ipykernel_896/3120497307.py in metadata(self, directory)
41 # and reading the file in this method.
42 result = self.restore(directory, args=MyStateRestore())
---> 43 return MyStateMetadata(
44 a_shape=result.a.shape,
45 b_shape=result.b.shape,
NameError: name 'MyStateMetadata' is not defined
Async vs Non-Async#
Asynchronous checkpointing allows training to proceed during the I/O, which prevents expensive computational resources from stalling during the CPU writes. When possible, we highly recommend implementing async handlers.
Async saving can be implemented by copying data to the corresponding worker CPU (if necessary), then parallelizing the writing tasks (e.g. by using the await keyword).
TypeHandler
deserialization should be defined using async to allow multiple objects to be deserialized at a time.
AsyncCheckpointHandler#
The AsyncCheckpointHandler
interface adds a new async_save
abstract method, and should be used with AsyncCheckpointer
to write checkpoints asynchronously.
Note that in the new style, AsyncCheckpointHandler
’s save()
and async_save()
methods work on args
instead of the legacy item
etc arguments. Also, the args
type needs to be registered against the AsyncCheckpointHandler
concrete class.
Example
class MyStateAsyncCheckpointHandler(ocp.AsyncCheckpointHandler, MyStateCheckpointHandler):
def __init__(self):
self._executor = futures.ThreadPoolExecutor(max_workers=1)
def save(self, directory: epath.Path, args: MyStateSave):
time.sleep(.5) # Artificially inflate the time spent in this method.
super().save(directory, args)
async def async_save(self, directory: epath.Path, args: MyStateSave):
return [self._executor.submit(functools.partial(
self.save, directory, args))]
def close(self):
self._executor.shutdown()
# Register MyStateAsyncCheckpointHandler for MyStateSave and MyStateRestore.
# NOTE: This registration will overwrite the previous one with MyStateCheckpointHandler.
# It is just for illustrating this example and should be avoided in real world systems.
ocp.args.register_with_handler(MyStateAsyncCheckpointHandler, for_save=True)(MyStateSave)
ocp.args.register_with_handler(MyStateAsyncCheckpointHandler, for_restore=True)(MyStateRestore)
path3 = epath.Path('/tmp/checkpoint-handler-async/')
if path3.exists():
path3.rmtree()
path3.mkdir()
async_checkpointer = ocp.AsyncCheckpointer(MyStateAsyncCheckpointHandler())
async_checkpointer.save(path3 / 'async-state', args=MyStateSave(item=state))
!echo "directory contents: "; ls /tmp/checkpoint-handler-async/
directory contents:
async-state
After the write is complete, the tmp folder is renamed to just async_state
.
async_checkpointer.wait_until_finished()
async_checkpointer.close()
!ls /tmp/checkpoint-handler-async/
!ls /tmp/checkpoint-handler-async/async-state
async-state
_CHECKPOINT_METADATA my_state.npz