Source code for orbax.checkpoint._src.checkpoint_managers.preservation_policy

# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Defines policies for when a checkpoint is preserved."""

import dataclasses
import datetime
from typing import Any, Callable, Dict, Protocol, Sequence, Set
from absl import logging
import numpy as np
from orbax.checkpoint._src.checkpoint_managers import policy_checkpoint_info

NestedDict = Dict[str, Any]
PyTree = Any
PolicyCheckpointInfo = policy_checkpoint_info.PolicyCheckpointInfo


[docs] @dataclasses.dataclass(kw_only=True) class PreservationContext: """Additional properties for making a save decision."""
[docs] class PreservationPolicy(Protocol): """A policy that defines when checkpoints should be preserved."""
[docs] def should_preserve( self, checkpoints: Sequence[PolicyCheckpointInfo], *, context: PreservationContext, ) -> Sequence[bool]: """Indicates which checkpoints should be preserved..""" ...
def _log_preservation_decision( policy_name: str, checkpoints: Sequence[PolicyCheckpointInfo], should_preserve_list: Sequence[bool], ): """Logs preservation decisions.""" if logging.vlog_is_on(1): for i, checkpoint in enumerate(checkpoints): if should_preserve_list[i]: logging.vlog( 1, f" {policy_name}: Preserving checkpoint at step" f" {checkpoint.step}).", ) else: logging.vlog( 1, f" {policy_name}: Not preserving checkpoint at step" f" {checkpoint.step}).", )
[docs] @dataclasses.dataclass class PreserveAll(PreservationPolicy): """Preserves all checkpoints."""
[docs] def should_preserve( self, checkpoints: Sequence[PolicyCheckpointInfo], *, context: PreservationContext, ) -> Sequence[bool]: result = [True] * len(checkpoints) _log_preservation_decision("PreserveAll", checkpoints, result) return result
[docs] @dataclasses.dataclass class LatestN(PreservationPolicy): """Preserves the last n checkpoints. Preserves all checkpoint if n is None.""" n: int | None = None
[docs] def should_preserve( self, checkpoints: Sequence[PolicyCheckpointInfo], *, context: PreservationContext, ) -> Sequence[bool]: if self.n is None or len(checkpoints) <= self.n: result = [True] * len(checkpoints) else: result = [False] * (len(checkpoints) - self.n) + [True] * self.n _log_preservation_decision(f"LatestN (n={self.n})", checkpoints, result) return result
[docs] @dataclasses.dataclass class EveryNSeconds(PreservationPolicy): """Ensures checkpoints are preserved at least after the time interval.""" interval_secs: int
[docs] def should_preserve( self, checkpoints: Sequence[PolicyCheckpointInfo], *, context: PreservationContext, ) -> Sequence[bool]: if not checkpoints: return [] last_preserved_checkpoint = checkpoints[0] result = [True] for info in checkpoints[1:]: if info.time - last_preserved_checkpoint.time >= datetime.timedelta( seconds=self.interval_secs ): last_preserved_checkpoint = info result.append(True) else: result.append(False) _log_preservation_decision( f"EveryNSeconds (interval_secs={self.interval_secs})", checkpoints, result, ) return result
[docs] @dataclasses.dataclass class EveryNSteps(PreservationPolicy): """Preserves checkpoints after at least N steps.""" interval_steps: int exact_interval: bool = True max_to_keep: int | None = None
[docs] def should_preserve( self, checkpoints: Sequence[PolicyCheckpointInfo], *, context: PreservationContext, ) -> Sequence[bool]: if self.interval_steps == 0: raise ValueError("interval_steps must not be 0.") result = [] if self.exact_interval: result = [ckpt.step % self.interval_steps == 0 for ckpt in checkpoints] else: previous_step = None for i, ckpt in enumerate(checkpoints): if i == 0: result.append(True) # Always preserve the first checkpoint. previous_step = ckpt.step elif ckpt.step - previous_step >= self.interval_steps: result.append(True) previous_step = ckpt.step else: result.append(False) if self.max_to_keep is not None: true_indices = [i for i, val in enumerate(result) if val] if len(true_indices) > self.max_to_keep: for i in true_indices[: -self.max_to_keep]: result[i] = False _log_preservation_decision( f"EveryNSteps (interval_steps={self.interval_steps}," f" max_to_keep={self.max_to_keep})", checkpoints, result, ) return result
[docs] @dataclasses.dataclass class EveryNStepsClosest(PreservationPolicy): """Preserves checkpoints at steps closest to absolute multiples of N. This policy maps each checkpoint to its closest nominal target step on a grid defined by `interval_steps` (i.e. `k * interval_steps`). For each nominal target, the closest available checkpoint is preserved. This avoids the error accumulation/drift that can occur with `EveryNSteps(exact_interval=False)` when checkpoints are irregular. The last checkpoint is always preserved for final model state and efficient recovery. """ interval_steps: int max_to_keep: int | None = None
[docs] def should_preserve( self, checkpoints: Sequence[PolicyCheckpointInfo], *, context: PreservationContext, ) -> Sequence[bool]: if self.interval_steps == 0: raise ValueError("interval_steps must not be 0.") if not checkpoints: return [] # Find the best index for each grid bucket best_indices = {} # k -> idx best_diffs = {} # k -> diff for i, ckpt in enumerate(checkpoints): k = round(ckpt.step / self.interval_steps) diff = abs(ckpt.step - k * self.interval_steps) # If multiple checkpoints are equally close, keep the later one, which is # likely what most users would expect. if k not in best_diffs or diff <= best_diffs[k]: best_indices[k] = i best_diffs[k] = diff result = [False] * len(checkpoints) for idx in best_indices.values(): result[idx] = True result[-1] = True # Always keep the last one. if self.max_to_keep is not None: true_indices = [i for i, val in enumerate(result) if val] if len(true_indices) > self.max_to_keep: for i in true_indices[: -self.max_to_keep]: result[i] = False _log_preservation_decision( f"EveryNStepsClosest (interval_steps={self.interval_steps}," f" max_to_keep={self.max_to_keep})", checkpoints, result, ) return result
[docs] @dataclasses.dataclass class CustomSteps(PreservationPolicy): """Preserves checkpoints at the given steps.""" steps: dataclasses.InitVar[Sequence[int]] _steps_set: Set[int] = dataclasses.field(init=False)
[docs] def __post_init__(self, steps_init: Sequence[int]): """Initializes the internal set of steps after the object is created.""" self._steps_set = set(steps_init)
[docs] def should_preserve( self, checkpoints: Sequence[PolicyCheckpointInfo], *, context: PreservationContext, ) -> Sequence[bool]: result = [ckpt.step in self._steps_set for ckpt in checkpoints] _log_preservation_decision( f"CustomSteps (steps={self._steps_set})", checkpoints, result ) return result
[docs] @dataclasses.dataclass class AnyPreservationPolicy(PreservationPolicy): """Applies multiple preservation policies and preserves if any policy preserves.""" policies: Sequence[PreservationPolicy]
[docs] def should_preserve( self, checkpoints: Sequence[PolicyCheckpointInfo], *, context: PreservationContext, ) -> Sequence[bool]: logging.vlog( 1, "AnyPreservationPolicy: Preserving if any child policy preserves." ) should_preserve_by_policy = np.asarray([ policy.should_preserve(checkpoints, context=context) for policy in self.policies ]) return np.any(should_preserve_by_policy, axis=0).tolist()
[docs] @dataclasses.dataclass(kw_only=True) class BestN(PreservationPolicy): """A policy that preserves the best checkpoints based on a best_fn. get_metric_fn: A function that accepts a nested tree of metrics and returns a scalar value representing the value used for ranking checkpoints. reverse: If False (default), checkpoints are sorted in ascending order, according to the best_fn. If True, checkpoints are sorted in descending order. Same as the semantics of built-in sorted() function. n: The number of checkpoints to preserve. If None, all checkpoints are preserved. If 0, no checkpoints are preserved. """ get_metric_fn: Callable[[PyTree], float] reverse: bool = False n: int | None = None keep_checkpoints_without_metrics: bool = True
[docs] def should_preserve( self, checkpoints: Sequence[PolicyCheckpointInfo], *, context: PreservationContext, ) -> Sequence[bool]: if self.n is None or len(checkpoints) <= self.n: return [True] * len(checkpoints) if self.n == 0: return [False] * len(checkpoints) indexed_checkpoints_with_metrics = [ (i, info) for (i, info) in [(i, cp) for i, cp in enumerate(checkpoints)] if info.metrics is not None ] indexed_checkpoints_without_metrics = [ (i, info) for (i, info) in [(i, cp) for i, cp in enumerate(checkpoints)] if info.metrics is None ] indexed_checkpoints_with_metrics = sorted( indexed_checkpoints_with_metrics, key=lambda item: self.get_metric_fn(item[1].metrics), reverse=self.reverse, ) preserve_indices = [ i for i, _ in indexed_checkpoints_with_metrics[-self.n :] ] if self.keep_checkpoints_without_metrics: preserve_indices += [i for i, _ in indexed_checkpoints_without_metrics] preserve_indices_set = set(preserve_indices) preserve_flags = [ i in preserve_indices_set for i in range(len(checkpoints)) ] _log_preservation_decision( f"BestN (n={self.n})", checkpoints, preserve_flags ) return preserve_flags
[docs] @dataclasses.dataclass class LatestDuration(PreservationPolicy): """Preserves checkpoints that are newer than the given duration. E.g. retain checkpoints within the last 24 hours:: import datetime LatestDuration(datetime.timedelta(hours=24)) """ duration: datetime.timedelta
[docs] def should_preserve( self, checkpoints: Sequence[PolicyCheckpointInfo], *, context: PreservationContext, ) -> Sequence[bool]: if not checkpoints: return [] current_time = datetime.datetime.now(tz=datetime.timezone.utc) first_preserve_idx = -1 for i, ckpt in enumerate(checkpoints): if current_time - ckpt.time <= self.duration: first_preserve_idx = i break if first_preserve_idx == -1: result = [False] * len(checkpoints) else: result = [False] * first_preserve_idx + [True] * ( len(checkpoints) - first_preserve_idx ) _log_preservation_decision( f"LatestDuration (duration={self.duration})", checkpoints, result ) return result