CheckpointArgs#
Exported symbols under orbax.checkpoint.args.
CheckpointArgs#
- class orbax.checkpoint.args.CheckpointArgs[source][source]#
Base class for all checkpoint argument dataclasses.
Subclass this dataclass to define the arguments for your custom CheckpointHandler. When users use the CheckpointHandler, they will use this CheckpointArgs to see how to
Example subclass: ``` import ocp.checkpoint as ocp
@ocp.args.register_with_handler(MyCheckpointHandler, for_save=True) @dataclasses.dataclass class MyCheckpointSave(ocp.args.CheckpointArgs):
item: Any options: Any
@ocp.args.register_with_handler(MyCheckpointHandler, for_restore=True) @dataclasses.dataclass class MyCheckpointRestore(ocp.args.CheckpointArgs):
options: Any
Example usage: ``` ckptr.save(
path, custom_state=MyCheckpointSave(item=…, options=…)
)
- ckptr.save(
path, custom_state=MyCheckpointRestore(options=…)
)#
- __eq__(other)#
Return self==value.
- __hash__ = None#
- __init__()#
- class orbax.checkpoint.args.CheckpointArgs[source][source]#
Base class for all checkpoint argument dataclasses.
Subclass this dataclass to define the arguments for your custom CheckpointHandler. When users use the CheckpointHandler, they will use this CheckpointArgs to see how to
Example subclass: ``` import ocp.checkpoint as ocp
@ocp.args.register_with_handler(MyCheckpointHandler, for_save=True) @dataclasses.dataclass class MyCheckpointSave(ocp.args.CheckpointArgs):
item: Any options: Any
@ocp.args.register_with_handler(MyCheckpointHandler, for_restore=True) @dataclasses.dataclass class MyCheckpointRestore(ocp.args.CheckpointArgs):
options: Any
Example usage: ``` ckptr.save(
path, custom_state=MyCheckpointSave(item=…, options=…)
)
- ckptr.save(
path, custom_state=MyCheckpointRestore(options=…)
)#
- __eq__(other)#
Return self==value.
- __hash__ = None#
- __init__()#
Array#
Standard PyTree#
Generic PyTree#
Composite#
JSON#
Proto#
JaxRandomKey#
NumpyRandomKey#
Utilities#
- orbax.checkpoint.args.get_registered_handler_cls(arg)[source][source]#
Returns the registered CheckpointHandler.
- Return type:
Type
[CheckpointHandler
]
- orbax.checkpoint.args.register_with_handler(handler_cls, for_save=False, for_restore=False)[source][source]#
Registers a CheckpointArg subclass with a specific handler.
This registration is necessary so that when the user passes uses this CheckpointArg class with CompositeCheckpointHandler, we can automatically find the correct Handler to use to save this class.
Note, for_save and for_restore may both be true, but cannot both be false.
- Parameters:
handler_cls (
Type
[CheckpointHandler
]) – CheckpointHandler to be associated with this CheckpointArg.for_save (
bool
) – indicates whether the CheckpointArg is registered as a save argument.for_restore (
bool
) – indicates whether the CheckpointArg is registered as a restore argument.
- Returns:
Decorator.
- orbax.checkpoint.args.get_registered_args_cls(handler)[source][source]#
Returns the registered CheckpointArgs corresponding to the handler.
- Parameters:
handler (
Union
[Type
[CheckpointHandler
],CheckpointHandler
]) – CheckpointHandler instance or class.- Return type:
Tuple
[Type
[CheckpointArgs
],Type
[CheckpointArgs
]]- Returns:
Tuple of (save, restore) CheckpointArgs classes.