API Overview#
import orbax.checkpoint as ocp
from orbax.checkpoint.checkpoint_managers import preservation_policy as preservation_policy_lib
from orbax.checkpoint.checkpoint_managers import save_decision_policy as save_decision_policy_lib
import jax
import numpy as np
from jax import numpy as jnp
path = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/')
state = {'layer0': {'bias': np.ones((4,)), 'weight': jnp.arange(16)}}
abstract_state = jax.tree.map(ocp.tree.to_shape_dtype_struct, state)
metadata = {'version': 1.0}
extra_metadata = {'version': 1.0, 'step': 0}
dataset = {'my_data': 2}
CheckpointManager Layer#
The most high-level API layer provided by Orbax is the CheckpointManager. This is the API of choice for users dealing with a series of checkpoints denoted as steps in the context of a training run.
CheckpointManagerOptions allows customizing the behavior of the CheckpointManager along various dimensions. A partial list of important customization options is given below. See the API reference for a complete list.
save_decision_policy: A policy that determines when to save checkpoints.preservation_policy: A policy that determines which checkpoints to keep.step_format_fixed_length: Formats with leadingndigits. This can make visually examining the checkpoints in sorted order easier.cleanup_tmp_directories: Automatically cleans up existing temporary/incomplete directories when theCheckpointManageris created.read_only: If True, then checkpoints save and delete are skipped. Restore works as usual.enable_async_checkpointing: True by default. Be wary of turning off, as save performance may be significantly impacted.
If dealing with a single checkpointable object, like a train state, CheckpointManager can be created as follows:
Note that CheckpointManager always saves asynchronously, unless you set enable_async_checkpointing=False in CheckpointManagerOptions. Make sure to use wait_until_finished() if you need to block until a save is complete.
Basic Usage#
import jax
directory = ocp.test_utils.erase_and_create_empty('/tmp/checkpoint-manager-single/')
options = ocp.CheckpointManagerOptions(
save_decision_policy=save_decision_policy_lib.FixedIntervalPolicy(2),
preservation_policy=preservation_policy_lib.LatestN(2),
# other options
)
mngr = ocp.CheckpointManager(
directory,
options=options,
)
num_steps = 5
def train_step(state):
return jax.tree_util.tree_map(lambda x: x + 1, state)
for step in range(num_steps):
state = train_step(state)
mngr.save(step, args=ocp.args.StandardSave(state))
mngr.wait_until_finished()
mngr.latest_step()
4
mngr.all_steps()
[2, 4]
mngr.restore(mngr.latest_step())
{'layer0': {'bias': array([6., 6., 6., 6.]),
'weight': Array([ 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20], dtype=int32)}}
# Restore with additional arguments, like dtype or sharding.
def set_dtype(abstract_arr):
return abstract_arr.update(dtype=np.float32)
mngr.restore(mngr.latest_step(), args=ocp.args.StandardRestore(
jax.tree.map(set_dtype, abstract_state)))
{'layer0': {'bias': array([6., 6., 6., 6.], dtype=float32),
'weight': Array([ 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17.,
18., 19., 20.], dtype=float32)}}
Managing Multiple Items#
Often, we need to deal with multiple items, representing the training state, dataset, and some custom metadata, for instance.
directory = ocp.test_utils.erase_and_create_empty('/tmp/checkpoint-manager-multiple/')
options = ocp.CheckpointManagerOptions(
save_decision_policy=save_decision_policy_lib.FixedIntervalPolicy(2),
preservation_policy=preservation_policy_lib.LatestN(2),
# other options
)
mngr = ocp.CheckpointManager(
directory,
options=options,
)
num_steps = 5
def train_step(step, _state, _extra_metadata):
return jax.tree_util.tree_map(lambda x: x + 1, _state), {**_extra_metadata, **{'step': step}}
for step in range(num_steps):
state, extra_metadata = train_step(step, state, extra_metadata)
mngr.save(
step,
args=ocp.args.Composite(
state=ocp.args.StandardSave(state),
extra_metadata=ocp.args.JsonSave(extra_metadata),
)
)
mngr.wait_until_finished()
# Restore exactly as saved
result = mngr.restore(mngr.latest_step())
result
Composite({'extra_metadata': {'version': 1.0, 'step': 4}, 'state': {'layer0': {'bias': array([11., 11., 11., 11.]), 'weight': Array([10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25], dtype=int32)}}})
result.state
{'layer0': {'bias': array([11., 11., 11., 11.]),
'weight': Array([10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25], dtype=int32)}}
result.extra_metadata
{'version': 1.0, 'step': 4}
# Skip `state` when restoring.
# Note that it is possible to provide `extra_metadata=None` because we already
# saved using `JsonSave`. This is internally cached, so we know it uses JSON
# logic to save and restore. If you had called `restore` without first calling
# `save`, however, it would have been necessary to provide
# `ocp.args.JsonRestore`.
mngr.restore(mngr.latest_step(), args=ocp.args.Composite(extra_metadata=None))
Composite({'extra_metadata': {'version': 1.0, 'step': 4}})
# Restoration of the state can be customized by specifying an abstract state.
# For example, we can change the dtypes to automatically cast the restored
# arrays.
def set_dtype(abstract_arr):
return abstract_arr.update(dtype=np.float32)
mngr.restore(
mngr.latest_step(),
args=ocp.args.Composite(
state=ocp.args.StandardRestore(jax.tree.map(set_dtype, abstract_state)),
extra_metadata=None
)
)
Composite({'extra_metadata': {'version': 1.0, 'step': 4}, 'state': {'layer0': {'bias': array([11., 11., 11., 11.], dtype=float32), 'weight': Array([10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22.,
23., 24., 25.], dtype=float32)}}})
There are some when the mapping between items and respective CheckpointHandlers need to be provided at the time of creating a CheckpointManager instance.
CheckpointManager constructor argument, item_handlers, enables to resolve those scenarios. Please see Using the Refactored CheckpointManager API for the details.
Checkpointer Layer#
Conceptually, the Checkpointer exists to work with a single checkpoint that exists at a single path. It is no frills (relative to CheckpointManager) but guarantees atomicity and allows for asynchronous saving via AsyncCheckpointer.
Saving and Restoring a PyTree#
Typically, you may wish to save and restore a PyTree of arrays to a given path.
This is easily accomplished with StandardCheckpointer.
with ocp.StandardCheckpointer() as ckptr:
ckptr.save(path / 'standard-ckpt-1', state)
result = ckptr.restore(path / 'standard-ckpt-1', abstract_state)
print(result)
{'layer0': {'bias': array([11., 11., 11., 11.]), 'weight': Array([10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25], dtype=int32)}}
Note that StandardCheckpointer always saves asynchronously! In order to block until a save completes, use ckptr.wait_until_finished().
Equivalently, this can be expressed as follows (see the following section):
with ocp.AsyncCheckpointer(ocp.StandardCheckpointHandler()) as ckptr:
ckptr.save(path / 'standard-ckpt-2', args=ocp.args.StandardSave(state))
Understanding Checkpointers#
When greater customization of save and restore behavior is desired, Orbax must be instructed which logic to use to save and restore a given object. This is achieved by combining a Checkpointer with a CheckpointHandler. You can think of the CheckpointHandler as providing a configuration that tells the Checkpointer what serialization logic to use to deal with a particular object, while the Checkpointer provides shared logic used by all CheckpointHandlers, like thread management and atomicity.
with ocp.Checkpointer(ocp.JsonCheckpointHandler()) as ckptr:
ckptr.save(path / 'json-ckpt-1', args=ocp.args.JsonSave({'a': 'b'}))
Async checkpointing provided via AsyncCheckpointer can often help to realize significant resource savings and training speedups because write to disk happens in a background thread. See here for more details.
ckptr = ocp.AsyncCheckpointer(ocp.StandardCheckpointHandler())
While most Checkpointer/CheckpointHandler pairs deal with a single object that is saved and restored, pairing a Checkpointer with CompositeCheckpointHandler allows dealing with multiple distinct objects at once.
with ocp.Checkpointer(ocp.CompositeCheckpointHandler()) as ckptr:
ckptr.save(
path / 'composite-ckpt-1',
args=ocp.args.Composite(
state=ocp.args.StandardSave(state),
metadata=ocp.args.JsonSave(metadata),
)
)
Understanding Items and Registration#
Let’s return to the subject of “items”. This is the term Orbax uses to refer to logically distinct checkpointable units. These units may be bundled together as part of the same state, but it is frequently convenient to maintain some separation between them, as they are often used for very different purposes.
Some common examples may include the training state, dataset, embeddings, custom metadata, etc.
Each of these items may require different logic in order to save, and it is neither possible nor desirable for Orbax to “just figure it out” automatically. It is important to have confidence that the item you’re saving is being saved as you expect it to be.
You can see a list of available handlers available for checkpointing different objects in the API reference. In the case where none of these meet your needs, you can implement your own CheckpointHandler.
Let’s return to our standard example. In this section we will always use CheckpointManager, but all the following principles apply in the same way when using Checkpointer(CompositeCheckpointHandler()).
directory = ocp.test_utils.erase_and_create_empty('/tmp/checkpoint-manager-items-1/')
mngr = ocp.CheckpointManager(directory)
mngr.save(
0,
args=ocp.args.Composite(
state=ocp.args.StandardSave(state),
extra_metadata=ocp.args.JsonSave(extra_metadata),
)
)
restored = mngr.restore(0)
print(restored.state)
print(restored.extra_metadata)
{'layer0': {'bias': array([11., 11., 11., 11.]), 'weight': Array([10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25], dtype=int32)}}
{'version': 1.0, 'step': 4}
For any given item, be it state, extra_metadata, the first arg used to save or restore a given item is then “locked in” and used for all subsequent saves and restores. This is what allows us to restore without specifying any arguments.
mngr.save(1, args=ocp.args.Composite(
state=ocp.args.StandardSave(state), extra_metadata=None))
restored = mngr.restore(1)
print(restored.state)
print(restored.extra_metadata)
{'layer0': {'bias': array([11., 11., 11., 11.]), 'weight': Array([10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25], dtype=int32)}}
None
We can also obtain metadata about our saved state, again without needing to specify any arguments.
meta = mngr.item_metadata(1)
print(meta.state)
mngr.close()
TreeMetadata(
custom_metadata=None
tree={'layer0': {'bias': ArrayMetadata : name=layer0.bias, directory=/tmp/checkpoint-manager-items-1/1/state, shape=(4,), sharding=None, dtype=float64, storage=StorageMetadata(chunk_shape=(4,), write_shape=None),, 'weight': ArrayMetadata : name=layer0.weight, directory=/tmp/checkpoint-manager-items-1/1/state, shape=(16,), sharding=SingleDeviceShardingMetadata(device_str=cpu:0), dtype=int32, storage=StorageMetadata(chunk_shape=(16,), write_shape=(16,)),}}
use_zarr3=False
)
However, if we create a new CheckpointManager and try to get metadata or restore, we will get an error because the CheckpointHandler for stateis not configured.item_metadata`, in contrast, does not raise an error, but returns None, so we have some indication that the item exists, but could not be reconstructed.
with ocp.CheckpointManager(directory) as mngr:
try:
print(mngr.restore(0))
except BaseException as e:
print(e)
print('')
print(mngr.item_metadata(0))
WARNING:absl:Item "extra_metadata" was found in the checkpoint, but could not be restored. Please provide a `CheckpointHandlerRegistry`, or call `restore` with an appropriate `CheckpointArgs` subclass.
WARNING:absl:Item "state" was found in the checkpoint, but could not be restored. Please provide a `CheckpointHandlerRegistry`, or call `restore` with an appropriate `CheckpointArgs` subclass.
WARNING:absl:Item "extra_metadata" was found in the checkpoint, but could not be restored. Please provide a `CheckpointHandlerRegistry`, or call `restore` with an appropriate `CheckpointArgs` subclass.
WARNING:absl:Item "state" was found in the checkpoint, but could not be restored. Please provide a `CheckpointHandlerRegistry`, or call `restore` with an appropriate `CheckpointArgs` subclass.
Composite({'extra_metadata': {'version': 1.0, 'step': 4}, 'state': {'layer0': {'bias': array([11., 11., 11., 11.]), 'weight': Array([10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25], dtype=int32)}}})
Composite({'extra_metadata': None, 'state': TreeMetadata(
custom_metadata=None
tree={'layer0': {'bias': ArrayMetadata : name=layer0.bias, directory=/tmp/checkpoint-manager-items-1/0/state, shape=(4,), sharding=None, dtype=float64, storage=StorageMetadata(chunk_shape=(4,), write_shape=None),, 'weight': ArrayMetadata : name=layer0.weight, directory=/tmp/checkpoint-manager-items-1/0/state, shape=(16,), sharding=SingleDeviceShardingMetadata(device_str=cpu:0), dtype=int32, storage=StorageMetadata(chunk_shape=(16,), write_shape=(16,)),}}
use_zarr3=False
)})
To fix this, we can pre-configure with a handler registry in order to specify the behavior that should be taken when restoring a particular item.
registry = ocp.handlers.DefaultCheckpointHandlerRegistry()
registry.add('state', ocp.args.StandardSave)
registry.add('state', ocp.args.StandardRestore)
with ocp.CheckpointManager(
directory,
handler_registry=registry,
) as mngr:
print(mngr.restore(0, args=ocp.args.Composite(state=None)))
print('')
print(mngr.item_metadata(0))
WARNING:absl:Item "extra_metadata" was found in the checkpoint, but could not be restored. Please provide a `CheckpointHandlerRegistry`, or call `restore` with an appropriate `CheckpointArgs` subclass.
WARNING:absl:Item "extra_metadata" was found in the checkpoint, but could not be restored. Please provide a `CheckpointHandlerRegistry`, or call `restore` with an appropriate `CheckpointArgs` subclass.
Composite({'state': {'layer0': {'bias': array([11., 11., 11., 11.]), 'weight': Array([10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25], dtype=int32)}}})
Composite({'extra_metadata': None, 'state': TreeMetadata(
custom_metadata=None
tree={'layer0': {'bias': ArrayMetadata : name=layer0.bias, directory=/tmp/checkpoint-manager-items-1/0/state, shape=(4,), sharding=None, dtype=float64, storage=StorageMetadata(chunk_shape=(4,), write_shape=None),, 'weight': ArrayMetadata : name=layer0.weight, directory=/tmp/checkpoint-manager-items-1/0/state, shape=(16,), sharding=SingleDeviceShardingMetadata(device_str=cpu:0), dtype=int32, storage=StorageMetadata(chunk_shape=(16,), write_shape=(16,)),}}
use_zarr3=False
)})
As previously mentioned, once we have “locked in” the type for an item, either through eager configuration with the registry, or lazy configuration by providing args, we cannot change the item type without reinitializing the CheckpointManager.
with ocp.CheckpointManager(
directory,
handler_registry=registry,
) as mngr:
mngr.save(2, args=ocp.args.PyTreeSave({'a': 'b'}))
try:
print(mngr.save(3, args=ocp.args.JsonSave({'a': 'b'})))
except BaseException as e:
print(e)
CheckpointHandler Layer#
The lowest-level API that users typically interact with in Orbax is the CheckpointHandler. Every CheckpointHandler is also paired with one or two CheckpointArgs objects which encapsulate all necessary and optional arguments that a user can provide when saving or restoring. At a high level CheckpointHandler exists to provide the logic required to save or restore a particular object in a checkpoint.
CheckpointHandler allows for synchronous saving. Subclasses of AsyncCheckpointHandler allow for asynchronous saving. (Restoration is always synchronous.)
Crucially a CheckpointHandler instance should not be used in isolation, but should always be used in conjunction with a Checkpointer. Otherwise, save operations will not be atomic and async operations cannot be waited upon. This means that in most cases, you will be working with Checkpointer APIs rather than CheckpointHandler APIs.
However, it is still essential to understand CheckpointHandler because you need to know how you want your object to be saved and restored, and what arguments are necessary to make that happen.
Let’s consider the example of StandardCheckpointHandler. This class is paired with StandardSave and StandardRestore.
StandardSave allows specifying the item argument, which is the PyTree to be saved using Tensorstore. It also includes save_args, which is an optional PyTree with a structure matching item. Each leaf is a ocp.type_handlers.SaveArgs object, which can be used to customize things like the dtype of the saved array.
StandardRestore only has one possible argument, the item, which is a PyTree of concrete or abstract arrays matching the structure of the checkpoint. This is optional, and the checkpoint will be restored exactly as saved if no argument is provided.
In general, other CheckpointHandlers may have other arguments, and the contract can be discerned by looking at the corresponding CheckpointArgs. Additionally, you can create your own implemenation of CheckpointHandler for your specific needs.
CompositeCheckpointHandler is a special case that allows composing multiple CheckpointHandlers at once. More details are provided throughout this page.