Orbax v0 to v1 Migration Guide#
How to use v1 API to load checkpoints saved with v0 API#
v1 ocp.load_* API can load checkpoints saved with v0 API. But before discussing the details, let us first understand how saved checkpoints are laid out.
Checkpoint Layouts#
Checkpointables in subdirectories#
Most commonly, Orbax saves a checkpoint in a directory, which in turn contains subdirectories containing checkpointables (items).
e.g. The checkpoint in step_1234 contains checkpointables in subdirectories named as state and my_json_data.
root_dir/
step_1234/
_CHECKPOINT_METADATA
state/
_METADATA
manifest.ocdbt
ocdbt.process_0/
pytree/
_METADATA
manifest.ocdbt
ocdbt.process_0/
my_json_data/
my_data.json
A CheckpointManager pointing to root_dir/ saves checkpoints for each step in the above format.
Similarly, Checkpointer(CompositeCheckpointHandler) can save a checkpoint like step_1234/, though the directory can be arbitrary (not constrained to correspond to a specific step).
Let’s save a checkpoint with the V0 API to demonstrate.
# Save checkpoint with checkpointables in state and pytree subdirs.
from etils import epath
import numpy as np
from orbax import checkpoint as ocp_v0
root_dir = epath.Path('/tmp/migration/root_dir')
root_dir.rmtree(missing_ok=True) # Clean up if it already exists.
data = {
'params': np.ones(2),
}
args = ocp_v0.args.Composite(**{
checkpointable_name: ocp_v0.args.StandardSave(data)
for checkpointable_name in ['state', 'pytree']
})
with ocp_v0.CheckpointManager(root_dir) as mngr:
step = 0
mngr.save(step, args=args)
step_dir = root_dir / f'{step}'
!ls /tmp/migration/root_dir/0
_CHECKPOINT_METADATA pytree state
/home/docs/.asdf/installs/python/3.12.12/lib/python3.12/pty.py:95: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
pid, fd = os.forkpty()
A checkpoint stored in above layout can be loaded using ocp.load_checkpointables(...) function.
# Load all checkpointables from a directory where subdirs contain checkpointables.
import orbax.checkpoint.experimental.v1 as ocp
loaded = ocp.load_checkpointables(step_dir)
# Use the checkpointables.
state = loaded['state']
pytree = loaded['pytree']
print('state=', state)
print('pytree=', pytree)
state= {'params': array([1., 1.])}
pytree= {'params': array([1., 1.])}
Checkpoint in directory with no subdirectory#
Alternatively, users can save checkpoints directly to a directory without any checkpointables (subdirectory).
e.g. The following layout contains an pytree checkpoint without any names like state as above.
my_checkpoint/
_CHECKPOINT_METADATA
_METADATA
manifest.ocdbt
ocdbt.process_0/
v0 Checkpointer (without CompositeCheckpointHandler) can be used to save in such layouts.
# Save a checkpoint directly to a directory.
my_checkpoint_dir = epath.Path('/tmp/migration/custom_checkpoint/my_checkpoint')
my_checkpoint_dir.rmtree(missing_ok=True)
with ocp_v0.StandardCheckpointer() as checkpointer:
checkpointer.save(my_checkpoint_dir, data)
!ls /tmp/migration/custom_checkpoint/my_checkpoint
_CHECKPOINT_METADATA _METADATA d manifest.ocdbt ocdbt.process_0
/home/docs/.asdf/installs/python/3.12.12/lib/python3.12/pty.py:95: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
pid, fd = os.forkpty()
An pytree checkpoint in the above layout can be loaded using ocp.load(...) function.
# Load a pytree from a directory with no checkpointables.
loaded = ocp.load(my_checkpoint_dir, checkpointable_name=None)
# Use the loaded pytree.
print('loaded=', loaded)
WARNING:root:TensorStore data files not found in checkpoint path /tmp/migration/custom_checkpoint/my_checkpoint. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.
loaded= {'params': array([1., 1.])}
Compatibility Matrix#
Loading pytree checkpoint with load(...)#
Restore API |
Response |
|---|---|
ocp.load( |
Loads PyTree under subdirectory, |
ocp.load( |
Loads PyTree under subdirectory, |
ocp.load( |
Loads PyTree under subdirectory, |
ocp.load( |
Loads PyTree directly from |
Following calls will lead to error.
Restore API |
Response |
|---|---|
ocp.load( |
Error: expecting a subdir named |
ocp.load( |
Error: expecting a subdir named |
ocp.load( |
Error: expecting pytree metadata file |
ocp.load( |
Error: expecting pytree metadata file |
ocp.load( |
Error: expecting a subdir named |
ocp.load( |
Error: expecting a subdir named |
Loading checkpointables with load_checkpointables(...)#
Restore API |
Response |
|---|---|
ocp.load_checkpointables( |
Loads all checkpointables from respective subdirs |
ocp.load_checkpointables( |
Loads |
Following calls will lead to error.
Restore API |
Response |
|---|---|
ocp.load_checkpointables( |
Error: suggesting to try a subdir instead |
ocp.load_checkpointables( |
Error: suggesting to use load instead |
ocp.load_checkpointables( |
Error: suggesting to try a subdir instead |
ocp.load_checkpointables( |
Error: suggesting to use load instead |
Migrating from v0 CheckpointManager to v1 Checkpointer#
If you were using v0 CheckpointManager in your training loop then switch to v1 Checkpointer.
Please consult the following table for complete list of compatible methods.
v0 CheckpointManager |
v1 Checkpointer |
|---|---|
|
|
|
|
|
|
|
|
|
|
|
|
and |
|
|
|
and |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Call |
returned from |
|
|
Call |
returned from |
|
|
|
|
|
|
|
|
Unsupported |
|
Coming soon… |