CheckpointArgs#

Defines exported CheckpointArgs classes.

CheckpointHandler subclasses define logic used to save and restore an object to and from a checkpoint. Each CheckpointHandler has corresponding SaveArgs and RestoreArgs classes that define the arguments used to call the handler.

The ocp.args module provides a complete definition of these classes. Refer to ocp.handlers for more information on the handlers themselves.

CheckpointArgs#

class orbax.checkpoint.args.CheckpointArgs[source][source]#

Base class for all checkpoint argument dataclasses.

All CheckpointHandler implementations should have corresponding CheckpointArgs dataclasses, typically one for save and one for restore.

Use one of the subclasses of CheckpointArgs for your use case to specify how an object should be saved or restored.

Typical usage:

with ocp.AsyncCheckpointer(ocp.StandardCheckpointHandler()) as ckptr:
  ckptr.save(
      path,
      args=ocp.args.StandardSave(train_state)
  )

Example subclass:

@ocp.args.register_with_handler(MyCheckpointHandler, for_save=True)
@dataclasses.dataclass
class MyCheckpointSave(ocp.args.CheckpointArgs):
  item: Any
  options: Any

@ocp.args.register_with_handler(MyCheckpointHandler, for_restore=True)
@dataclasses.dataclass
class MyCheckpointRestore(ocp.args.CheckpointArgs):
  options: Any

Example usage:

ckptr.save(
    path,
    custom_state=MyCheckpointSave(item=..., options=...)
)

ckptr.save(
    path,
    custom_state=MyCheckpointRestore(options=...)
)
__eq__(other)#

Return self==value.

__hash__ = None#
__init__()#

Composite#

orbax.checkpoint.args.Composite[source]#

alias of CompositeArgs

Standard PyTree#

orbax.checkpoint.args.StandardSave[source]#

alias of StandardSaveArgs

orbax.checkpoint.args.StandardRestore[source]#

alias of StandardRestoreArgs

Generic PyTree#

orbax.checkpoint.args.PyTreeSave[source]#

alias of PyTreeSaveArgs

orbax.checkpoint.args.PyTreeRestore[source]#

alias of PyTreeRestoreArgs

Array#

orbax.checkpoint.args.ArraySave[source]#

alias of ArraySaveArgs

orbax.checkpoint.args.ArrayRestore[source]#

alias of ArrayRestoreArgs

JSON#

orbax.checkpoint.args.JsonSave[source]#

alias of JsonSaveArgs

orbax.checkpoint.args.JsonRestore[source]#

alias of JsonRestoreArgs

Proto#

orbax.checkpoint.args.ProtoSave[source]#

alias of ProtoSaveArgs

orbax.checkpoint.args.ProtoRestore[source]#

alias of ProtoRestoreArgs

JaxRandomKey#

orbax.checkpoint.args.JaxRandomKeySave[source]#

alias of JaxRandomKeySaveArgs

orbax.checkpoint.args.JaxRandomKeyRestore[source]#

alias of JaxRandomKeyRestoreArgs

NumpyRandomKey#

orbax.checkpoint.args.NumpyRandomKeySave[source]#

alias of NumpyRandomKeySaveArgs

orbax.checkpoint.args.NumpyRandomKeyRestore[source]#

alias of NumpyRandomKeyRestoreArgs

Utilities#

orbax.checkpoint.args.get_registered_handler_cls(arg)[source][source]#

Returns the registered CheckpointHandler.

Return type:

Type[CheckpointHandler]

orbax.checkpoint.args.register_with_handler(handler_cls, for_save=False, for_restore=False)[source][source]#

Registers a CheckpointArgs subclass with a specific handler.

This registration is necessary so that when the user passes uses this CheckpointArgs class with CompositeCheckpointHandler, we can automatically find the correct Handler to use to save this class.

Note, for_save and for_restore may both be true, but cannot both be false.

Parameters:
  • handler_cls (Type[CheckpointHandler]) – CheckpointHandler to be associated with this CheckpointArg.

  • for_save (bool) – indicates whether the CheckpointArg is registered as a save argument.

  • for_restore (bool) – indicates whether the CheckpointArg is registered as a restore argument.

Returns:

Decorator.

orbax.checkpoint.args.get_registered_args_cls(handler)[source][source]#

Returns the registered CheckpointArgs corresponding to the handler.

Parameters:

handler (Union[Type[CheckpointHandler], CheckpointHandler]) – CheckpointHandler instance or class.

Return type:

Tuple[Type[CheckpointArgs], Type[CheckpointArgs]]

Returns:

Tuple of (save, restore) CheckpointArgs classes.

orbax.checkpoint.args.has_registered_args(handler)[source][source]#
Return type:

bool