Orbax v0 to v1 Migration Guide#

How to use v1 API to load checkpoints saved with v0 API#

v1 ocp.load_* API can load checkpoints saved with v0 API. But before discussing the details, let us first understand how saved checkpoints are laid out.

Checkpoint Layouts#

Checkpointables in subdirectories#

Most commonly, Orbax saves a checkpoint in a directory, which in turn contains subdirectories containing checkpointables (items).

e.g. The checkpoint in step_1234 contains checkpointables in subdirectories named as state and my_json_data.

root_dir/
    step_1234/
        _CHECKPOINT_METADATA
        state/
            _METADATA
            manifest.ocdbt
            ocdbt.process_0/
        pytree/
            _METADATA
            manifest.ocdbt
            ocdbt.process_0/
        my_json_data/
            my_data.json

A CheckpointManager pointing to root_dir/ saves checkpoints for each step in the above format.

Similarly, Checkpointer(CompositeCheckpointHandler) can save a checkpoint like step_1234/, though the directory can be arbitrary (not constrained to correspond to a specific step).

Let’s save a checkpoint with the V0 API to demonstrate.

# Save checkpoint with checkpointables in state and pytree subdirs.

from etils import epath
import numpy as np
from orbax import checkpoint as ocp_v0

root_dir = epath.Path('/tmp/migration/root_dir')
root_dir.rmtree(missing_ok=True)  # Clean up if it already exists.
data = {
  'params': np.ones(2),
}

args = ocp_v0.args.Composite(**{
  checkpointable_name: ocp_v0.args.StandardSave(data)
  for checkpointable_name in ['state', 'pytree']
})
with ocp_v0.CheckpointManager(root_dir) as mngr:
  step = 0
  mngr.save(step, args=args)
  
step_dir = root_dir / f'{step}'
!ls /tmp/migration/root_dir/0
_CHECKPOINT_METADATA  pytree  state
/home/docs/.asdf/installs/python/3.12.12/lib/python3.12/pty.py:95: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  pid, fd = os.forkpty()

A checkpoint stored in above layout can be loaded using ocp.load_checkpointables(...) function.

# Load all checkpointables from a directory where subdirs contain checkpointables.
import orbax.checkpoint.experimental.v1 as ocp

loaded = ocp.load_checkpointables(step_dir)
# Use the checkpointables.
state = loaded['state']
pytree = loaded['pytree']

print('state=', state)
print('pytree=', pytree)
state= {'params': array([1., 1.])}
pytree= {'params': array([1., 1.])}

Checkpoint in directory with no subdirectory#

Alternatively, users can save checkpoints directly to a directory without any checkpointables (subdirectory).

e.g. The following layout contains an pytree checkpoint without any names like state as above.

my_checkpoint/
        _CHECKPOINT_METADATA
        _METADATA
        manifest.ocdbt
        ocdbt.process_0/

v0 Checkpointer (without CompositeCheckpointHandler) can be used to save in such layouts.

# Save a checkpoint directly to a directory.

my_checkpoint_dir = epath.Path('/tmp/migration/custom_checkpoint/my_checkpoint')
my_checkpoint_dir.rmtree(missing_ok=True)

with ocp_v0.StandardCheckpointer() as checkpointer:
  checkpointer.save(my_checkpoint_dir, data)
!ls /tmp/migration/custom_checkpoint/my_checkpoint
_CHECKPOINT_METADATA  _METADATA  d  manifest.ocdbt  ocdbt.process_0
/home/docs/.asdf/installs/python/3.12.12/lib/python3.12/pty.py:95: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  pid, fd = os.forkpty()

An pytree checkpoint in the above layout can be loaded using ocp.load(...) function.

# Load a pytree from a directory with no checkpointables.

loaded = ocp.load(my_checkpoint_dir, checkpointable_name=None)
# Use the loaded pytree.
print('loaded=', loaded)
WARNING:root:TensorStore data files not found in checkpoint path /tmp/migration/custom_checkpoint/my_checkpoint. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.
loaded= {'params': array([1., 1.])}

Compatibility Matrix#

Loading pytree checkpoint with load(...)#

Restore API

Response

ocp.load(step_1234)

Loads PyTree under subdirectory, pytree

ocp.load(step_1234, checkpointable_name='pytree')

Loads PyTree under subdirectory, pytree

ocp.load(step_1234, checkpointable_name='state')

Loads PyTree under subdirectory, state

ocp.load(my_checkpoint, checkpointable_name=None)

Loads PyTree directly from my_checkpoint

Following calls will lead to error.

Restore API

Response

ocp.load(root_dir)

Error: expecting a subdir named pytree

ocp.load(root_dir, checkpointable_name='pytree')

Error: expecting a subdir named pytree

ocp.load(root_dir, checkpointable_name=None)

Error: expecting pytree metadata file

ocp.load(step_1234, checkpointable_name=None)

Error: expecting pytree metadata file

ocp.load(my_checkpoint)

Error: expecting a subdir named pytree

ocp.load(my_checkpoint, checkpointable_name='pytree')

Error: expecting a subdir named pytree

Loading checkpointables with load_checkpointables(...)#

Restore API

Response

ocp.load_checkpointables(step_1234)

Loads all checkpointables from respective subdirs

ocp.load_checkpointables(step_1234, dict(state=abstract_tree, my_json_data=None))

Loads state and my_json_data checkpointables from respective subdirs

Following calls will lead to error.

Restore API

Response

ocp.load_checkpointables(root_dir)

Error: suggesting to try a subdir instead

ocp.load_checkpointables(my_checkpoint)

Error: suggesting to use load instead

ocp.load_checkpointables(root_dir, dict(state=abstract_tree, pytree=abstract_tree))

Error: suggesting to try a subdir instead

ocp.load_checkpointables(my_checkpoint, dict(state=abstract_tree, pytree=abstract_tree))

Error: suggesting to use load instead

Migrating from v0 CheckpointManager to v1 Checkpointer#

If you were using v0 CheckpointManager in your training loop then switch to v1 Checkpointer.

Please consult the following table for complete list of compatible methods.

v0 CheckpointManager

v1 Checkpointer

directory

directory

all_steps(...)

checkpoints

latest_step()

latest

reload()

reload()

should_save(step)

should_save(step)

save(...)

save(...), save_checkpointables(...)

and save_*_async(...)

restore(...)

load(...), load_checkpointables(...)

and load_*_async(...)

item_metadata(step)

metadata(step),

checkpointables_metadata(step)

metrics(step)

metadata(step).metrics,

checkpointables_metadata(step).metrics

metadata(step)

metadata(step),

checkpointables_metadata(step)

metadata(None) or metadata()

root_metadata()

wait_until_finished

Call AsyncResponse.result()

returned from save_*_async(...) and load_*_async(...).

check_for_errors()

Call AsyncResponse.result()

returned from save_*_async(...) and load_*_async(...).

close()

close()

is_saving_in_progress()

is_saving_in_progress()

best_step()

checkpoints.metrics.best_step

reached_preemption(...)

Unsupported

delete(step)

Coming soon…