Using the Refactored CheckpointManager API#
As of orbax-checkpoint-0.5.0
, several new APIs have been introduced at multiple different levels. The most significant change is to how users interact with CheckpointManager
. This page shows a side-by-side comparison of the old and new APIs.
The legacy APIs is deprecated and will stop working after May 1st, 2024. Please ensure you are using the new style by then.
CheckpointManager.save(...)
is now async by default. Make sure you call wait_until_finished
if depending on a previous save being completed. Otherwise, the behavior can be disabled via the
CheckpointManagerOptions.enable_async_checkpointing
option.
For further information on how to use the new API, see the introductory tutorial and the API Overview.
import orbax.checkpoint as ocp
# Dummy PyTrees for simplicity.
# In reality, this would be a tree of np.ndarray or jax.Array.
pytree = {'a': 0}
# In reality, this would be a tree of jax.ShapeDtypeStruct (metadata
# for restoration).
abstract_pytree = {'a': 0}
extra_metadata = {'version': 1.0}
Single-Item Checkpointing#
Before#
options = ocp.CheckpointManagerOptions()
mngr = ocp.CheckpointManager(
ocp.test_utils.erase_and_create_empty('/tmp/ckpt1/'),
ocp.Checkpointer(ocp.PyTreeCheckpointHandler()),
options=options,
)
restore_args = ocp.checkpoint_utils.construct_restore_args(abstract_pytree)
mngr.save(0, pytree)
mngr.wait_until_finished()
mngr.restore(
0,
items=abstract_pytree,
restore_kwargs={'restore_args': restore_args}
)
WARNING:absl:Configured `CheckpointManager` using deprecated legacy API. Please follow the instructions at https://orbax.readthedocs.io/en/latest/api_refactor.html to migrate by August 1st, 2024.
{'a': 0}
After#
options = ocp.CheckpointManagerOptions()
with ocp.CheckpointManager(
ocp.test_utils.erase_and_create_empty('/tmp/ckpt2/'),
options=options,
) as mngr:
mngr.save(0, args=ocp.args.StandardSave(pytree))
mngr.wait_until_finished()
# After providing `args` during an initial `save` or `restore` call, the
# `CheckpointManager` instance records the type so that you do not need to
# specify it again. If the `CheckpointManager` instance is not provided with a
# `ocp.args.CheckpointArgs` instance for a particular item on a previous
# occasion it cannot be restored without specifying the argument at restore
# time.
# In many cases, you can restore exactly as saved without specifying additional
# arguments.
mngr.restore(0)
# If customization of properties like sharding or dtype is desired, just provide
# the abstract target PyTree, the properties of which will be used to set
# the properties of the restored arrays.
mngr.restore(0, args=ocp.args.StandardRestore(abstract_pytree))
WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.
Important notes:
Don’t forget to use the keyword
args=...
for save and restore! Otherwise you will get the legacy API. This will not be necessary forever, but only until the legacy API is removed.The value of
args
is a subclass ofCheckpointArgs
, present in theocp.args
module. These classes are used to communicate the logic that you wish to use to save and restore your object. For a typical PyTree consisting of arrays, useStandardSave
/StandardRestore
.
Let’s explore scenarios when restore()
and item_metadata()
calls raise errors due to unspecified CheckpointHandlers for item names.
CheckpointManager(..., item_handlers=...)
can be used to resolve these scenarios.
# Unmapped CheckpointHandlers on a new CheckpointManager instance.
new_mngr = ocp.CheckpointManager('/tmp/ckpt2/', options=options)
new_mngr.restore(0) # Raises error due to unmapped CheckpointHandler
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/tmp/ipykernel_712/2652070573.py in <module>
1 # Unmapped CheckpointHandlers on a new CheckpointManager instance.
2 new_mngr = ocp.CheckpointManager('/tmp/ckpt2/', options=options)
----> 3 new_mngr.restore(0) # Raises error due to unmapped CheckpointHandler
~/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.9/site-packages/orbax/checkpoint/checkpoint_manager.py in restore(self, step, items, restore_kwargs, directory, args)
1087
1088 restore_directory = self._get_read_step_directory(step, directory)
-> 1089 restored = self._checkpointer.restore(restore_directory, args=args)
1090 if self._single_item:
1091 return restored[DEFAULT_ITEM_NAME]
~/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.9/site-packages/orbax/checkpoint/async_checkpointer.py in restore(self, directory, *args, **kwargs)
340 """See superclass documentation."""
341 self.wait_until_finished()
--> 342 return super().restore(directory, *args, **kwargs)
343
344 def check_for_errors(self):
~/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.9/site-packages/orbax/checkpoint/checkpointer.py in restore(self, directory, *args, **kwargs)
167 logging.info('Restoring item from %s.', directory)
168 ckpt_args = construct_checkpoint_args(self._handler, False, *args, **kwargs)
--> 169 restored = self._handler.restore(directory, args=ckpt_args)
170 logging.info('Finished restoring checkpoint from %s.', directory)
171 utils.sync_global_processes('Checkpointer:restore', self._active_processes)
~/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.9/site-packages/orbax/checkpoint/composite_checkpoint_handler.py in restore(self, directory, args)
437 continue
438 if handler is None:
--> 439 raise ValueError(
440 f'Item with name: "{item_name}" had an undetermined'
441 ' `CheckpointHandler` when restoring. Please ensure the handler'
ValueError: Item with name: "default" had an undetermined `CheckpointHandler` when restoring. Please ensure the handler was specified during initialization, or use the appropriate `CheckpointArgs` subclass to indicate the item type.
new_mngr.restore(0, args=ocp.args.StandardRestore(abstract_pytree))
{'a': 0}
new_mngr.close()
# item_handlers can be used as an alternative to restore(..., args=...).
with ocp.CheckpointManager(
'/tmp/ckpt2/',
options=options,
item_handlers=ocp.StandardCheckpointHandler()
) as new_mngr:
print(new_mngr.restore(0))
WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.
{'a': 0}
NOTE:
CheckpointManager.item_metadata(step)
doesn’t support any input like args
in restore(..., args=...)
.
So, item_handlers
is the only option available with item_metadata(step)
calls.
# item_handlers becomes even more critical with item_metadata() calls.
new_mngr = ocp.CheckpointManager('/tmp/ckpt2/', options=options)
new_mngr.item_metadata(0) # Raises error due to unmapped CheckpointHandler
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/tmp/ipykernel_712/484512995.py in <module>
1 # item_handlers becomes even more critical with item_metadata() calls.
2 new_mngr = ocp.CheckpointManager('/tmp/ckpt2/', options=options)
----> 3 new_mngr.item_metadata(0) # Raises error due to unmapped CheckpointHandler
~/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.9/site-packages/orbax/checkpoint/checkpoint_manager.py in item_metadata(self, step)
1105 items_missing_handlers.append(item_name)
1106 if items_missing_handlers:
-> 1107 raise ValueError(
1108 'No mapped CheckpointHandler found for items:'
1109 f' {items_missing_handlers}. Please see documentation of'
ValueError: No mapped CheckpointHandler found for items: ['default']. Please see documentation of `item_handlers` in CheckpointManager.
new_mngr = ocp.CheckpointManager(
'/tmp/ckpt2/',
options=options,
item_handlers=ocp.StandardCheckpointHandler(),
)
new_mngr.item_metadata(0)
{'a': ScalarMetadata(name='a', directory=PosixGPath('/tmp/ckpt2/0/default'), shape=(), sharding=None, dtype=dtype('int64'))}
new_mngr.close()
Multiple-Item Checkpointing#
Before#
options = ocp.CheckpointManagerOptions()
mngr = ocp.CheckpointManager(
ocp.test_utils.erase_and_create_empty('/tmp/ckpt3/'),
{
'state': ocp.Checkpointer(ocp.PyTreeCheckpointHandler()),
'extra_metadata': ocp.Checkpointer(ocp.JsonCheckpointHandler())
},
options=options,
)
restore_args = ocp.checkpoint_utils.construct_restore_args(abstract_pytree)
mngr.save(0, {'state': pytree, 'extra_metadata': extra_metadata})
mngr.wait_until_finished()
mngr.restore(
0,
items={'state': abstract_pytree, 'extra_metadata': None},
restore_kwargs={
'state': {'restore_args': restore_args},
'extra_metadata': None
},
)
WARNING:absl:Configured `CheckpointManager` using deprecated legacy API. Please follow the instructions at https://orbax.readthedocs.io/en/latest/api_refactor.html to migrate by August 1st, 2024.
CompositeArgs({'extra_metadata': {'version': 1.0}, 'state': {'a': 0}})
After#
options = ocp.CheckpointManagerOptions()
mngr = ocp.CheckpointManager(
ocp.test_utils.erase_and_create_empty('/tmp/ckpt4/'),
# `item_names` defines an up-front contract about what items the
# CheckpointManager will be dealing with.
item_names=('state', 'extra_metadata'),
options=options,
)
mngr.save(0, args=ocp.args.Composite(
state=ocp.args.StandardSave(pytree),
extra_metadata=ocp.args.JsonSave(extra_metadata))
)
mngr.wait_until_finished()
# Restore as saved
mngr.restore(0)
WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.
CompositeArgs({'extra_metadata': {'version': 1.0}, 'state': {'a': 0}})
# Restore with customization. Restore a subset of items.
mngr.restore(0, args=ocp.args.Composite(
state=ocp.args.StandardRestore(abstract_pytree)))
CompositeArgs({'state': {'a': 0}})
mngr.close()
Just like single item use case described above, let’s explore scenarios when restore()
and item_metadata()
calls raise errors due to unspecified CheckpointHandlers for item names.
CheckpointManager(..., item_handlers=...)
can be used to resolve these scenarios.
# Unmapped CheckpointHandlers on a new CheckpointManager instance.
new_mngr = ocp.CheckpointManager(
'/tmp/ckpt4/',
options=options,
item_names=('state', 'extra_metadata'),
)
new_mngr.restore(0) # Raises error due to unmapped CheckpointHandlers
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/tmp/ipykernel_712/2106957155.py in <module>
5 item_names=('state', 'extra_metadata'),
6 )
----> 7 new_mngr.restore(0) # Raises error due to unmapped CheckpointHandlers
~/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.9/site-packages/orbax/checkpoint/checkpoint_manager.py in restore(self, step, items, restore_kwargs, directory, args)
1087
1088 restore_directory = self._get_read_step_directory(step, directory)
-> 1089 restored = self._checkpointer.restore(restore_directory, args=args)
1090 if self._single_item:
1091 return restored[DEFAULT_ITEM_NAME]
~/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.9/site-packages/orbax/checkpoint/async_checkpointer.py in restore(self, directory, *args, **kwargs)
340 """See superclass documentation."""
341 self.wait_until_finished()
--> 342 return super().restore(directory, *args, **kwargs)
343
344 def check_for_errors(self):
~/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.9/site-packages/orbax/checkpoint/checkpointer.py in restore(self, directory, *args, **kwargs)
167 logging.info('Restoring item from %s.', directory)
168 ckpt_args = construct_checkpoint_args(self._handler, False, *args, **kwargs)
--> 169 restored = self._handler.restore(directory, args=ckpt_args)
170 logging.info('Finished restoring checkpoint from %s.', directory)
171 utils.sync_global_processes('Checkpointer:restore', self._active_processes)
~/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.9/site-packages/orbax/checkpoint/composite_checkpoint_handler.py in restore(self, directory, args)
437 continue
438 if handler is None:
--> 439 raise ValueError(
440 f'Item with name: "{item_name}" had an undetermined'
441 ' `CheckpointHandler` when restoring. Please ensure the handler'
ValueError: Item with name: "state" had an undetermined `CheckpointHandler` when restoring. Please ensure the handler was specified during initialization, or use the appropriate `CheckpointArgs` subclass to indicate the item type.
new_mngr.restore(
0,
args=ocp.args.Composite(
state=ocp.args.StandardRestore(abstract_pytree),
extra_metadata=ocp.args.JsonRestore(),
),
)
CompositeArgs({'extra_metadata': {'version': 1.0}, 'state': {'a': 0}})
new_mngr.close()
# item_handlers can be used as an alternative to restore(..., args=...).
with ocp.CheckpointManager(
'/tmp/ckpt4/',
options=options,
item_handlers={
'state': ocp.StandardCheckpointHandler(),
'extra_metadata': ocp.JsonCheckpointHandler(),
},
) as new_mngr:
print(new_mngr.restore(0))
WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.
CompositeArgs({'extra_metadata': {'version': 1.0}, 'state': {'a': 0}})
NOTE:
CheckpointManager.item_metadata(step)
doesn’t support any input like args
in restore(..., args=...)
.
So, item_handlers
is the only option available with item_metadata(step)
calls.
# item_handlers becomes even more critical with item_metadata() calls.
with ocp.CheckpointManager(
'/tmp/ckpt4/',
options=options,
item_names=('state', 'extra_metadata'),
) as new_mngr:
new_mngr.item_metadata(0) # Raises error due to unmapped CheckpointHandlers
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/tmp/ipykernel_712/2842306685.py in <module>
5 item_names=('state', 'extra_metadata'),
6 ) as new_mngr:
----> 7 new_mngr.item_metadata(0) # Raises error due to unmapped CheckpointHandlers
~/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.9/site-packages/orbax/checkpoint/checkpoint_manager.py in item_metadata(self, step)
1105 items_missing_handlers.append(item_name)
1106 if items_missing_handlers:
-> 1107 raise ValueError(
1108 'No mapped CheckpointHandler found for items:'
1109 f' {items_missing_handlers}. Please see documentation of'
ValueError: No mapped CheckpointHandler found for items: ['state', 'extra_metadata']. Please see documentation of `item_handlers` in CheckpointManager.
with ocp.CheckpointManager(
'/tmp/ckpt4/',
options=options,
item_handlers={
'state': ocp.StandardCheckpointHandler(),
'extra_metadata': ocp.JsonCheckpointHandler(),
},
) as new_mngr:
print(new_mngr.item_metadata(0))
CompositeArgs({'state': {'a': ScalarMetadata(name='a', directory=PosixGPath('/tmp/ckpt4/0/state'), shape=(), sharding=None, dtype=dtype('int64'))}, 'extra_metadata': None})