API Overview#
CheckpointHandler Layer#
The lowest-level API that users typically interact with in Orbax is the CheckpointHandler
. Every CheckpointHandler
is also paired with one or two CheckpointArgs
objects which encapsulate all necessary and optional arguments that a user can provide when saving or restoring. At a high level CheckpointHandler
exists to provide the logic required to save or restore a particular object in a checkpoint.
CheckpointHandler
allows for synchronous saving. Subclasses of AsyncCheckpointHandler
allow for asynchronous saving. (Restoration is always synchronous.)
Crucially a CheckpointHandler
instance should not be used in isolation, but should always be used in conjunction with a Checkpointer
(see below). Otherwise, save operations will not be atomic and async operations cannot be waited upon. This means that in most cases, you will be working with Checkpointer
APIs rather than CheckpointHandler
APIs.
However, it is still essential to understand CheckpointHandler
because you need to know how you want your object to be saved and restored, and what arguments are necessary to make that happen.
Let’s consider the example of StandardCheckpointHandler
. This class is paired with StandardSave
and StandardRestore
.
StandardSave
allows specifying the item
argument, which is the PyTree to be saved using Tensorstore. It also includes save_args
, which is an optional PyTree
with a structure matching item
. Each leave is a ocp.type_handlers.SaveArgs
object, which can be used to customize things like the dtype
of the saved array.
StandardRestore
only has one possible argument, the item
, which is a PyTree of concrete or abstract arrays matching the structure of the checkpoint. This is optional, and the checkpoint will be restored exactly as saved if no argument is provided.
In general, other CheckpointHandler
s may have other arguments, and the contract can be discerned by looking at the corresponding CheckpointArgs
. Additionally, CheckpointHandler
s can be customized for specific needs.
CompositeCheckpointHandler#
A special case of CheckpointHandler
the CompositeCheckpointHandler
. While CheckpointHandler
s are typically expected to deal with a single object, CompositeCheckpointHandler
is explicitly designed for delegating save/restore logic for multiple distinct objects to separate CheckpointHandler
s.
At minimum, CompositeCheckpointHandler
must be initialized with a series of item names, which are used to differentiate distinct items. In many cases, you do not need to manually specify the delegated CheckpointHandler
instance for a particular item up front. Here’s an example:
import orbax.checkpoint as ocp
from etils import epath
path = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/')
handler = ocp.CompositeCheckpointHandler('state', 'metadata', 'dataset')
Now, it will be possible to use this handler (in conjunction with a Checkpointer
!) to save and restore 3 distinct objects, named ‘state’, ‘metadata’, and ‘dataset’.
When we call save or restore, it is necessary to specify a CheckpointArgs
subclass for each item. This is used to infer the desired CheckpointHandler
. For example, if we specify StandardSave
, the object will get saved using StandardCheckpointHandler
. Per-item CheckpointArgs
must be wrapped in the CheckpointArgs
for CompositeCheckpointHandler
, which is ocp.args.Composite
.
state = {'layer0': {'bias': 0, 'weight': 1}}
metadata = {'version': 1.0}
dataset = {'my_data': 2}
checkpointer = ocp.Checkpointer(handler)
checkpointer.save(
path / 'composite_checkpoint',
ocp.args.Composite(
state=ocp.args.StandardSave(state),
metadata=ocp.args.JsonSave(metadata),
dataset=ocp.args.JsonSave(dataset),
)
)
When restoring, we can retrieve a subset of the items:
# Restore all items:
checkpointer.restore(path / 'composite_checkpoint')
WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.
CompositeArgs({'dataset': {'my_data': 2}, 'metadata': {'version': 1.0}, 'state': {'layer0': {'bias': 0, 'weight': 1}}})
# Restore some items, but not all:
checkpointer.restore(
path / 'composite_checkpoint',
ocp.args.Composite(
state=ocp.args.StandardRestore(),
)
)
WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.
CompositeArgs({'state': {'layer0': {'bias': 0, 'weight': 1}}})
# Restore some items, and specify optional arguments for restoration:
checkpointer.restore(
path / 'composite_checkpoint',
ocp.args.Composite(
state=ocp.args.StandardRestore(state),
)
)
CompositeArgs({'state': {'layer0': {'bias': 0, 'weight': 1}}})
As noted above, every CheckpointHandler
has one or two CheckpointArgs
subclasses which represent save and restore arguments. CompositeCheckpointHandler
is no exception. Both save and restore arguments are represented by args.Composite
, which is basically just a wrapper for other CheckpointArgs
passed to the the sub-handlers.
Similarly, the return value of restore
is also args.Composite
.
The args.Composite
class is basically just a key-value store similar to a dictionary.
Checkpointer Layer#
Conceptually, the Checkpointer
exists to work with a single checkpoint that exists at a single path. It is no frills (relative to CheckpointManager
) but guarantees atomicity and allows for asynchronous saving via AsyncCheckpointer
.
Async checkpointing provided via AsyncCheckpointer
can often help to realize significant resource savings and training speedups because write to disk happens in a background thread. See here for more details.
As mentioned above, a Checkpointer
is always combined with a CheckpointHandler
. You can think of the CheckpointHandler
as providing a configuration that tells the Checkpointer
what serialization logic to use to deal with a particular object.
ckptr = ocp.Checkpointer(ocp.JsonCheckpointHandler())
ckptr = ocp.AsyncCheckpointer(ocp.StandardCheckpointHandler())
Orbax provides some shorthand checkpointers, such as StandardCheckpointer
, which is just Checkpointer(StandardCheckpointHandler())
.
While most Checkpointer
/CheckpointHandler
pairs deal with a single object that is saved and restored, pairing a Checkpointer
with CompositeCheckpointHandler
allows dealing with multiple distinct objects at once (see above).
CheckpointManager Layer#
The most high-level API layer provided by Orbax is the CheckpointManager
. This is the API of choice for users dealing with a series of checkpoints denoted as steps in the context of a training run.
CheckpointManagerOptions
allows customizing the behavior of the CheckpointManager
along various dimensions. A partial list of important customization options is given below. See the API reference for a complete list.
save_interval_steps
: An interval at which to save checkpoints.max_to_keep
: Starts to delete checkpoints when more than this number are present. Depending on other settings, more checkpoints than this number may be present at any given time.step_format_fixed_length
: Formats with leadingn
digits. This can make visually examining the checkpoints in sorted order easier.cleanup_tmp_directories
: Automatically cleans up existing temporary/incomplete directories when theCheckpointManager
is created.read_only
: If True, then checkpoints save and delete are skipped. Restore works as usual.enable_async_checkpointing
: True by default. Be wary of turning off, as save performance may be significantly impacted.
If dealing with a single checkpointable object, like a train state, CheckpointManager
can be created as follows:
Note that CheckpointManager
always saves asynchronously, unless you set enable_async_checkpointing=False
in CheckpointManagerOptions
. Make sure to use wait_until_finished()
if you need to block until a save is complete.
import jax
directory = ocp.test_utils.erase_and_create_empty('/tmp/checkpoint-manager-single/')
options = ocp.CheckpointManagerOptions(
save_interval_steps=2,
max_to_keep=2,
# other options
)
mngr = ocp.CheckpointManager(
directory,
options=options,
)
num_steps = 5
state = {'layer0': {'bias': 0, 'weight': 1}}
def train_step(state):
return jax.tree_util.tree_map(lambda x: x + 1, state)
for step in range(num_steps):
state = train_step(state)
mngr.save(step, args=ocp.args.StandardSave(state))
mngr.wait_until_finished()
mngr.latest_step()
4
mngr.all_steps()
[2, 4]
mngr.restore(mngr.latest_step())
WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.
{'layer0': {'bias': 5, 'weight': 6}}
# Restore with additional arguments, like dtype or sharding.
target_state = {'layer0': {'bias': 0.0, 'weight': 0.0}}
mngr.restore(mngr.latest_step(), args=ocp.args.StandardRestore(target_state))
{'layer0': {'bias': 5, 'weight': 6}}
If we’re dealing with multiple items, we need to provide item_names
when configuring the CheckpointManager
. Internally, CheckpointManager
is using CompositeCheckpointHandler
, so the information above also applies here.
directory = ocp.test_utils.erase_and_create_empty('/tmp/checkpoint-manager-multiple/')
options = ocp.CheckpointManagerOptions(
save_interval_steps=2,
max_to_keep=2,
# other options
)
mngr = ocp.CheckpointManager(
directory,
options=options,
item_names=('state', 'extra_metadata'),
)
num_steps = 5
state = {'layer0': {'bias': 0, 'weight': 1}}
extra_metadata = {'version': 1.0, 'step': 0}
def train_step(step, _state, _extra_metadata):
return jax.tree_util.tree_map(lambda x: x + 1, _state), {**_extra_metadata, **{'step': step}}
for step in range(num_steps):
state, extra_metadata = train_step(step, state, extra_metadata)
mngr.save(
step,
args=ocp.args.Composite(
state=ocp.args.StandardSave(state),
extra_metadata=ocp.args.JsonSave(extra_metadata),
)
)
mngr.wait_until_finished()
# Restore exactly as saved
result = mngr.restore(mngr.latest_step())
WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.
result
CompositeArgs({'extra_metadata': {'version': 1.0, 'step': 4}, 'state': {'layer0': {'bias': 5, 'weight': 6}}})
result.state
{'layer0': {'bias': 5, 'weight': 6}}
result.extra_metadata
{'version': 1.0, 'step': 4}
# Skip `state` when restoring.
# Note that it is possible to provide `extra_metadata=None` because we already
# saved using `JsonSave`. This is internally cached, so we know it uses JSON
# logic to save and restore. If you had called `restore` without first calling
# `save`, however, it would have been necessary to provide
# `ocp.args.JsonRestore`.
mngr.restore(mngr.latest_step(), args=ocp.args.Composite(extra_metadata=None))
CompositeArgs({'extra_metadata': {'version': 1.0, 'step': 4}})
# Restore with additional arguments, like dtype or sharding.
target_state = {'layer0': {'bias': 0.0, 'weight': 0.0}}
mngr.restore(mngr.latest_step(), args=ocp.args.Composite(
state=ocp.args.StandardRestore(target_state), extra_metadata=None)
)
CompositeArgs({'extra_metadata': {'version': 1.0, 'step': 4}, 'state': {'layer0': {'bias': 5, 'weight': 6}}})
There are some scenarios when the mapping between items and respective CheckpointHandler
s need to be provided at the time of creating a CheckpointManager
instance.
CheckpointManager constructor argument, item_handlers
, enables to resolve those scenarios. Please see Using the Refactored CheckpointManager API for the details.