CheckpointArgs#
Defines exported CheckpointArgs classes.
CheckpointHandler subclasses define logic used to save and restore
an object to and from a checkpoint. Each CheckpointHandler
has corresponding SaveArgs and RestoreArgs
classes that define the arguments used to call the handler.
The ocp.args module provides a complete definition of these classes. Refer to ocp.handlers for more information on the handlers themselves.
CheckpointArgs#
- class orbax.checkpoint.args.CheckpointArgs[source][source]#
Base class for all checkpoint argument dataclasses.
All
CheckpointHandlerimplementations should have correspondingCheckpointArgsdataclasses, typically one for save and one for restore.Use one of the subclasses of
CheckpointArgsfor your use case to specify how an object should be saved or restored.Typical usage:
with ocp.AsyncCheckpointer(ocp.StandardCheckpointHandler()) as ckptr: ckptr.save( path, args=ocp.args.StandardSave(train_state) )
Example subclass:
@ocp.args.register_with_handler(MyCheckpointHandler, for_save=True) @dataclasses.dataclass class MyCheckpointSave(ocp.args.CheckpointArgs): item: Any options: Any @ocp.args.register_with_handler(MyCheckpointHandler, for_restore=True) @dataclasses.dataclass class MyCheckpointRestore(ocp.args.CheckpointArgs): options: Any
Example usage:
ckptr.save( path, custom_state=MyCheckpointSave(item=..., options=...) ) ckptr.save( path, custom_state=MyCheckpointRestore(options=...) )
- __eq__(other)#
Return self==value.
- __hash__ = None#
- __init__()#
Composite#
- orbax.checkpoint.args.Composite[source]#
alias of
CompositeArgs
Standard PyTree#
- orbax.checkpoint.args.StandardSave[source]#
alias of
StandardSaveArgs
- orbax.checkpoint.args.StandardRestore[source]#
alias of
StandardRestoreArgs
Generic PyTree#
- orbax.checkpoint.args.PyTreeSave[source]#
alias of
PyTreeSaveArgs
- orbax.checkpoint.args.PyTreeRestore[source]#
alias of
PyTreeRestoreArgs
Array#
- orbax.checkpoint.args.ArraySave[source]#
alias of
ArraySaveArgs
- orbax.checkpoint.args.ArrayRestore[source]#
alias of
ArrayRestoreArgs
JSON#
- orbax.checkpoint.args.JsonSave[source]#
alias of
JsonSaveArgs
- orbax.checkpoint.args.JsonRestore[source]#
alias of
JsonRestoreArgs
Proto#
- orbax.checkpoint.args.ProtoSave[source]#
alias of
ProtoSaveArgs
- orbax.checkpoint.args.ProtoRestore[source]#
alias of
ProtoRestoreArgs
JaxRandomKey#
- orbax.checkpoint.args.JaxRandomKeySave[source]#
alias of
JaxRandomKeySaveArgs
- orbax.checkpoint.args.JaxRandomKeyRestore[source]#
alias of
JaxRandomKeyRestoreArgs
NumpyRandomKey#
- orbax.checkpoint.args.NumpyRandomKeySave[source]#
alias of
NumpyRandomKeySaveArgs
- orbax.checkpoint.args.NumpyRandomKeyRestore[source]#
alias of
NumpyRandomKeyRestoreArgs
Utilities#
- orbax.checkpoint.args.get_registered_handler_cls(arg)[source][source]#
Returns the registered
CheckpointHandler.- Return type:
Type[CheckpointHandler]
- orbax.checkpoint.args.register_with_handler(handler_cls, for_save=False, for_restore=False)[source][source]#
Registers a
CheckpointArgssubclass with a specific handler.This registration is necessary so that when the user passes uses this
CheckpointArgsclass withCompositeCheckpointHandler, we can automatically find the correct Handler to use to save this class.Note, for_save and for_restore may both be true, but cannot both be false.
- Parameters:
handler_cls (
Type[CheckpointHandler]) – CheckpointHandler to be associated with this CheckpointArg.for_save (
bool) – indicates whether the CheckpointArg is registered as a save argument.for_restore (
bool) – indicates whether the CheckpointArg is registered as a restore argument.
- Returns:
Decorator.
- orbax.checkpoint.args.get_registered_args_cls(handler)[source][source]#
Returns the registered CheckpointArgs corresponding to the handler.
- Parameters:
handler (
Union[Type[CheckpointHandler],CheckpointHandler]) – CheckpointHandler instance or class.- Return type:
Tuple[Type[CheckpointArgs],Type[CheckpointArgs]]- Returns:
Tuple of (save, restore) CheckpointArgs classes.