Checkpointing in a Training Loop#
This guide covers the usage of the training module, designed around the basic
concept of a training loop.
Note: We use the --xla_force_host_platform_device_count=8 flag to emulate multiple devices in our single-CPU environment.
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
from orbax.checkpoint import v1 as ocp
from etils import epath
training = ocp.training
Getting Started#
Let’s dive in with a simple training loop example.
We will use the Checkpointer API provided by the training module. The
Checkpointer must be configured with a root directory, which represents a
working directory where all checkpoints will be saved throughout the course of
an experiment.
The root directory is not itself a checkpoint; rather, it is a container of checkpoints.
root_directory = epath.Path('/tmp/training/my-checkpoints-1')
root_directory.rmtree(missing_ok=True)
We will assume the existence of a training state containing the keys params
and opt_state, which are trees of jax.Array. The state also contains a key
step, which is represented as an integer.
Note that the arrays in the state will be sharded using a fully-replicated sharding, but the example would work equally well with any other sharding.
import jax
import numpy as np
pytree = {
'params': {
'layer0': np.arange(16).reshape((8, 2)),
},
'opt_state': [np.arange(16)],
}
sharding = jax.sharding.NamedSharding(
jax.sharding.Mesh(jax.devices(), ('x',)), jax.sharding.PartitionSpec()
)
pytree = jax.tree.map(lambda x: jax.device_put(x, sharding), pytree)
pytree['step'] = 0
Let’s set up our fake training loop. We will add a “training step function” that just increments the step. In reality, this would also compute gradients and update model parameters.
def train_step(state):
state['step'] += 1
return state
Now, we can create a Checkpointer to begin saving a sequence of checkpoints.
with training.Checkpointer(root_directory) as ckptr:
num_steps = 10
for step in range(num_steps):
saved = ckptr.save(step, pytree)
assert saved
pytree = train_step(pytree)
Calling load with no arguments will automatically restore the latest saved
checkpoint.
with training.Checkpointer(root_directory) as ckptr:
print(ckptr.load())
{'opt_state': [Array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], dtype=int32)], 'params': {'layer0': Array([[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7],
[ 8, 9],
[10, 11],
[12, 13],
[14, 15]], dtype=int32)}, 'step': 9}
/home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages/orbax/checkpoint/_src/serialization/jax_array_handlers.py:736: UserWarning: Sharding info not provided when restoring. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.
warnings.warn(
Checkpointer APIs#
Now, let’s get into a bit more detail about how to interact with the
Checkpointer.
In general, we recommend using Checkpointer as a context manager, as shown in
the examples below.
with Checkpointer(...) as ckptr:
...
You can use it without the context manager, but make sure to call close()
before the program exits to ensure the completion of any outstanding operations
and to ensure resource cleanup.
ckptr = Checkpointer(...)
...
ckptr.close()
Saving#
Calling save in the training loop automatically calls should_save, which
determines whether or not a checkpoint should be saved at the given step, based
on the configured saving frequency. If a save is performed save returns
True; otherwise it returns False.
Whether or not a save should be performed can be controlled via
SaveDecisionPolicy.
By default, ContinuousCheckpointingPolicy is configured, which always saves
unless a save is already ongoing.
Other pre-configured policies include: - FixedIntervalPolicy: Saves every n
steps. - InitialSavePolicy: Saves on the first step. -
PreemptionCheckpointingPolicy: Saves on a step where a preemption signal is
received by the JAX distributed system. This is useful for saving whenever a job
is automatically restarted by the system. - SpecificStepsPolicy: Saves on the
specific set of configured steps.
The policies can be used in conjunction via AnySavePolicy, which performs a
save if any of the sub-policies would perform a save at the given step.
You may always implement your own policy. See SaveDecisionPolicy for details.
root_directory = epath.Path('/tmp/training/my-checkpoints-2')
root_directory.rmtree(missing_ok=True)
with training.Checkpointer(
root_directory,
save_decision_policy=training.save_decision_policies.FixedIntervalPolicy(3),
) as ckptr:
for step in range(10):
ckptr.save(step, pytree)
!ls {root_directory}
0 3 6 9
/home/docs/.asdf/installs/python/3.12.12/lib/python3.12/pty.py:95: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
pid, fd = os.forkpty()
Now let’s exercise some additional save features. These include:
custom_metadata: A JSON-formatted object intended for storing any user-specified properties. Custom metadata can be specified at both the root directory level and the individual checkpoint level. At the root level, the metadata should pertain to all checkpoints. For example, the experiment name is shared by all checkpoints within the root directory, while a property likeis_finalhas different values for different checkpoints.override: Deletes and overwrites any existing checkpoint at the provided step.force: Performs a save at the current step regardless of what would ordinarily be dictated by theSaveDecisionPolicy.metrics: A JSON-formatted object storing evaluation metrics for the current step. This can be useful for ordering and garbage collecting checkpoints; more on that below.
root_directory = epath.Path('/tmp/training/my-checkpoints-3')
root_directory.rmtree(missing_ok=True)
with training.Checkpointer(
root_directory,
save_decision_policy=training.save_decision_policies.FixedIntervalPolicy(3),
custom_metadata={'experiment_name': 'my-experiment'},
) as ckptr:
num_steps = 10
for step in range(num_steps):
is_final = step == num_steps - 1
ckptr.save(
step,
pytree,
metrics={'accuracy': 0.85},
custom_metadata={'is_final': is_final},
force=is_final,
)
!ls {root_directory}
0 3 6 9 metadata
We will learn more about how to access some of the attributes that we saved in the sections below.
Querying Available Checkpoints#
We can learn about which checkpoints are available by using latest and
checkpoints.
ckptr = training.Checkpointer(root_directory)
Each of these APIs returns CheckpointMetadata objects, which store a number of
properties describing each checkpoint. Some metadata properties are more
expensive to retrieve than others though. The latest and checkpoints APIs
just store a limited set of cheaply-retrievable properties, like the step.
These APIs also make use of caching as much as possible, to avoid repeated disk
reads.
# Returns CheckpointMetadata or None, if no checkpoints are found.
latest = ckptr.latest
assert latest is not None
print(latest.step)
print(latest)
9
CheckpointMetadata({'commit_timestamp_nsecs': '1779312267707646976',
'custom_metadata': 'None',
'init_timestamp_nsecs': 'None',
'metadata': 'None',
'metrics': "{'accuracy': 0.85}",
'step': '9'})
Inspecting Checkpoint Metadata#
In many cases, we wish to cheaply gain information about checkpoint properties
without loading the entire model. Using the metadata API, we can learn
about the tree structure of our PyTree, as well as information about each array
in the tree.
Like loading methods, metadata methods accept either no argument, or an argument representing the step to retrieve metadata for.
For example:
# Loads metadata from the latest checkpoint.
ckptr.metadata()
# Loads metadata corresponding to the first step.
ckptr.metadata(ckptr.checkpoints[0])
# Loads metadata from a specific integer step.
ckptr.metadata(3)
print()
WARNING:root:TensorStore data files not found in checkpoint path /tmp/training/my-checkpoints-3/9. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.
WARNING:root:TensorStore data files not found in checkpoint path /tmp/training/my-checkpoints-3/0. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.
WARNING:root:TensorStore data files not found in checkpoint path /tmp/training/my-checkpoints-3/3. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.
Let’s examine the output.
ckptr.metadata()
WARNING:root:TensorStore data files not found in checkpoint path /tmp/training/my-checkpoints-3/9. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.
CheckpointMetadata({'commit_timestamp_nsecs': '1779312267707647245',
'custom_metadata': "{'is_final': True}",
'init_timestamp_nsecs': '1779312267666798405',
'metadata': "{'opt_state': [ArrayMetadata(shape=(16,), dtype=dtype('int32'), "
'sharding_metadata=NamedShardingMetadata(shape=[8], '
"axis_names=['x'], axis_types=(Auto,), partition_spec=()) "
'device_mesh=DeviceMetadataMesh(mesh=[DeviceMetadata(id=0), '
'DeviceMetadata(id=1), DeviceMetadata(id=2), '
'DeviceMetadata(id=3), DeviceMetadata(id=4), '
'DeviceMetadata(id=5), DeviceMetadata(id=6), '
'DeviceMetadata(id=7)]), '
'storage_metadata=StorageMetadata(chunk_shape=(2,), '
"write_shape=(2,)))], 'params': {'layer0': "
"ArrayMetadata(shape=(8, 2), dtype=dtype('int32'), "
'sharding_metadata=NamedShardingMetadata(shape=[8], '
"axis_names=['x'], axis_types=(Auto,), partition_spec=()) "
'device_mesh=DeviceMetadataMesh(mesh=[DeviceMetadata(id=0), '
'DeviceMetadata(id=1), DeviceMetadata(id=2), '
'DeviceMetadata(id=3), DeviceMetadata(id=4), '
'DeviceMetadata(id=5), DeviceMetadata(id=6), '
'DeviceMetadata(id=7)]), '
'storage_metadata=StorageMetadata(chunk_shape=(1, 2), '
"write_shape=(1, 2)))}, 'step': 0}",
'metrics': "{'accuracy': 0.85}",
'step': '9'})
Let’s dig into a few specific fields. In particular, we can access
custom_metadata and metrics that were saved previously.
print(ckptr.metadata().metrics)
print(ckptr.metadata().custom_metadata)
WARNING:root:TensorStore data files not found in checkpoint path /tmp/training/my-checkpoints-3/9. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.
WARNING:root:TensorStore data files not found in checkpoint path /tmp/training/my-checkpoints-3/9. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.
{'accuracy': 0.85}
{'is_final': True}
Within the metadata object, there is another field called metadata. This
stores information specific to the structure of the object we saved. In this
case, it describes the structure of the PyTree and array properties.
import pprint
pprint.pprint(ckptr.metadata().metadata)
WARNING:root:TensorStore data files not found in checkpoint path /tmp/training/my-checkpoints-3/9. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.
{'opt_state': [ArrayMetadata(shape=(16,),
dtype=dtype('int32'),
sharding_metadata=NamedShardingMetadata(shape=[8], axis_names=['x'], axis_types=(Auto,), partition_spec=()) device_mesh=DeviceMetadataMesh(mesh=[DeviceMetadata(id=0), DeviceMetadata(id=1), DeviceMetadata(id=2), DeviceMetadata(id=3), DeviceMetadata(id=4), DeviceMetadata(id=5), DeviceMetadata(id=6), DeviceMetadata(id=7)]),
storage_metadata=StorageMetadata(chunk_shape=(2,),
write_shape=(2,)))],
'params': {'layer0': ArrayMetadata(shape=(8, 2),
dtype=dtype('int32'),
sharding_metadata=NamedShardingMetadata(shape=[8], axis_names=['x'], axis_types=(Auto,), partition_spec=()) device_mesh=DeviceMetadataMesh(mesh=[DeviceMetadata(id=0), DeviceMetadata(id=1), DeviceMetadata(id=2), DeviceMetadata(id=3), DeviceMetadata(id=4), DeviceMetadata(id=5), DeviceMetadata(id=6), DeviceMetadata(id=7)]),
storage_metadata=StorageMetadata(chunk_shape=(1,
2),
write_shape=(1,
2)))},
'step': 0}
Finally, we can also retrieve the root-level metadata. Recall that this metadata is intended to describe the entire sequence of checkpoints, rather than just a single checkpoint.
ckptr.root_metadata()
RootMetadata({'custom_metadata': "{'experiment_name': 'my-experiment'}",
'directory': '/tmp/training/my-checkpoints-3'})
Garbage Collection#
Garbage collection is important to avoid accumulating too many old checkpoints and running out of disk space.
To control this behavior, we have an object (fairly similar to
SaveDecisionPolicy) above, called PreservationPolicy. This class tells the
Checkpointer which checkpoints should be protected from garbage collection.
By default, the PreservationPolicy defaults to PreserveAll (no garbage
collection), because we do not want users to lose any valuable data. However,
for anything other than toy use cases,
you should make sure to configure a more restrictive PreservationPolicy.
Our Checkpointer below is implicitly configured with PreserveAll, so all 10
steps should be present at first.
root_directory = epath.Path('/tmp/training/my-checkpoints-gc')
root_directory.rmtree(missing_ok=True)
with training.Checkpointer(root_directory) as ckptr:
for step in range(10):
ckptr.save(step, pytree)
print([c.step for c in ckptr.checkpoints])
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
If we create a new Checkpointer with a new PreservationPolicy configured,
the same 10 checkpoints should still be present. Once we
save a new step, any steps indicated for cleanup by the policy will be removed.
with training.Checkpointer(
root_directory,
preservation_policy=training.preservation_policies.AnyPreservationPolicy([
training.preservation_policies.LatestN(2),
training.preservation_policies.EveryNSteps(4),
]),
) as ckptr:
print([c.step for c in ckptr.checkpoints])
assert ckptr.latest.step == 9
ckptr.save(10, pytree)
print([c.step for c in ckptr.checkpoints])
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
[0, 4, 8, 9, 10]
Typically, the latest n checkpoints are preserved (LatestN) along with
checkpoints at some regular, but longer interval (EveryNSteps or
EveryNSeconds). The latter can be useful for performing evals and maintaining
a record of the experiment’s progress.
Loading#
As we saw above with the metadata methods, we can load in a variety of ways.
# Loads from the latest checkpoint.
ckptr.load()
# Loads the first available checkpoint in the root directory.
ckptr.load(ckptr.checkpoints[0])
# Loads from a specific integer step.
ckptr.load(4)
print()
When dealing with PyTrees, particularly PyTrees with sharded jax.Array leaves,
it is important for any non-toy use cases to specify an “abstract PyTree” that
is used to guide restoration. Checkpoints are complicated objects. The abstract
PyTree acts as an assertion to verify that the checkpoint has structure you
expect and that arrays have the correct shapes.
The abstract PyTree can also be used to instruct Orbax how to load the PyTree.
The dtype property may be used to cast arrays, while the sharding property
is used to correctly place array shards on devices.
We should define an abstract tree with the same structure as the tree we originally saved. For the leaves, we specify different shardings than we originally saved with, and different dtypes as well, causing the loaded arrays to be cast and resharded when loading.
sharding = jax.sharding.NamedSharding(
jax.sharding.Mesh(jax.devices(), ('x',)), jax.sharding.PartitionSpec('x')
)
abstract_state = {
'params': {
'layer0': jax.ShapeDtypeStruct((8, 2), np.float32, sharding=sharding),
},
'opt_state': [jax.ShapeDtypeStruct((16,), np.float32, sharding=sharding)],
'step': 0,
}
ckptr.load(None, abstract_state)
{'opt_state': [Array([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.,
13., 14., 15.], dtype=float32)],
'params': {'layer0': Array([[ 0., 1.],
[ 2., 3.],
[ 4., 5.],
[ 6., 7.],
[ 8., 9.],
[10., 11.],
[12., 13.],
[14., 15.]], dtype=float32)},
'step': 10}
More details on working with PyTrees in such a manner can be found at Checkpointing PyTrees.
Checkpointables and Dataset Checkpointing#
Checkpointer supports the concept of checkpointables. See the documentation
on “Working with Checkpointables” for more information.
In simplified terms, a “checkpointable” refers to a distinct piece of the
overall checkpoint, which can be thought of as a bundle. The PyTree training
state is one such checkpointable. The dataset iterator is another. Checkpointing
the position of the dataset iterator can be useful to ensure training resumes
where we were interrupted not just for the model parameters, but for the data as
well.
We can see this concept in concrete terms using a Grain dataset iterator. See Grain documentation for more information. For our purposes, we can construct a toy dataset iterator.
import grain
dataset = iter(grain.MapDataset.range(30).batch(3).map(lambda x: x.tolist()))
pytree = {
'params': {
'layer0': np.arange(16).reshape((8, 2)),
},
'opt_state': [np.arange(16)],
}
sharding = jax.sharding.NamedSharding(
jax.sharding.Mesh(jax.devices(), ('x',)), jax.sharding.PartitionSpec()
)
pytree = jax.tree.map(lambda x: jax.device_put(x, sharding), pytree)
pytree['step'] = 0
def train_step(state, ds):
next(ds) # Advances the dataset iterator
state['step'] += 1
return state
We can save ten checkpoints in sequence, including the dataset iterator,
advancing the iterator once per step. At each step, the dataset iterator points
to [step*3, step*3+1, step*3+2].
root_directory = epath.Path('/tmp/training/my-checkpoints-4')
root_directory.rmtree(missing_ok=True)
num_steps = 10
with training.Checkpointer(root_directory) as ckptr:
for step in range(num_steps):
ckptr.save_checkpointables(step, dict(pytree=pytree, dataset=dataset))
pytree = train_step(pytree, dataset)
After loading at step 5, new_dataset points to position 5 of the iterator.
new_dataset = iter(
grain.MapDataset.range(30).batch(3).map(lambda x: x.tolist())
)
print(f'Initial position: {next(new_dataset)}')
with training.Checkpointer(root_directory) as ckptr:
ckptr.load_checkpointables(5, dict(pytree=None, dataset=new_dataset))
print(f'Loaded from checkpoint: {next(new_dataset)}')
Initial position: [0, 1, 2]
Loaded from checkpoint: [15, 16, 17]
It’s important to note that dataset loading is stateful. You need to instantiate
an iterator object, pass it to load_checkpointables, and the checkpoint state
will be restored into the iterator state of the dataset object.
Training with MNIST Data#
In the past examples, we’ve used incomplete data to demonstrate Orbax functionality. Here, it’s useful to demonstrate checkpointing during training loops involving real-world training data (though the example is still simplified for demonstration purposes).
We’ll define our own loss function and simulate a MNIST model training loop. These functions are pulled from the Flax docs.
import flax
from flax import nnx
import jax
import optax
from orbax.checkpoint.experimental.v1._src.training.model_helpers import DotReluDot # Don't rely on this module!
flax.config.update('flax_always_shard_variable', False)
def loss_fn(model: DotReluDot, batch: dict) -> tuple[jax.Array, jax.Array]:
logits = model(batch['image'])
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch['label']
).mean()
return loss, logits
@nnx.jit
def train_step(
model: DotReluDot,
optimizer: nnx.ModelAndOptimizer,
batch: dict,
):
grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(model, batch)
optimizer.update(grads)
It’s worth mentioning that DotReluDot is an example layer that subclasses nnx.Module and is capable of incorporating sharding information to accelerate the training process. Let’s recreate a sharded environment for sake of demonstration.
from jax.sharding import Mesh
import numpy as np
print(f'You have 8 “fake” JAX devices now: {jax.devices()}')
mesh = Mesh(
devices=np.array(jax.devices()).reshape(4, 2), # Can be customized to run across multiple devices
axis_names=('data', 'model'),
)
print(mesh)
You have 8 “fake” JAX devices now: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]
Mesh('data': 4, 'model': 2, axis_types=(Auto, Auto))
Let’s go ahead and import the MNIST dataset using the Grain library:
%%capture
from orbax.checkpoint.experimental.v1._src.training.model_helpers import create_dataset
batch_size = 32
train_ds = create_dataset('train', batch_size)
test_ds = create_dataset('test', batch_size)
WARNING:absl:Variant folder /home/docs/tensorflow_datasets/mnist/3.0.1 has no dataset_info.json
2026-05-20 21:24:36.889539: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
Before we start the training loop, let’s initialize an abstract version of our checkpoint state, without initializing any real array values. We do this for both the model and optimizer.
# Create an abstract model
abs_model = nnx.eval_shape(lambda: DotReluDot(1024, rngs=nnx.Rngs(0)))
abs_model_state = nnx.state(abs_model)
abs_model_state = jax.tree.map(
lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s),
abs_model_state,
nnx.get_named_sharding(abs_model_state, mesh),
)
# Create an abstract optimizer
abs_model_tmp = DotReluDot(1024, rngs=nnx.Rngs(0))
abs_optimizer = nnx.eval_shape(
lambda: nnx.ModelAndOptimizer(abs_model_tmp, optax.adamw(0.005, 0.9))
)
# Store the abstract model and optimizer as one object
abs_state = {
'params': abs_model_state,
'optimizer': nnx.state(abs_optimizer, nnx.optimizer.OptState),
}
Now, we can define our main train() function, throughout which we will demonstrate checkpointing. A couple notes:
We use
FixedIntervalPolicyso that our checkpoint is saved every 10 training steps.We use
nnx.state()to convert the model object (DotReluDot) and optimizer to a checkpointable PyTree, which can then be checkpointed withckptr.save()
When actually loading a checkpoint, we do the following:
If a checkpoint exists in our current checkpoints directory, we restore the latest one.
If no checkpoints exist in our directory and the user provides a path to another directory, we load the checkpoint saved at that path.
If no checkpoints have been saved, this indicates we’re entering the training loop for the first time, so we don’t restore a model or optimizer.
from typing import Callable
from orbax.checkpoint import v1 as ocp
training = ocp.training
root_directory = epath.Path('/tmp/training/my-checkpoints-5')
root_directory.rmtree(missing_ok=True)
train_steps = 100 # Limited number of training steps for speed.
learning_rate = 0.005
momentum = 0.9
model_depth = 1024
save_interval = 10
def init_or_restore(
ckptr: training.Checkpointer, abs_state: dict, ckpt_path: str | None
) -> tuple[DotReluDot, nnx.ModelAndOptimizer, int]:
model = DotReluDot(model_depth, rngs=nnx.Rngs(0))
optimizer = nnx.ModelAndOptimizer(model, optax.adamw(learning_rate, momentum))
if ckpt_path or ckptr.latest:
# If a checkpoint already exists, we restore it.
if ckptr.latest:
loaded_state = ckptr.load(abstract_state=abs_state)
else:
loaded_state = ocp.load(path=ckpt_path, abstract_state=abs_state)
# Update model and optimizer separately
nnx.update(model, loaded_state['params'])
nnx.update(optimizer, loaded_state['optimizer'])
last_step = loaded_state['optimizer']['step'].value
else:
last_step = 0
return model, optimizer, last_step
def train(stop_fn: Callable[[int], bool] = None, ckpt_path: str = None):
# If step_to_restore is provided, use that to load this specific checkpoint.
# Otherwise, load the latest checkpoint; if none exists, start from scratch.
with training.Checkpointer(
root_directory,
save_decision_policy=training.save_decision_policies.FixedIntervalPolicy(
save_interval
),
) as ckptr:
model, optimizer, last_step = init_or_restore(ckptr, abs_state, ckpt_path)
# Main training loop
with mesh:
for step, batch in enumerate(train_ds, start=last_step + 1):
if step >= train_steps or (stop_fn and stop_fn(step)):
break
train_step(model, optimizer, batch)
# Save the combined state
state = {
'params': nnx.state(model),
'optimizer': nnx.state(optimizer, nnx.optimizer.OptState),
}
ckptr.save(step, state)
We invoke the train() function. In doing so, we demonstrate a “failure” causing our main training loop to stop at step 52.
def simulated_failure(step):
return step == 52
train(stop_fn=simulated_failure)
latest = training.Checkpointer(root_directory).latest
assert latest is not None and latest.step == 50
!ls {root_directory}
/tmp/ipykernel_5148/1523164441.py:47: DeprecationWarning: `with mesh:` context manager has been deprecated. Please use `with jax.set_mesh(mesh):` instead.
with mesh:
10 20 30 40 50
/home/docs/.asdf/installs/python/3.12.12/lib/python3.12/pty.py:95: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
pid, fd = os.forkpty()
We want to recover from our last saved checkpoint (step 50), which is restored automatically. We now continue training until step 80, which we can consider a “planned” exit rather than a failure.
%%capture
train(stop_fn = lambda x: x > 80)
From our latest checkpoint (step 80), let’s assume we want to fine-tune the model. We restore from our latest checkpoint, and continue training until we reach the maximum number of steps.
latest = training.Checkpointer(root_directory).latest
assert latest is not None and latest.step == 80
assert (root_directory / '80').exists()
train(ckpt_path=root_directory / '80')
!ls {root_directory}
/tmp/ipykernel_5148/1523164441.py:29: DeprecationWarning: '.value' access is now deprecated. For Variable[Array] instances use:
variable[...]
For other Variable types use:
variable.get_value()
last_step = loaded_state['optimizer']['step'].value
/tmp/ipykernel_5148/1523164441.py:47: DeprecationWarning: `with mesh:` context manager has been deprecated. Please use `with jax.set_mesh(mesh):` instead.
with mesh:
10 20 30 40 50 60 70 80 90