Using the Refactored CheckpointManager API#
As of orbax-checkpoint-0.5.0
, several new APIs have been introduced at multiple different levels. The most significant change is to how users interact with CheckpointManager
. This page shows a side-by-side comparison of the old and new APIs.
The legacy APIs is deprecated and will stop working soon. Please ensure you are using the new style ASAP.
CheckpointManager.save(...)
is now async by default. Make sure you call wait_until_finished
if depending on a previous save being completed. Otherwise, the behavior can be disabled via the
CheckpointManagerOptions.enable_async_checkpointing
option.
For further information on how to use the new API, see the introductory tutorial and the API Overview.
import orbax.checkpoint as ocp
from etils import epath
# Dummy PyTrees for simplicity.
# In reality, this would be a tree of np.ndarray or jax.Array.
pytree = {'a': 0}
# In reality, this would be a tree of jax.ShapeDtypeStruct (metadata
# for restoration).
abstract_pytree = {'a': 0}
extra_metadata = {'version': 1.0}
Single-Item Checkpointing#
Before#
options = ocp.CheckpointManagerOptions()
mngr = ocp.CheckpointManager(
ocp.test_utils.erase_and_create_empty('/tmp/ckpt1/'),
ocp.Checkpointer(ocp.PyTreeCheckpointHandler()),
options=options,
)
restore_args = ocp.checkpoint_utils.construct_restore_args(abstract_pytree)
mngr.save(0, pytree)
mngr.wait_until_finished()
mngr.restore(
0,
items=abstract_pytree,
restore_kwargs={'restore_args': restore_args}
)
WARNING:absl:Configured `CheckpointManager` using deprecated legacy API. Please follow the instructions at https://orbax.readthedocs.io/en/latest/api_refactor.html to migrate.
{'a': 0}
After#
options = ocp.CheckpointManagerOptions()
with ocp.CheckpointManager(
ocp.test_utils.erase_and_create_empty('/tmp/ckpt2/'),
options=options,
) as mngr:
mngr.save(0, args=ocp.args.StandardSave(pytree))
# The `CheckpointManager` already knows that the object is saved and restored
# using "standard" pytree logic. In many cases, you can restore exactly as
# saved without specifying additional arguments.
mngr.restore(0)
# If customization of properties like sharding or dtype is desired, just provide
# the abstract target PyTree, the properties of which will be used to set
# the properties of the restored arrays.
mngr.restore(0, args=ocp.args.StandardRestore(abstract_pytree))
WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.
Important notes:
Don’t forget to use the keyword
args=...
for save and restore! Otherwise you will get the legacy API. This will not be necessary forever, but only until the legacy API is removed.The value of
args
is a subclass ofCheckpointArgs
, present in theocp.args
module. These classes are used to communicate the logic that you wish to use to save and restore your object. For a typical PyTree consisting of arrays, useStandardSave
/StandardRestore
.
Let’s explore scenarios when restore()
and item_metadata()
calls raise errors due to unspecified CheckpointHandlers for item names.
# Unmapped CheckpointHandlers on a new CheckpointManager instance.
new_mngr = ocp.CheckpointManager('/tmp/ckpt2/', options=options)
try:
new_mngr.restore(0) # Raises error due to unmapped CheckpointHandler
except BaseException as e:
print(e)
WARNING:absl:Item "default" was found in the checkpoint, but could not be restored. Please provide a `CheckpointHandlerRegistry`, or call `restore` with an appropriate `CheckpointArgs` subclass.
WARNING:absl:Item "default" was found in the checkpoint, but could not be restored. Please provide a `CheckpointHandlerRegistry`, or call `restore` with an appropriate `CheckpointArgs` subclass.
WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.
To fix this, use one of the following options:
new_mngr.restore(0, args=ocp.args.StandardRestore(abstract_pytree))
new_mngr.close()
We can also configure the CheckpointManager
to know how to restore the object in advance.
# The item name is "default".
list(epath.Path('/tmp/ckpt2/0').iterdir())
[PosixGPath('/tmp/ckpt2/0/default'),
PosixGPath('/tmp/ckpt2/0/_CHECKPOINT_METADATA')]
registry = ocp.handlers.DefaultCheckpointHandlerRegistry()
registry.add('default', ocp.args.StandardRestore, ocp.StandardCheckpointHandler)
# item_handlers can be used as an alternative to restore(..., args=...).
with ocp.CheckpointManager(
'/tmp/ckpt2/',
options=options,
handler_registry=registry,
) as new_mngr:
print(new_mngr.restore(0))
WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.
{'a': 0}
NOTE:
CheckpointManager.item_metadata(step)
doesn’t support any input like args
in restore(..., args=...)
.
So, handler_registry
is currently required when calling item_metadata(step)
before calling restore or save.
# item_handlers becomes even more critical with item_metadata() calls.
new_mngr = ocp.CheckpointManager('/tmp/ckpt2/', options=options)
try:
new_mngr.item_metadata(0) # Raises error due to unmapped CheckpointHandler
except BaseException as e:
print(e)
WARNING:absl:Item "default" was found in the checkpoint, but could not be restored. Please provide a `CheckpointHandlerRegistry`, or call `restore` with an appropriate `CheckpointArgs` subclass.
WARNING:absl:Item "default" was found in the checkpoint, but could not be restored. Please provide a `CheckpointHandlerRegistry`, or call `restore` with an appropriate `CheckpointArgs` subclass.
with ocp.CheckpointManager(
'/tmp/ckpt2/',
options=options,
handler_registry=registry,
) as new_mngr:
new_mngr.item_metadata(0)
Multiple-Item Checkpointing#
Before#
options = ocp.CheckpointManagerOptions()
mngr = ocp.CheckpointManager(
ocp.test_utils.erase_and_create_empty('/tmp/ckpt3/'),
{
'state': ocp.Checkpointer(ocp.PyTreeCheckpointHandler()),
'extra_metadata': ocp.Checkpointer(ocp.JsonCheckpointHandler())
},
options=options,
)
restore_args = ocp.checkpoint_utils.construct_restore_args(abstract_pytree)
mngr.save(0, {'state': pytree, 'extra_metadata': extra_metadata})
mngr.wait_until_finished()
mngr.restore(
0,
items={'state': abstract_pytree, 'extra_metadata': None},
restore_kwargs={
'state': {'restore_args': restore_args},
'extra_metadata': None
},
)
WARNING:absl:Configured `CheckpointManager` using deprecated legacy API. Please follow the instructions at https://orbax.readthedocs.io/en/latest/api_refactor.html to migrate.
Composite({'extra_metadata': {'version': 1.0}, 'state': {'a': 0}})
After#
options = ocp.CheckpointManagerOptions()
mngr = ocp.CheckpointManager(
ocp.test_utils.erase_and_create_empty('/tmp/ckpt4/'),
# `item_names` defines an up-front contract about what items the
# CheckpointManager will be dealing with.
options=options,
)
mngr.save(0, args=ocp.args.Composite(
state=ocp.args.StandardSave(pytree),
extra_metadata=ocp.args.JsonSave(extra_metadata))
)
mngr.wait_until_finished()
# Restore as saved
mngr.restore(0)
WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.
Composite({'extra_metadata': {'version': 1.0}, 'state': {'a': 0}})
# Restore with customization. Restore a subset of items.
mngr.restore(0, args=ocp.args.Composite(
state=ocp.args.StandardRestore(abstract_pytree)))
Composite({'state': {'a': 0}})
mngr.close()
Just like single item use case described above, let’s explore scenarios when restore()
and item_metadata()
calls raise errors due to unspecified CheckpointHandlers for item names.
# Unmapped CheckpointHandlers on a new CheckpointManager instance.
new_mngr = ocp.CheckpointManager(
'/tmp/ckpt4/',
options=options,
item_names=('state', 'extra_metadata'),
)
try:
new_mngr.restore(0) # Raises error due to unmapped CheckpointHandlers
except BaseException as e:
print(e)
Item with name: "state" had an undetermined `CheckpointHandler` when restoring. Please ensure the handler was specified during initialization, or use the appropriate `CheckpointArgs` subclass to indicate the item type.
new_mngr.restore(
0,
args=ocp.args.Composite(
state=ocp.args.StandardRestore(abstract_pytree),
extra_metadata=ocp.args.JsonRestore(),
),
)
new_mngr.close()
registry = ocp.handlers.DefaultCheckpointHandlerRegistry()
registry.add('state', ocp.args.StandardRestore, ocp.StandardCheckpointHandler)
registry.add('extra_metadata', ocp.args.JsonRestore, ocp.JsonCheckpointHandler)
# item_handlers can be used as an alternative to restore(..., args=...).
with ocp.CheckpointManager(
'/tmp/ckpt4/',
options=options,
handler_registry=registry,
) as new_mngr:
print(new_mngr.restore(0))
WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.
Composite({'extra_metadata': {'version': 1.0}, 'state': {'a': 0}})
NOTE:
CheckpointManager.item_metadata(step)
doesn’t support any input like args
in restore(..., args=...)
.
So, handler_registry
is currently required with item_metadata(step)
calls.
# item_handlers becomes even more critical with item_metadata() calls.
with ocp.CheckpointManager(
'/tmp/ckpt4/',
options=options,
item_names=('state', 'extra_metadata'),
) as new_mngr:
try:
new_mngr.item_metadata(0) # Raises error due to unmapped CheckpointHandlers
except BaseException as e:
print(e)
WARNING:absl:Item "state" was found in the checkpoint, but could not be restored. Please provide a `CheckpointHandlerRegistry`, or call `restore` with an appropriate `CheckpointArgs` subclass.
WARNING:absl:Item "extra_metadata" was found in the checkpoint, but could not be restored. Please provide a `CheckpointHandlerRegistry`, or call `restore` with an appropriate `CheckpointArgs` subclass.
WARNING:absl:Item "state" was found in the checkpoint, but could not be restored. Please provide a `CheckpointHandlerRegistry`, or call `restore` with an appropriate `CheckpointArgs` subclass.
WARNING:absl:Item "extra_metadata" was found in the checkpoint, but could not be restored. Please provide a `CheckpointHandlerRegistry`, or call `restore` with an appropriate `CheckpointArgs` subclass.
with ocp.CheckpointManager(
'/tmp/ckpt4/',
options=options,
handler_registry=registry,
) as new_mngr:
print(new_mngr.item_metadata(0))
Composite({'state': TreeMetadata(
custom_metadata=None
tree={'a': ScalarMetadata(name='a', directory=PosixGPath('/tmp/ckpt4/0/state'), shape=(), sharding=None, dtype=dtype('int64'), storage=None)}
), 'extra_metadata': None})