Using the Refactored CheckpointManager API

Using the Refactored CheckpointManager API#

As of orbax-checkpoint-0.5.0, several new APIs have been introduced at multiple different levels. The most significant change is to how users interact with CheckpointManager. This page shows a side-by-side comparison of the old and new APIs.

The legacy APIs is deprecated and will stop working after May 1st, 2024. Please ensure you are using the new style by then.

CheckpointManager.save(...) is now async by default. Make sure you call wait_until_finished if depending on a previous save being completed. Otherwise, the behavior can be disabled via the CheckpointManagerOptions.enable_async_checkpointing option.

For further information on how to use the new API, see the introductory tutorial and the API Overview.

import orbax.checkpoint as ocp
# Dummy PyTrees for simplicity.

# In reality, this would be a tree of np.ndarray or jax.Array.
pytree = {'a': 0}
# In reality, this would be a tree of jax.ShapeDtypeStruct (metadata
# for restoration).
abstract_pytree = {'a': 0}

extra_metadata = {'version': 1.0}

Single-Item Checkpointing#

Before#

options = ocp.CheckpointManagerOptions()
mngr = ocp.CheckpointManager(
  ocp.test_utils.erase_and_create_empty('/tmp/ckpt1/'),
  ocp.Checkpointer(ocp.PyTreeCheckpointHandler()),
  options=options,
)

restore_args = ocp.checkpoint_utils.construct_restore_args(abstract_pytree)
mngr.save(0, pytree)
mngr.wait_until_finished()

mngr.restore(
    0,
    items=abstract_pytree,
    restore_kwargs={'restore_args': restore_args}
)
WARNING:absl:Configured `CheckpointManager` using deprecated legacy API. Please follow the instructions at https://orbax.readthedocs.io/en/latest/api_refactor.html to migrate by August 1st, 2024.
{'a': 0}

After#

options = ocp.CheckpointManagerOptions()
with ocp.CheckpointManager(
  ocp.test_utils.erase_and_create_empty('/tmp/ckpt2/'),
  options=options,
) as mngr:

  mngr.save(0, args=ocp.args.StandardSave(pytree))
  mngr.wait_until_finished()

  # After providing `args` during an initial `save` or `restore` call, the
  # `CheckpointManager` instance records the type so that you do not need to
  # specify it again. If the `CheckpointManager` instance is not provided with a
  # `ocp.args.CheckpointArgs` instance for a particular item on a previous
  # occasion it cannot be restored without specifying the argument at restore
  # time.

  # In many cases, you can restore exactly as saved without specifying additional
  # arguments.
  mngr.restore(0)
  # If customization of properties like sharding or dtype is desired, just provide
  # the abstract target PyTree, the properties of which will be used to set
  # the properties of the restored arrays.
  mngr.restore(0, args=ocp.args.StandardRestore(abstract_pytree))
WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.

Important notes:

  • Don’t forget to use the keyword args=... for save and restore! Otherwise you will get the legacy API. This will not be necessary forever, but only until the legacy API is removed.

  • The value of args is a subclass of CheckpointArgs, present in the ocp.args module. These classes are used to communicate the logic that you wish to use to save and restore your object. For a typical PyTree consisting of arrays, use StandardSave/StandardRestore.

Let’s explore scenarios when restore() and item_metadata() calls raise errors due to unspecified CheckpointHandlers for item names.

CheckpointManager(..., item_handlers=...) can be used to resolve these scenarios.

# Unmapped CheckpointHandlers on a new CheckpointManager instance.
new_mngr = ocp.CheckpointManager('/tmp/ckpt2/', options=options)
new_mngr.restore(0)  # Raises error due to unmapped CheckpointHandler
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/tmp/ipykernel_712/2652070573.py in <module>
      1 # Unmapped CheckpointHandlers on a new CheckpointManager instance.
      2 new_mngr = ocp.CheckpointManager('/tmp/ckpt2/', options=options)
----> 3 new_mngr.restore(0)  # Raises error due to unmapped CheckpointHandler

~/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.9/site-packages/orbax/checkpoint/checkpoint_manager.py in restore(self, step, items, restore_kwargs, directory, args)
   1087 
   1088     restore_directory = self._get_read_step_directory(step, directory)
-> 1089     restored = self._checkpointer.restore(restore_directory, args=args)
   1090     if self._single_item:
   1091       return restored[DEFAULT_ITEM_NAME]

~/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.9/site-packages/orbax/checkpoint/async_checkpointer.py in restore(self, directory, *args, **kwargs)
    340     """See superclass documentation."""
    341     self.wait_until_finished()
--> 342     return super().restore(directory, *args, **kwargs)
    343 
    344   def check_for_errors(self):

~/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.9/site-packages/orbax/checkpoint/checkpointer.py in restore(self, directory, *args, **kwargs)
    167     logging.info('Restoring item from %s.', directory)
    168     ckpt_args = construct_checkpoint_args(self._handler, False, *args, **kwargs)
--> 169     restored = self._handler.restore(directory, args=ckpt_args)
    170     logging.info('Finished restoring checkpoint from %s.', directory)
    171     utils.sync_global_processes('Checkpointer:restore', self._active_processes)

~/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.9/site-packages/orbax/checkpoint/composite_checkpoint_handler.py in restore(self, directory, args)
    437           continue
    438         if handler is None:
--> 439           raise ValueError(
    440               f'Item with name: "{item_name}" had an undetermined'
    441               ' `CheckpointHandler` when restoring. Please ensure the handler'

ValueError: Item with name: "default" had an undetermined `CheckpointHandler` when restoring. Please ensure the handler was specified during initialization, or use the appropriate `CheckpointArgs` subclass to indicate the item type.
new_mngr.restore(0, args=ocp.args.StandardRestore(abstract_pytree))
{'a': 0}
new_mngr.close()
# item_handlers can be used as an alternative to restore(..., args=...).
with ocp.CheckpointManager(
    '/tmp/ckpt2/',
    options=options,
    item_handlers=ocp.StandardCheckpointHandler()
) as new_mngr:
  print(new_mngr.restore(0))
WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.
{'a': 0}

NOTE: CheckpointManager.item_metadata(step) doesn’t support any input like args in restore(..., args=...).

So, item_handlers is the only option available with item_metadata(step) calls.

# item_handlers becomes even more critical with item_metadata() calls.
new_mngr = ocp.CheckpointManager('/tmp/ckpt2/', options=options)
new_mngr.item_metadata(0)  # Raises error due to unmapped CheckpointHandler
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/tmp/ipykernel_712/484512995.py in <module>
      1 # item_handlers becomes even more critical with item_metadata() calls.
      2 new_mngr = ocp.CheckpointManager('/tmp/ckpt2/', options=options)
----> 3 new_mngr.item_metadata(0)  # Raises error due to unmapped CheckpointHandler

~/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.9/site-packages/orbax/checkpoint/checkpoint_manager.py in item_metadata(self, step)
   1105           items_missing_handlers.append(item_name)
   1106       if items_missing_handlers:
-> 1107         raise ValueError(
   1108             'No mapped CheckpointHandler found for items:'
   1109             f' {items_missing_handlers}. Please see documentation of'

ValueError: No mapped CheckpointHandler found for items: ['default']. Please see documentation of `item_handlers` in CheckpointManager.
new_mngr = ocp.CheckpointManager(
    '/tmp/ckpt2/',
    options=options,
    item_handlers=ocp.StandardCheckpointHandler(),
)
new_mngr.item_metadata(0)
{'a': ScalarMetadata(name='a', directory=PosixGPath('/tmp/ckpt2/0/default'), shape=(), sharding=None, dtype=dtype('int64'))}
new_mngr.close()

Multiple-Item Checkpointing#

Before#

options = ocp.CheckpointManagerOptions()
mngr = ocp.CheckpointManager(
  ocp.test_utils.erase_and_create_empty('/tmp/ckpt3/'),
  {
      'state': ocp.Checkpointer(ocp.PyTreeCheckpointHandler()),
      'extra_metadata': ocp.Checkpointer(ocp.JsonCheckpointHandler())
  },
  options=options,
)

restore_args = ocp.checkpoint_utils.construct_restore_args(abstract_pytree)
mngr.save(0, {'state': pytree, 'extra_metadata': extra_metadata})
mngr.wait_until_finished()

mngr.restore(
    0,
    items={'state': abstract_pytree, 'extra_metadata': None},
    restore_kwargs={
        'state': {'restore_args': restore_args},
        'extra_metadata': None
    },
)
WARNING:absl:Configured `CheckpointManager` using deprecated legacy API. Please follow the instructions at https://orbax.readthedocs.io/en/latest/api_refactor.html to migrate by August 1st, 2024.
CompositeArgs({'extra_metadata': {'version': 1.0}, 'state': {'a': 0}})

After#

options = ocp.CheckpointManagerOptions()
mngr = ocp.CheckpointManager(
  ocp.test_utils.erase_and_create_empty('/tmp/ckpt4/'),
  # `item_names` defines an up-front contract about what items the
  # CheckpointManager will be dealing with.
  item_names=('state', 'extra_metadata'),
  options=options,
)

mngr.save(0, args=ocp.args.Composite(
    state=ocp.args.StandardSave(pytree),
    extra_metadata=ocp.args.JsonSave(extra_metadata))
)
mngr.wait_until_finished()
# Restore as saved
mngr.restore(0)
WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.
CompositeArgs({'extra_metadata': {'version': 1.0}, 'state': {'a': 0}})
# Restore with customization. Restore a subset of items.
mngr.restore(0, args=ocp.args.Composite(
    state=ocp.args.StandardRestore(abstract_pytree)))
CompositeArgs({'state': {'a': 0}})
mngr.close()

Just like single item use case described above, let’s explore scenarios when restore() and item_metadata() calls raise errors due to unspecified CheckpointHandlers for item names.

CheckpointManager(..., item_handlers=...) can be used to resolve these scenarios.

# Unmapped CheckpointHandlers on a new CheckpointManager instance.
new_mngr = ocp.CheckpointManager(
    '/tmp/ckpt4/',
    options=options,
    item_names=('state', 'extra_metadata'),
)
new_mngr.restore(0)  # Raises error due to unmapped CheckpointHandlers
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/tmp/ipykernel_712/2106957155.py in <module>
      5     item_names=('state', 'extra_metadata'),
      6 )
----> 7 new_mngr.restore(0)  # Raises error due to unmapped CheckpointHandlers

~/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.9/site-packages/orbax/checkpoint/checkpoint_manager.py in restore(self, step, items, restore_kwargs, directory, args)
   1087 
   1088     restore_directory = self._get_read_step_directory(step, directory)
-> 1089     restored = self._checkpointer.restore(restore_directory, args=args)
   1090     if self._single_item:
   1091       return restored[DEFAULT_ITEM_NAME]

~/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.9/site-packages/orbax/checkpoint/async_checkpointer.py in restore(self, directory, *args, **kwargs)
    340     """See superclass documentation."""
    341     self.wait_until_finished()
--> 342     return super().restore(directory, *args, **kwargs)
    343 
    344   def check_for_errors(self):

~/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.9/site-packages/orbax/checkpoint/checkpointer.py in restore(self, directory, *args, **kwargs)
    167     logging.info('Restoring item from %s.', directory)
    168     ckpt_args = construct_checkpoint_args(self._handler, False, *args, **kwargs)
--> 169     restored = self._handler.restore(directory, args=ckpt_args)
    170     logging.info('Finished restoring checkpoint from %s.', directory)
    171     utils.sync_global_processes('Checkpointer:restore', self._active_processes)

~/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.9/site-packages/orbax/checkpoint/composite_checkpoint_handler.py in restore(self, directory, args)
    437           continue
    438         if handler is None:
--> 439           raise ValueError(
    440               f'Item with name: "{item_name}" had an undetermined'
    441               ' `CheckpointHandler` when restoring. Please ensure the handler'

ValueError: Item with name: "state" had an undetermined `CheckpointHandler` when restoring. Please ensure the handler was specified during initialization, or use the appropriate `CheckpointArgs` subclass to indicate the item type.
new_mngr.restore(
    0,
    args=ocp.args.Composite(
        state=ocp.args.StandardRestore(abstract_pytree),
        extra_metadata=ocp.args.JsonRestore(),
    ),
)
CompositeArgs({'extra_metadata': {'version': 1.0}, 'state': {'a': 0}})
new_mngr.close()
# item_handlers can be used as an alternative to restore(..., args=...).
with  ocp.CheckpointManager(
    '/tmp/ckpt4/',
    options=options,
    item_handlers={
        'state': ocp.StandardCheckpointHandler(),
        'extra_metadata': ocp.JsonCheckpointHandler(),
    },
) as new_mngr:
  print(new_mngr.restore(0))
WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.
CompositeArgs({'extra_metadata': {'version': 1.0}, 'state': {'a': 0}})

NOTE: CheckpointManager.item_metadata(step) doesn’t support any input like args in restore(..., args=...).

So, item_handlers is the only option available with item_metadata(step) calls.

# item_handlers becomes even more critical with item_metadata() calls.
with ocp.CheckpointManager(
    '/tmp/ckpt4/',
    options=options,
    item_names=('state', 'extra_metadata'),
) as new_mngr:
  new_mngr.item_metadata(0)  # Raises error due to unmapped CheckpointHandlers
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/tmp/ipykernel_712/2842306685.py in <module>
      5     item_names=('state', 'extra_metadata'),
      6 ) as new_mngr:
----> 7   new_mngr.item_metadata(0)  # Raises error due to unmapped CheckpointHandlers

~/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.9/site-packages/orbax/checkpoint/checkpoint_manager.py in item_metadata(self, step)
   1105           items_missing_handlers.append(item_name)
   1106       if items_missing_handlers:
-> 1107         raise ValueError(
   1108             'No mapped CheckpointHandler found for items:'
   1109             f' {items_missing_handlers}. Please see documentation of'

ValueError: No mapped CheckpointHandler found for items: ['state', 'extra_metadata']. Please see documentation of `item_handlers` in CheckpointManager.
with ocp.CheckpointManager(
    '/tmp/ckpt4/',
    options=options,
    item_handlers={
        'state': ocp.StandardCheckpointHandler(),
        'extra_metadata': ocp.JsonCheckpointHandler(),
    },
) as new_mngr:
  print(new_mngr.item_metadata(0))
CompositeArgs({'state': {'a': ScalarMetadata(name='a', directory=PosixGPath('/tmp/ckpt4/0/state'), shape=(), sharding=None, dtype=dtype('int64'))}, 'extra_metadata': None})