# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Abstract class to manage checkpoints: AbstractCheckpointManager."""
import abc
from typing import Any, Mapping, Optional, Protocol, Sequence, Union
from etils import epath
from orbax.checkpoint import args as args_lib
from orbax.checkpoint._src.metadata import checkpoint
PyTree = Any
SaveParams = Mapping[str, Any]
RestoreParams = SaveParams
[docs]
class AbstractCheckpointManager(Protocol):
"""Interface to manage checkpoints.
Allows a user to save and restore objects for which a
:py:class:`.Checkpointer` implementation exists (e.g.
:py:class:`.PyTreeCheckpointer` for PyTrees). The class keeps track of
multiple checkpointable objects in the following structure::
path/to/directory/ (top-level directory)
0/ (step)
params/ (first saveable)
...
metadata/ (second saveable)
...
1/ (step)
...
2/ (step)
...
...
"""
@property
@abc.abstractmethod
def directory(self) -> epath.Path:
"""Returns the top-level directory containing checkpoints for all items."""
[docs]
@abc.abstractmethod
def all_steps(self, read: bool = False) -> Sequence[int]:
"""Returns all steps tracked by the manager.
Args:
read: If True, forces a read directly from the storage location.
Otherwise, a cached result can be returned.
Returns:
A sequence of steps (integers)
"""
[docs]
@abc.abstractmethod
def latest_step(self) -> Optional[int]:
"""Returns the latest step saved.
Returns None if no steps have been saved.
Returns:
A step (int) or None if no steps are present.
"""
[docs]
@abc.abstractmethod
def best_step(self) -> Optional[int]:
"""Returns the best step saved, as defined by `options.best_fn`.
Returns None if no steps have been saved.
Returns:
A step (int) or None if no steps are present.
"""
[docs]
@abc.abstractmethod
def reload(self):
"""Reloads internal properties.
Resets internal cache of checkpoint steps, in case the directory managed
by this object has been updated externally.
"""
[docs]
@abc.abstractmethod
def reached_preemption(self, step: int) -> bool:
"""Returns True if a preemption sync point has been reached."""
[docs]
@abc.abstractmethod
def should_save(self, step: int) -> bool:
"""Returns True if a checkpoint should be saved for the current step.
This depends the previous step and save interval.
Args:
step: int
Returns:
True if the checkpoint should be saved.
"""
[docs]
@abc.abstractmethod
def delete(self, step: int):
"""Deletes a step checkpoint."""
[docs]
@abc.abstractmethod
def save(
self,
step: int,
*args,
**kwargs,
) -> bool:
"""Saves the given step."""
[docs]
@abc.abstractmethod
def restore(
self,
step: Optional[int],
*args,
**kwargs,
) -> Union[Any, Mapping[str, Any], args_lib.Composite]:
"""Restores the given step."""
[docs]
@abc.abstractmethod
def metrics(self, step: int) -> Optional[PyTree]:
"""Returns metrics for step, if present."""
[docs]
@abc.abstractmethod
def wait_until_finished(self):
"""Blocks until any incomplete save operations are completed.
Note that this method will typically be a no-op if all checkpointers are
synchronous, since old checkpoints are already cleaned up immediately after
completing `save`, and there is no background thread to wait for.
If some checkpointers are of type :py:class:`.AsyncCheckpointer`, however,
this method will wait until each of these checkpointers is finished.
"""
[docs]
@abc.abstractmethod
def check_for_errors(self):
"""Checks for any outstanding errors in completed asynchronous save operations.
Delegates to underlying :py:class:`.Checkpointer`.
"""
[docs]
@abc.abstractmethod
def close(self):
"""Waits for outstanding operations to finish and closes :py:class:`.Checkpointer` s."""