Asynchronous Checkpointing

Asynchronous Checkpointing#

What is this?#

Orbax supports async checkpointing. This means that checkpoints can be saved in a background thread while training proceeds concurrently, leaving a minimum amount of time required for performing the blocking portion of the save.

Why should I care?#

Training jobs that would ordinarily spend time blocking for arrays to be written to disk, often via slow network connections, can proceed without waiting. This typically results in faster training progress. Furthermore, expensive devices like TPUs or GPUs which would have previously been left idle for the entire duration of the save are put to productive use for a higher proportion of the time during the training run.

Because the we only need to worry about the blocking portion of the save, checkpointing becomes significantly faster. Consider some concrete numbers:

  • On a 300M parameter model, saving time decreased by ~40%

  • On an 8B parameter model, saving time decreased by ~85%

  • On a 340B parameter model, saving time decreased by ~97%

In short, async checkpointing adoption is highly encouraged. It can result in improved training throughput and substantial resource savings.

Usage#

Some setup first:

from etils import epath
import numpy as np
import orbax.checkpoint.experimental.v1 as ocp

root_dir = epath.Path('/tmp/async_checkpointing')

train_state = {
    'layer0': {
        'kernel': np.ones((8, 8), dtype=np.float32),
        'bias': np.ones((8,), dtype=np.float32),
    }
}

Using async checkpointing is quite simple in Orbax. For blocking save, do something like this:

### PREFER NOT TO USE THIS. ###
### PREFER TO USE ASYNC CHECKPOINTING INSTEAD (SEE BELOW). ###

path = root_dir / 'sync'
path.rmtree(missing_ok=True)

ocp.save(path, train_state)
!ls /tmp/sync_checkpoint
_CHECKPOINT_METADATA  _METADATA  d  manifest.ocdbt  ocdbt.process_0
/home/docs/.asdf/installs/python/3.12.12/lib/python3.12/pty.py:95: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  pid, fd = os.forkpty()

For async save, simply use save_async(...) instead of save(...). Calling it will kick off the checkpoint save in a background thread, and return a response object without waiting for completion. At this point, other work can be performed in the main thread, and response.result() can be called to block until completion.

path = root_dir / 'async'
path.rmtree(missing_ok=True)

response = ocp.save_async(path, train_state)
### Do some other work...
response.result()
!ls /tmp/async_checkpoint
_CHECKPOINT_METADATA  _METADATA  d  manifest.ocdbt  ocdbt.process_0

To save multiple checkpointables together, Orbax provides free functions in both blocking and async flavors: save_checkpointables(...) and save_checkpointables_async(...).

And the same goes with training.Checkpointer class:

  • training.Checkpointer.save(...)

  • training.Checkpointer.save_async(...)

  • training.Checkpointer.save_checkpointables(...)

  • training.Checkpointer.save_checkpointables_async(...)

Additional Details#

From start to finish, async checkpointing for a train state of arrays works by first performing a blocking copy of the arrays from device to host. (If the array is already in memory, a copy will also be created in this case.) This step is necessary because the values cannot be written directly from device to storage. It also needs to be blocking because if training proceeds on the main thread, updates to the train state will result in the checkpoint being corrupted.

The examples shown above works well for PyTrees of jax.Arrays present on TPUs or GPUs. However, Orbax provides a more generalizable API allowing you to save any object asynchronously. In practice, custom async checkpointing logic can be implemented with CheckpointableHandler.