CheckpointArgs#

Exported symbols under orbax.checkpoint.args.

CheckpointArgs#

class orbax.checkpoint.args.CheckpointArgs[source][source]#

Base class for all checkpoint argument dataclasses.

Subclass this dataclass to define the arguments for your custom CheckpointHandler. When users use the CheckpointHandler, they will use this CheckpointArgs to see how to

Example subclass: ``` import ocp.checkpoint as ocp

@ocp.args.register_with_handler(MyCheckpointHandler, for_save=True) @dataclasses.dataclass class MyCheckpointSave(ocp.args.CheckpointArgs):

item: Any options: Any

@ocp.args.register_with_handler(MyCheckpointHandler, for_restore=True) @dataclasses.dataclass class MyCheckpointRestore(ocp.args.CheckpointArgs):

options: Any

```

Example usage: ``` ckptr.save(

path, custom_state=MyCheckpointSave(item=…, options=…)

)

ckptr.save(

path, custom_state=MyCheckpointRestore(options=…)

)#

__eq__(other)#

Return self==value.

__hash__ = None#
__init__()#
class orbax.checkpoint.args.CheckpointArgs[source][source]#

Base class for all checkpoint argument dataclasses.

Subclass this dataclass to define the arguments for your custom CheckpointHandler. When users use the CheckpointHandler, they will use this CheckpointArgs to see how to

Example subclass: ``` import ocp.checkpoint as ocp

@ocp.args.register_with_handler(MyCheckpointHandler, for_save=True) @dataclasses.dataclass class MyCheckpointSave(ocp.args.CheckpointArgs):

item: Any options: Any

@ocp.args.register_with_handler(MyCheckpointHandler, for_restore=True) @dataclasses.dataclass class MyCheckpointRestore(ocp.args.CheckpointArgs):

options: Any

```

Example usage: ``` ckptr.save(

path, custom_state=MyCheckpointSave(item=…, options=…)

)

ckptr.save(

path, custom_state=MyCheckpointRestore(options=…)

)#

__eq__(other)#

Return self==value.

__hash__ = None#
__init__()#

Array#

orbax.checkpoint.args.ArraySave[source]#

alias of ArraySaveArgs

orbax.checkpoint.args.ArrayRestore[source]#

alias of ArrayRestoreArgs

Standard PyTree#

orbax.checkpoint.args.StandardSave[source]#

alias of StandardSaveArgs

orbax.checkpoint.args.StandardRestore[source]#

alias of StandardRestoreArgs

Generic PyTree#

orbax.checkpoint.args.PyTreeSave[source]#

alias of PyTreeSaveArgs

orbax.checkpoint.args.PyTreeRestore[source]#

alias of PyTreeRestoreArgs

Composite#

orbax.checkpoint.args.Composite[source]#

alias of CompositeArgs

JSON#

orbax.checkpoint.args.JsonSave[source]#

alias of JsonSaveArgs

orbax.checkpoint.args.JsonRestore[source]#

alias of JsonRestoreArgs

Proto#

orbax.checkpoint.args.ProtoSave[source]#

alias of ProtoSaveArgs

orbax.checkpoint.args.ProtoRestore[source]#

alias of ProtoRestoreArgs

JaxRandomKey#

orbax.checkpoint.args.JaxRandomKeySave[source]#

alias of JaxRandomKeySaveArgs

orbax.checkpoint.args.JaxRandomKeyRestore[source]#

alias of JaxRandomKeyRestoreArgs

NumpyRandomKey#

orbax.checkpoint.args.NumpyRandomKeySave[source]#

alias of NumpyRandomKeySaveArgs

orbax.checkpoint.args.NumpyRandomKeyRestore[source]#

alias of NumpyRandomKeyRestoreArgs

Utilities#

orbax.checkpoint.args.get_registered_handler_cls(arg)[source][source]#

Returns the registered CheckpointHandler.

Return type:

Type[CheckpointHandler]

orbax.checkpoint.args.register_with_handler(handler_cls, for_save=False, for_restore=False)[source][source]#

Registers a CheckpointArg subclass with a specific handler.

This registration is necessary so that when the user passes uses this CheckpointArg class with CompositeCheckpointHandler, we can automatically find the correct Handler to use to save this class.

Note, for_save and for_restore may both be true, but cannot both be false.

Parameters:
  • handler_cls (Type[CheckpointHandler]) – CheckpointHandler to be associated with this CheckpointArg.

  • for_save (bool) – indicates whether the CheckpointArg is registered as a save argument.

  • for_restore (bool) – indicates whether the CheckpointArg is registered as a restore argument.

Returns:

Decorator.

orbax.checkpoint.args.get_registered_args_cls(handler)[source][source]#

Returns the registered CheckpointArgs corresponding to the handler.

Parameters:

handler (Union[Type[CheckpointHandler], CheckpointHandler]) – CheckpointHandler instance or class.

Return type:

Tuple[Type[CheckpointArgs], Type[CheckpointArgs]]

Returns:

Tuple of (save, restore) CheckpointArgs classes.

orbax.checkpoint.args.has_registered_args(handler)[source][source]#
Return type:

bool