Source code for orbax.checkpoint.checkpoint_args

# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""CheckpointArgs base class and registration."""

import dataclasses
import inspect
from typing import Tuple, Type, TypeVar, Union

from orbax.checkpoint._src.handlers import checkpoint_handler
from orbax.checkpoint._src.handlers import handler_type_registry

CheckpointHandler = checkpoint_handler.CheckpointHandler


[docs] @dataclasses.dataclass class CheckpointArgs: """Base class for all checkpoint argument dataclasses. All :py:class:`.CheckpointHandler` implementations should have corresponding :py:class:`CheckpointArgs` dataclasses, typically one for save and one for restore. Use one of the subclasses of :py:class:`CheckpointArgs` for your use case to specify how an object should be saved or restored. Typical usage:: with ocp.AsyncCheckpointer(ocp.StandardCheckpointHandler()) as ckptr: ckptr.save( path, args=ocp.args.StandardSave(train_state) ) Example subclass:: @ocp.args.register_with_handler(MyCheckpointHandler, for_save=True) @dataclasses.dataclass class MyCheckpointSave(ocp.args.CheckpointArgs): item: Any options: Any @ocp.args.register_with_handler(MyCheckpointHandler, for_restore=True) @dataclasses.dataclass class MyCheckpointRestore(ocp.args.CheckpointArgs): options: Any Example usage:: ckptr.save( path, custom_state=MyCheckpointSave(item=..., options=...) ) ckptr.save( path, custom_state=MyCheckpointRestore(options=...) ) """ pass
_SAVE_ARG_TO_HANDLER: dict[Type[CheckpointArgs], Type[CheckpointHandler]] = {} _RESTORE_ARG_TO_HANDLER: dict[Type[CheckpointArgs], Type[CheckpointHandler]] = ( {} ) _CheckpointArgsType = TypeVar('_CheckpointArgsType', bound=CheckpointArgs)
[docs] def register_with_handler( handler_cls: Type[CheckpointHandler], for_save: bool = False, for_restore: bool = False, ): """Registers a :py:class:`CheckpointArgs` subclass with a specific handler. This registration is necessary so that when the user passes uses this :py:class:`CheckpointArgs` class with :py:class:`.CompositeCheckpointHandler`, we can automatically find the correct Handler to use to save this class. Note, `for_save` and `for_restore` may both be true, but cannot both be false. Args: handler_cls: `CheckpointHandler` to be associated with this `CheckpointArg`. for_save: indicates whether the `CheckpointArg` is registered as a save argument. for_restore: indicates whether the `CheckpointArg` is registered as a restore argument. Returns: Decorator. """ if not for_save and not for_restore: raise ValueError('`for_save` and `for_restore` cannot both be False.') def decorator( cls: Type[_CheckpointArgsType], ) -> Type[_CheckpointArgsType]: if not issubclass(cls, CheckpointArgs): raise TypeError( f'{cls} must subclass `CheckpointArgs` in order to be registered.' ) if for_save: _SAVE_ARG_TO_HANDLER[cls] = handler_cls if for_restore: _RESTORE_ARG_TO_HANDLER[cls] = handler_cls handler_type_registry.register_handler_type(handler_cls) return cls return decorator
[docs] def get_registered_handler_cls( arg: Union[Type[CheckpointArgs], CheckpointArgs] ) -> Type[CheckpointHandler]: """Returns the registered :py:class:`.CheckpointHandler`.""" if not inspect.isclass(arg): arg = type(arg) if not issubclass(arg, CheckpointArgs): raise TypeError(f'{arg} must be a subclass of `CheckpointArgs`.') if arg not in _SAVE_ARG_TO_HANDLER and arg not in _RESTORE_ARG_TO_HANDLER: raise ValueError( f'Unable to find registered `CheckpointHandler` for {arg}. Use' ' `register_with_handler`.' ) if arg in _SAVE_ARG_TO_HANDLER: return _SAVE_ARG_TO_HANDLER[arg] else: return _RESTORE_ARG_TO_HANDLER[arg]
[docs] def get_registered_args_cls( handler: Union[Type[CheckpointHandler], CheckpointHandler] ) -> Tuple[Type[CheckpointArgs], Type[CheckpointArgs]]: """Returns the registered CheckpointArgs corresponding to the handler. Args: handler: `CheckpointHandler` instance or class. Returns: Tuple of (save, restore) `CheckpointArgs` classes. """ save_args = None restore_args = None if not inspect.isclass(handler): handler = type(handler) for arg_cls, handler_cls in _SAVE_ARG_TO_HANDLER.items(): if handler_cls == handler: save_args = arg_cls break if save_args is None: raise ValueError( f'Unable to find registered `CheckpointArgs` for save for {handler}.' ) for arg_cls, handler_cls in _RESTORE_ARG_TO_HANDLER.items(): if handler_cls == handler: restore_args = arg_cls break if restore_args is None: raise ValueError( f'Unable to find registered `CheckpointArgs` for restore for {handler}.' ) return save_args, restore_args
[docs] def has_registered_args( handler: Union[Type[CheckpointHandler], CheckpointHandler] ) -> bool: try: get_registered_args_cls(handler) except ValueError: return False return True