API Overview#

CheckpointHandler Layer#

The lowest-level API that users typically interact with in Orbax is the CheckpointHandler. Every CheckpointHandler is also paired with one or two CheckpointArgs objects which encapsulate all necessary and optional arguments that a user can provide when saving or restoring. At a high level CheckpointHandler exists to provide the logic required to save or restore a particular object in a checkpoint.

CheckpointHandler allows for synchronous saving. Subclasses of AsyncCheckpointHandler allow for asynchronous saving. (Restoration is always synchronous.)

Crucially a CheckpointHandler instance should not be used in isolation, but should always be used in conjunction with a Checkpointer (see below). Otherwise, save operations will not be atomic and async operations cannot be waited upon. This means that in most cases, you will be working with Checkpointer APIs rather than CheckpointHandler APIs.

However, it is still essential to understand CheckpointHandler because you need to know how you want your object to be saved and restored, and what arguments are necessary to make that happen.

Let’s consider the example of StandardCheckpointHandler. This class is paired with StandardSave and StandardRestore.

StandardSave allows specifying the item argument, which is the PyTree to be saved using Tensorstore. It also includes save_args, which is an optional PyTree with a structure matching item. Each leave is a ocp.type_handlers.SaveArgs object, which can be used to customize things like the dtype of the saved array.

StandardRestore only has one possible argument, the item, which is a PyTree of concrete or abstract arrays matching the structure of the checkpoint. This is optional, and the checkpoint will be restored exactly as saved if no argument is provided.

In general, other CheckpointHandlers may have other arguments, and the contract can be discerned by looking at the corresponding CheckpointArgs. Additionally, CheckpointHandlers can be customized for specific needs.

CompositeCheckpointHandler#

A special case of CheckpointHandler the CompositeCheckpointHandler. While CheckpointHandlers are typically expected to deal with a single object, CompositeCheckpointHandler is explicitly designed for delegating save/restore logic for multiple distinct objects to separate CheckpointHandlers.

At minimum, CompositeCheckpointHandler must be initialized with a series of item names, which are used to differentiate distinct items. In many cases, you do not need to manually specify the delegated CheckpointHandler instance for a particular item up front. Here’s an example:

import orbax.checkpoint as ocp
from etils import epath

path = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/')
handler = ocp.CompositeCheckpointHandler('state', 'metadata', 'dataset')

Now, it will be possible to use this handler (in conjunction with a Checkpointer!) to save and restore 3 distinct objects, named ‘state’, ‘metadata’, and ‘dataset’.

When we call save or restore, it is necessary to specify a CheckpointArgs subclass for each item. This is used to infer the desired CheckpointHandler. For example, if we specify StandardSave, the object will get saved using StandardCheckpointHandler. Per-item CheckpointArgs must be wrapped in the CheckpointArgs for CompositeCheckpointHandler, which is ocp.args.Composite.

state = {'layer0': {'bias': 0, 'weight': 1}}
metadata = {'version': 1.0}
dataset = {'my_data': 2}

checkpointer = ocp.Checkpointer(handler)
checkpointer.save(
    path / 'composite_checkpoint',
    ocp.args.Composite(
        state=ocp.args.StandardSave(state),
        metadata=ocp.args.JsonSave(metadata),
        dataset=ocp.args.JsonSave(dataset),
    )
)

When restoring, we can retrieve a subset of the items:

# Restore all items:
checkpointer.restore(path / 'composite_checkpoint')
WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.
CompositeArgs({'dataset': {'my_data': 2}, 'metadata': {'version': 1.0}, 'state': {'layer0': {'bias': 0, 'weight': 1}}})
# Restore some items, but not all:
checkpointer.restore(
    path / 'composite_checkpoint',
    ocp.args.Composite(
        state=ocp.args.StandardRestore(),
    )
)
WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.
CompositeArgs({'state': {'layer0': {'bias': 0, 'weight': 1}}})
# Restore some items, and specify optional arguments for restoration:
checkpointer.restore(
    path / 'composite_checkpoint',
    ocp.args.Composite(
        state=ocp.args.StandardRestore(state),
    )
)
CompositeArgs({'state': {'layer0': {'bias': 0, 'weight': 1}}})

As noted above, every CheckpointHandler has one or two CheckpointArgs subclasses which represent save and restore arguments. CompositeCheckpointHandler is no exception. Both save and restore arguments are represented by args.Composite, which is basically just a wrapper for other CheckpointArgs passed to the the sub-handlers.

Similarly, the return value of restore is also args.Composite.

The args.Composite class is basically just a key-value store similar to a dictionary.

Checkpointer Layer#

Conceptually, the Checkpointer exists to work with a single checkpoint that exists at a single path. It is no frills (relative to CheckpointManager) but guarantees atomicity and allows for asynchronous saving via AsyncCheckpointer.

Async checkpointing provided via AsyncCheckpointer can often help to realize significant resource savings and training speedups because write to disk happens in a background thread. See here for more details.

As mentioned above, a Checkpointer is always combined with a CheckpointHandler. You can think of the CheckpointHandler as providing a configuration that tells the Checkpointer what serialization logic to use to deal with a particular object.

ckptr = ocp.Checkpointer(ocp.JsonCheckpointHandler())
ckptr = ocp.AsyncCheckpointer(ocp.StandardCheckpointHandler())

Orbax provides some shorthand checkpointers, such as StandardCheckpointer, which is just Checkpointer(StandardCheckpointHandler()).

While most Checkpointer/CheckpointHandler pairs deal with a single object that is saved and restored, pairing a Checkpointer with CompositeCheckpointHandler allows dealing with multiple distinct objects at once (see above).

CheckpointManager Layer#

The most high-level API layer provided by Orbax is the CheckpointManager. This is the API of choice for users dealing with a series of checkpoints denoted as steps in the context of a training run.

CheckpointManagerOptions allows customizing the behavior of the CheckpointManager along various dimensions. A partial list of important customization options is given below. See the API reference for a complete list.

  • save_interval_steps: An interval at which to save checkpoints.

  • max_to_keep: Starts to delete checkpoints when more than this number are present. Depending on other settings, more checkpoints than this number may be present at any given time.

  • step_format_fixed_length: Formats with leading n digits. This can make visually examining the checkpoints in sorted order easier.

  • cleanup_tmp_directories: Automatically cleans up existing temporary/incomplete directories when the CheckpointManager is created.

  • read_only: If True, then checkpoints save and delete are skipped. Restore works as usual.

  • enable_async_checkpointing: True by default. Be wary of turning off, as save performance may be significantly impacted.

If dealing with a single checkpointable object, like a train state, CheckpointManager can be created as follows:

Note that CheckpointManager always saves asynchronously, unless you set enable_async_checkpointing=False in CheckpointManagerOptions. Make sure to use wait_until_finished() if you need to block until a save is complete.

import jax

directory = ocp.test_utils.erase_and_create_empty('/tmp/checkpoint-manager-single/')
options = ocp.CheckpointManagerOptions(
    save_interval_steps=2,
    max_to_keep=2,
    # other options
)
mngr = ocp.CheckpointManager(
    directory,
    options=options,
)
num_steps = 5
state = {'layer0': {'bias': 0, 'weight': 1}}

def train_step(state):
  return jax.tree_util.tree_map(lambda x: x + 1, state)

for step in range(num_steps):
  state = train_step(state)
  mngr.save(step, args=ocp.args.StandardSave(state))
mngr.wait_until_finished()
mngr.latest_step()
4
mngr.all_steps()
[2, 4]
mngr.restore(mngr.latest_step())
WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.
{'layer0': {'bias': 5, 'weight': 6}}
# Restore with additional arguments, like dtype or sharding.
target_state = {'layer0': {'bias': 0.0, 'weight': 0.0}}
mngr.restore(mngr.latest_step(), args=ocp.args.StandardRestore(target_state))
{'layer0': {'bias': 5, 'weight': 6}}

If we’re dealing with multiple items, we need to provide item_names when configuring the CheckpointManager. Internally, CheckpointManager is using CompositeCheckpointHandler, so the information above also applies here.

directory = ocp.test_utils.erase_and_create_empty('/tmp/checkpoint-manager-multiple/')
options = ocp.CheckpointManagerOptions(
    save_interval_steps=2,
    max_to_keep=2,
    # other options
)
mngr = ocp.CheckpointManager(
    directory,
    options=options,
    item_names=('state', 'extra_metadata'),
)
num_steps = 5
state = {'layer0': {'bias': 0, 'weight': 1}}
extra_metadata = {'version': 1.0, 'step': 0}

def train_step(step, _state, _extra_metadata):
  return jax.tree_util.tree_map(lambda x: x + 1, _state), {**_extra_metadata, **{'step': step}}

for step in range(num_steps):
  state, extra_metadata = train_step(step, state, extra_metadata)
  mngr.save(
      step,
      args=ocp.args.Composite(
        state=ocp.args.StandardSave(state),
        extra_metadata=ocp.args.JsonSave(extra_metadata),
      )
  )
mngr.wait_until_finished()
# Restore exactly as saved
result = mngr.restore(mngr.latest_step())
WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.
result
CompositeArgs({'extra_metadata': {'version': 1.0, 'step': 4}, 'state': {'layer0': {'bias': 5, 'weight': 6}}})
result.state
{'layer0': {'bias': 5, 'weight': 6}}
result.extra_metadata
{'version': 1.0, 'step': 4}
# Skip `state` when restoring.
# Note that it is possible to provide `extra_metadata=None` because we already
# saved using `JsonSave`. This is internally cached, so we know it uses JSON
# logic to save and restore. If you had called `restore` without first calling
# `save`, however, it would have been necessary to provide
# `ocp.args.JsonRestore`.
mngr.restore(mngr.latest_step(), args=ocp.args.Composite(extra_metadata=None))
CompositeArgs({'extra_metadata': {'version': 1.0, 'step': 4}})
# Restore with additional arguments, like dtype or sharding.
target_state = {'layer0': {'bias': 0.0, 'weight': 0.0}}
mngr.restore(mngr.latest_step(), args=ocp.args.Composite(
    state=ocp.args.StandardRestore(target_state), extra_metadata=None)
)
CompositeArgs({'extra_metadata': {'version': 1.0, 'step': 4}, 'state': {'layer0': {'bias': 5, 'weight': 6}}})

There are some scenarios when the mapping between items and respective CheckpointHandlers need to be provided at the time of creating a CheckpointManager instance.

CheckpointManager constructor argument, item_handlers, enables to resolve those scenarios. Please see Using the Refactored CheckpointManager API for the details.