Asynchronous Checkpointing

Asynchronous Checkpointing#

What is this?#

Orbax supports async checkpointing. This means that checkpoints can be saved in a background thread while training proceeds concurrently, leaving a minimum amount of time required for performing the blocking portion of the save.

Why should I care?#

Training jobs that would ordinarily spend time blocking for arrays to be written to disk, often via slow network connections, can proceed without waiting. This typically results in faster training progress. Furthermore, expensive devices like TPUs or GPUs which would have previously been left idle for the entire duration of the save are put to productive use for a higher proportion of the time during the training run.

Because the we only need to worry about the blocking portion of the save, checkpointing becomes significantly faster. Consider some concrete numbers:

  • On a 300M parameter model, saving time decreased by ~40%

  • On an 8B parameter model, saving time decreased by ~85%

  • On a 340B parameter model, saving time decreased by ~97%

In short, async checkpointing adoption is highly encouraged. It can result in improved training throughput and substantial resource savings.

Usage#

Some setup first:

import numpy as np
import orbax.checkpoint as ocp
from etils import epath

train_state = {
    'layer0': {
        'kernel': np.ones((8, 8), dtype=np.float32),
        'bias': np.ones((8,), dtype=np.float32),
    }
}

Using async checkpointing is quite simple in Orbax. Before, we would do something like this:

### PREFER NOT TO USE THIS. ###
### PREFER TO USE ASYNC CHECKPOINTING INSTEAD (SEE BELOW). ###

path = epath.Path('/tmp/sync_checkpoint')
ckptr = ocp.Checkpointer(ocp.StandardCheckpointHandler())
ckptr.save(path, args=ocp.args.StandardSave(train_state))

Now we can simply use AsyncCheckpointer instead of Checkpointer. Calling save will kick off the checkpoint save in a background thread, and return without waiting for completion. At this point, other work can be performed in the main thread, and wait_until_finished can be called to block until completion. Importantly, the AsyncCheckpointer must remain alive for the duration of the save.

path = epath.Path('/tmp/async_checkpoint')
ckptr = ocp.AsyncCheckpointer(ocp.StandardCheckpointHandler())
ckptr.save(path, args=ocp.args.StandardSave(train_state))
### Do some other work...
ckptr.wait_until_finished()

We can do something similar if we’re using CheckpointManager:

path = epath.Path('/tmp/async_checkpoint_manager')
ckpt_mngr = ocp.CheckpointManager(path)

def train_step(step, state):
  # update state values accordingly
  return step + 1, state

step = 0
num_steps = 5
while step < num_steps:
  ckpt_mngr.save(step, args=ocp.args.StandardSave(train_state))
  step, train_state = train_step(step, train_state)

ckpt_mngr.wait_until_finished()
ckpt_mngr.all_steps()
[0, 1, 2, 3, 4]

Note that calling save when using an AsyncCheckpointer will automatically call wait_until_finished before starting a new save, so that any writes that are still in progress will be completed first.

Async save behavior in CheckpointManager can be switched off by using the following:

ocp.CheckpointManagerOptions(enable_async_checkpointing=False)
CheckpointManagerOptions(save_interval_steps=1, max_to_keep=None, keep_time_interval=None, keep_period=None, best_fn=None, best_mode='max', keep_checkpoints_without_metrics=True, step_prefix=None, step_format_fixed_length=None, step_name_format=None, create=True, cleanup_tmp_directories=False, save_on_steps=frozenset(), single_host_load_and_broadcast=False, todelete_subdir=None, enable_background_delete=False, read_only=False, enable_async_checkpointing=False, async_options=None, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None))

Additional Details#

From start to finish, async checkpointing for a train state of arrays works by first performing a blocking copy of the arrays from device to host. (If the array is already in memory, a copy will also be created in this case.) This step is necessary because the values cannot be written directly from device to storage. It also needs to be blocking because if training proceeds on the main thread, updates to the train state will result in the checkpoint being corrupted.

Once the copy completes (and any other less significant blocking operations), a series of futures are returned to AsyncCheckpointer by the CheckpointHandler. AsyncCheckpointer then starts a background thread to wait on these futures (which are already running).

The examples shown above works well for PyTrees of jax.Arrays present on TPUs or GPUs. However, Orbax provides a more generalizable API allowing you to save any object asynchronously. In practice, custom async checkpointing logic can be implemented with AsyncCheckpointHandler. Also check out our guide on custom CheckpointHandlers for further details.