Source code for orbax.checkpoint._src.checkpoint_managers.save_decision_policy

# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Defines policies for when a checkpoint is saved."""

import dataclasses
import datetime
import typing
from typing import Container, Protocol, Sequence
from absl import logging
from orbax.checkpoint import options as options_lib
from orbax.checkpoint._src.checkpoint_managers import policy_checkpoint_info
from orbax.checkpoint._src.futures import signaling_client
from orbax.checkpoint._src.futures import synchronization
from orbax.checkpoint._src.multihost import multihost


PolicyCheckpointInfo = policy_checkpoint_info.PolicyCheckpointInfo


[docs] @dataclasses.dataclass(kw_only=True) class DecisionContext: """Additional properties for making a save decision. This dataclass is populated by the checkpointer framework and passed into the `should_save` method of all `SaveDecisionPolicy` implementations. It provides essential external system context, allowing policies to make safe, state-aware decisions. Attributes: is_saving_in_progress: Indicates whether an asynchronous checkpoint save operation is currently running in the background. Policies (like `ContinuousCheckpointingPolicy`) use this to avoid triggering overlapping save operations. reached_preemption: Indicates whether a preemption signal has been received from the cluster manager, meaning the training job is about to be terminated. multiprocessing_options: Configuration details for distributed multihost training. This provides information such as primary host identification for synchronization barriers. """ is_saving_in_progress: bool reached_preemption: bool multiprocessing_options: options_lib.MultiprocessingOptions
[docs] @typing.runtime_checkable class SaveDecisionPolicy(Protocol): """A policy that defines when to save a checkpoint."""
[docs] def should_save( self, step: PolicyCheckpointInfo, previous_steps: Sequence[PolicyCheckpointInfo], *, context: DecisionContext, ) -> bool: """Returns True if a checkpoint should be saved at the given step.""" ...
def _log_save_decision( policy_name: str, step: PolicyCheckpointInfo, is_saving: bool, ) -> None: """Logs the save decision.""" if is_saving: logging.vlog( 1, f"{policy_name}: Saving checkpoint at step" f" {step.step}).", ) else: logging.vlog( 1, f"{policy_name}: Not saving checkpoint at step" f" {step.step}).", )
[docs] @dataclasses.dataclass class FixedIntervalPolicy(SaveDecisionPolicy): """Checkpoint at a fixed interval. This policy evaluates to True whenever the current training step is an exact multiple of the configured `interval` (i.e., `step.step % interval == 0`). It makes its decision purely based on the current step index, strictly ignoring previous save history or external context. Attributes: interval: The frequency at which checkpoints should be saved. For example, an interval of 100 means a save is triggered at steps 100, 200, 300, etc. Methods: should_save(step, previous_steps, *, context): Evaluates whether the current step index is a multiple of the interval. Args: step (PolicyCheckpointInfo): Information about the current training step, primarily using the step index for modulo arithmetic. previous_steps (Sequence[PolicyCheckpointInfo]): Ignored by this policy. context (DecisionContext): Ignored by this policy. """ interval: int
[docs] def should_save( self, step: PolicyCheckpointInfo, previous_steps: Sequence[PolicyCheckpointInfo], *, context: DecisionContext, ) -> bool: del previous_steps del context result = step.step % self.interval == 0 _log_save_decision( f"FixedIntervalPolicy (interval={self.interval})", step, result, ) return result
[docs] @dataclasses.dataclass class SpecificStepsPolicy(SaveDecisionPolicy): """Checkpoint at specific steps. This policy evaluates to True whenever the current training step index exists within the configured `steps` container (i.e., `step.step in steps`). It makes its decision purely based on the current step index, strictly ignoring previous save history or external context. Attributes: steps (Container[int]): A collection (such as a `set`, `list`, or `tuple`) of step indices where checkpoints should be saved. Methods: should_save(step, previous_steps, *, context): Evaluates whether the current step index exists in the steps container. Args: step (PolicyCheckpointInfo): Information about the current training step, primarily using the step index for membership testing. previous_steps (Sequence[PolicyCheckpointInfo]): Ignored by this policy. context (DecisionContext): Ignored by this policy. """ steps: Container[int]
[docs] def should_save( self, step: PolicyCheckpointInfo, previous_steps: Sequence[PolicyCheckpointInfo], *, context: DecisionContext, ) -> bool: del previous_steps del context result = step.step in self.steps _log_save_decision( f"SpecificStepsPolicy (steps={self.steps})", step, result, ) return result
[docs] @dataclasses.dataclass(kw_only=True) class ContinuousCheckpointingPolicy(SaveDecisionPolicy): """Checkpoint as often as possible, as long as a save is not ongoing. This policy evaluates to True as often as possible. It enforces two primary constraints to prevent blocking training or causing other regressions. 1. It will never trigger a new save if a save is currently in progress (checked via the provided `DecisionContext`); this prevents blocking on an ongoing save, which would hurt accelerator utilization. 2. It optionally respects a minimum time interval between saves if `minimum_interval_secs` is configured. This sets a floor on how frequently checkpoints are saved, which can be used to avoid excessive burden on the filesystem, or blocking too frequently (due to synchronous D2H). In a distributed training environment, to ensure perfect synchronization and avoid race conditions, the time and state-based save decision is computed exclusively on the primary host. The result is then broadcast to all other hosts via a blocking barrier. For usage examples, please refer to the parent class `SaveDecisionPolicy`. Attributes: minimum_interval_secs (int | None): The minimum time in seconds that must elapse between the timestamp of the previous checkpoint and the current step. If `None` (the default), back-to-back saves are permitted as soon as the ongoing save completes. Methods: should_save(step, previous_steps, *, context): Evaluates the current state and synchronizes across all hosts to return a boolean indicating whether a checkpoint should be saved. Args: step (PolicyCheckpointInfo): Information about the current training step, including the step index and timestamp. previous_steps (Sequence[PolicyCheckpointInfo]): A chronological list of metadata for all steps where a checkpoint was successfully saved. context (DecisionContext): A container for auxiliary information, such as the current saving state (`is_saving_in_progress`) and multiprocessing configuration used to inform the decision. """ minimum_interval_secs: int | None = None
[docs] def should_save( self, step: PolicyCheckpointInfo, previous_steps: Sequence[PolicyCheckpointInfo], *, context: DecisionContext, ) -> bool: def _get_should_save_result() -> bool: if context.is_saving_in_progress: return False if not previous_steps or self.minimum_interval_secs is None: return True return step.time - previous_steps[-1].time >= datetime.timedelta( seconds=self.minimum_interval_secs ) is_primary_host = multihost.is_primary_host( context.multiprocessing_options.primary_host ) client = signaling_client.get_signaling_client() operation_id = synchronization.OperationIdGenerator.next_operation_id() result_barrier_key = f"{operation_id}_continuous_checkpointing_policy_should_save_{step.step}/" # Make time based and save in progress based decision only on primary host # and broadcast to all hosts. if is_primary_host: save_result = int(_get_should_save_result()) client.key_value_set( result_barrier_key, str(save_result), allow_overwrite=True, ) save_result = int( client.blocking_key_value_get( result_barrier_key, timeout_secs=multihost.coordination_timeout(), ) ) _log_save_decision( "ContinuousCheckpointingPolicy" f" (minimum_interval_secs={self.minimum_interval_secs})", step, bool(save_result), ) return bool(save_result)
[docs] class PreemptionCheckpointingPolicy(SaveDecisionPolicy): """Save a checkpoint when a preemption is detected. This policy evaluates to True strictly when the provided `DecisionContext` indicates that a preemption signal has been received (i.e., `context.reached_preemption` is True). It can be useful for ensuring that training progress is safely stored before the job is killed by the cluster scheduler. Note that saving on preemption is not strictly necessary, however. For example, if continuous checkpointing is employed, and checkpoints are saved frequently enough, the cost of re-computing some amount of steps can be cheaper than the cost of waiting for a checkpoint to complete after preemption. Methods: should_save(step, previous_steps, *, context): Evaluates whether a preemption signal has been registered in the context. Args: step (PolicyCheckpointInfo): Ignored by this policy. previous_steps (Sequence[PolicyCheckpointInfo]): Ignored by this policy. context (DecisionContext): A container for auxiliary information. This policy specifically checks the `reached_preemption` boolean flag to make its decision. """
[docs] def should_save( self, step: PolicyCheckpointInfo, previous_steps: Sequence[PolicyCheckpointInfo], *, context: DecisionContext, ) -> bool: _log_save_decision( "PreemptionCheckpointingPolicy", step, context.reached_preemption, ) del step del previous_steps return context.reached_preemption
[docs] class InitialSavePolicy(SaveDecisionPolicy): """Save a checkpoint as soon as possible if no checkpoints already exist. This policy evaluates to True only if the `previous_steps` sequence is empty. It is highly useful for ensuring a baseline checkpoint is created immediately upon starting a fresh training run, while safely evaluating to False if the job is restarting from an existing checkpoint. Methods: should_save(step, previous_steps, *, context): Evaluates whether the `previous_steps` history is empty. Args: step (PolicyCheckpointInfo): Ignored by this policy. previous_steps (Sequence[PolicyCheckpointInfo]): A chronological list of metadata for previously saved steps. The policy checks if this is empty. context (DecisionContext): Ignored by this policy. """
[docs] def should_save( self, step: PolicyCheckpointInfo, previous_steps: Sequence[PolicyCheckpointInfo], *, context: DecisionContext, ) -> bool: _log_save_decision( "InitialSavePolicy", step, not previous_steps, ) del step del context return not previous_steps
[docs] @dataclasses.dataclass class AnySavePolicy(SaveDecisionPolicy): """Evaluates all policies and saves if any of them returns True. This policy iterates through a provided sequence of child policies. It evaluates each one in order and returns True immediately if any child policy returns True. If all child policies return False, this policy returns False. It is highly useful for combining time-based, step-based, and event-based saving rules into a single, unified checkpointer configuration. Attributes: policies (Sequence[SaveDecisionPolicy]): An ordered collection of underlying policies to evaluate. Methods: should_save(step, previous_steps, *, context): Evaluates the sequence of configured policies. Args: step (PolicyCheckpointInfo): Passed down to each child policy. previous_steps (Sequence[PolicyCheckpointInfo]): Passed down to each child policy. context (DecisionContext): Passed down to each child policy. """ policies: Sequence[SaveDecisionPolicy]
[docs] def should_save( self, step: PolicyCheckpointInfo, previous_steps: Sequence[PolicyCheckpointInfo], *, context: DecisionContext, ) -> bool: logging.vlog( 1, "AnySavePolicy: policies=%s, step=%d.", self.policies, step.step, ) return any( policy.should_save(step, previous_steps=previous_steps, context=context) for policy in self.policies )