Preemption Tolerance

Preemption Tolerance#

Orbax provides some important features allowing users to recover quickly from an interruption or preemption in their main training job.

The first of these features is known variously as preemption checkpointing, on-demand checkpointing, or auto-checkpointing. When the training job receives a preemption signal, a checkpoint can automatically be saved.

The main advantage of this feature is that it allows users to shorten their training time when preemptions occur and waste fewer resources, since training can resume immediately from the most recent step.

Orbax takes advantage of JAX multihost_utils to detect preemptions.

The feature is enabled by default for users of CheckpointManager. Here is an example:

import orbax.checkpoint as ocp

mngr = ocp.CheckpointManager(
    '/tmp/mydir/',
    ocp.PyTreeCheckpointer(),
    ocp.CheckpointManagerOptions(save_interval_steps=4)
)

def train_step(s):
  return s

state = {'a': 1, 'b': 2}
start_step = 0
num_steps = 12
if mngr.latest_step() is not None:
  start_step = mngr.latest_step()
  state = mngr.restore(start_step)

for step in range(start_step, num_steps):
  state = train_step(state)
  mngr.save(step, state)
WARNING:absl:Configured `CheckpointManager` using deprecated legacy API. Please follow the instructions at https://orbax.readthedocs.io/en/latest/api_refactor.html to migrate by August 1st, 2024.

Checkpoints are saved at steps 0, 4, and 8. If, for example, a preemption had occurred at step 6, a checkpoint would be saved even though this step does align with the normal save interval. When restarting the program, the latest_step would be 6, and training could be resumed from that point without needing to go all the way back to step 4.

To further save resources, we can also exit immediately after the checkpoint is finished saving. This can save several minutes of wasted work, if there is a substantial grace period between the preemption signal received and forced termination of the program.

This can be accomplished with small modifications, depicted below. Importantly, if we are at a preemption step, we must wait for the checkpoint to finish writing before exiting. The specific details of the exit function depend on the system used to run the training job.

for step in range(start_step, num_steps):
  state = train_step(state)
  mngr.save(step, state)
  if mngr.reached_preemption(step):
    mngr.wait_until_finished()
    exit()