ocp.v1.training.save_decision_policies module#

Defines policies for when a checkpoint is saved.

SaveDecisionPolicy#

class orbax.checkpoint.experimental.v1.training.save_decision_policies.SaveDecisionPolicy(*args, **kwargs)[source]#

Bases: Protocol

A policy that defines when to save a checkpoint.

SaveDecisionPolicy is a protocol that defines the interface for making checkpoint save decisions. Implementations of this protocol should define the logic for when to save a checkpoint based on the given step and previous steps. Before implementing a new policy, users should check whether any of Orbax’s existing policies (e.g., FixedIntervalPolicy, ContinuousCheckpointingPolicy, etc.) can be used.

Examples:

1. Configuring Checkpointer SaveDecisionPolicy instances can be passed to Checkpointer to control save frequency. For example:

from orbax.checkpoint.v1 import training
policies = training.save_decision_policies

# Save every 1000 steps, or when a preemption is detected.
policy = policies.AnySavePolicy([
    policies.FixedIntervalPolicy(1000),
    policies.PreemptionCheckpointingPolicy(),
])
checkpointer = training.Checkpointer(directory, save_decision_policy=policy)

2. Implementing a custom policy To define custom saving rules, users may implement the SaveDecisionPolicy interface:

class SaveEveryNSteps(SaveDecisionPolicy):
  def __init__(self, n: int):
    self.n = n

  def should_save(
      self,
      step: CheckpointMetadata,
      previous_steps: Sequence[CheckpointMetadata],
      *,
      context: DecisionContext
  ) -> bool:
    # step.step accesses the integer index of the current training step.
    return step.step % self.n == 0
should_save(step, previous_steps, *, context)[source]#

Evaluates the current state to return a boolean indicating whether a checkpoint should be saved.

Parameters:
  • step (CheckpointMetadata) – Metadata for the current training step, containing the step index, timestamp, and metadata.

  • previous_steps (Sequence[CheckpointMetadata]) – A chronological list of metadata for all steps where a checkpoint was successfully saved.

  • context (DecisionContext) – A container for auxiliary information, such as validation loss or performance metrics, used to inform the save decision.

should_save(step, previous_steps, *, context)[source]#

Returns True if a checkpoint should be saved at the given step.

Return type:

bool

ContinuousCheckpointingPolicy#

class orbax.checkpoint.experimental.v1.training.save_decision_policies.ContinuousCheckpointingPolicy(*, minimum_interval_secs=None)[source]#

Checkpoint as often as possible, as long as a save is not ongoing.

This policy evaluates to True as often as possible. It enforces two primary constraints to prevent blocking training or causing other regressions.

  1. It will never trigger a new save if a save is currently in progress (checked via the provided DecisionContext); this prevents blocking on an ongoing save, which would hurt accelerator utilization.

  2. It optionally respects a minimum time interval between saves if minimum_interval_secs is configured. This sets a floor on how frequently checkpoints are saved, which can be used to avoid excessive burden on the filesystem, or blocking too frequently (due to synchronous D2H).

In a distributed training environment, to ensure perfect synchronization and avoid race conditions, the time and state-based save decision is computed exclusively on the primary host. The result is then broadcast to all other hosts via a blocking barrier.

For usage examples, please refer to the parent class SaveDecisionPolicy.

minimum_interval_secs#

The minimum time in seconds that must elapse between the timestamp of the previous checkpoint and the current step. If None (the default), back-to-back saves are permitted as soon as the ongoing save completes.

Type:

int | None

should_save(step, previous_steps, *, context)[source]#

Evaluates the current state and synchronizes across all hosts to return a boolean indicating whether a checkpoint should be saved.

Parameters:
  • step (PolicyCheckpointInfo) – Information about the current training step, including the step index and timestamp.

  • previous_steps (Sequence[PolicyCheckpointInfo]) – A chronological list of metadata for all steps where a checkpoint was successfully saved.

  • context (DecisionContext) – A container for auxiliary information, such as the current saving state (is_saving_in_progress) and multiprocessing configuration used to inform the decision.

should_save(step, previous_steps, *, context)[source]#

Returns True if a checkpoint should be saved at the given step.

Return type:

bool

FixedIntervalPolicy#

class orbax.checkpoint.experimental.v1.training.save_decision_policies.FixedIntervalPolicy(interval)[source]#

Checkpoint at a fixed interval.

This policy evaluates to True whenever the current training step is an exact multiple of the configured interval (i.e., step.step % interval == 0). It makes its decision purely based on the current step index, strictly ignoring previous save history or external context.

interval#

The frequency at which checkpoints should be saved. For example, an interval of 100 means a save is triggered at steps 100, 200, 300, etc.

Type:

int

should_save(step, previous_steps, *, context)[source]#

Evaluates whether the current step index is a multiple of the interval.

Parameters:
  • step (PolicyCheckpointInfo) – Information about the current training step, primarily using the step index for modulo arithmetic.

  • previous_steps (Sequence[PolicyCheckpointInfo]) – Ignored by this policy.

  • context (DecisionContext) – Ignored by this policy.

should_save(step, previous_steps, *, context)[source]#

Returns True if a checkpoint should be saved at the given step.

Return type:

bool

InitialSavePolicy#

class orbax.checkpoint.experimental.v1.training.save_decision_policies.InitialSavePolicy(*args, **kwargs)[source]#

Save a checkpoint as soon as possible if no checkpoints already exist.

This policy evaluates to True only if the previous_steps sequence is empty. It is highly useful for ensuring a baseline checkpoint is created immediately upon starting a fresh training run, while safely evaluating to False if the job is restarting from an existing checkpoint.

should_save(step, previous_steps, *, context)[source]#

Evaluates whether the previous_steps history is empty.

Parameters:
  • step (PolicyCheckpointInfo) – Ignored by this policy.

  • previous_steps (Sequence[PolicyCheckpointInfo]) – A chronological list of metadata for previously saved steps. The policy checks if this is empty.

  • context (DecisionContext) – Ignored by this policy.

should_save(step, previous_steps, *, context)[source]#

Returns True if a checkpoint should be saved at the given step.

Return type:

bool

PreemptionCheckpointingPolicy#

class orbax.checkpoint.experimental.v1.training.save_decision_policies.PreemptionCheckpointingPolicy(*args, **kwargs)[source]#

Save a checkpoint when a preemption is detected.

This policy evaluates to True strictly when the provided DecisionContext indicates that a preemption signal has been received (i.e., context.reached_preemption is True). It can be useful for ensuring that training progress is safely stored before the job is killed by the cluster scheduler.

Note that saving on preemption is not strictly necessary, however. For example, if continuous checkpointing is employed, and checkpoints are saved frequently enough, the cost of re-computing some amount of steps can be cheaper than the cost of waiting for a checkpoint to complete after preemption.

should_save(step, previous_steps, *, context)[source]#

Evaluates whether a preemption signal has been registered in the context.

Parameters:
  • step (PolicyCheckpointInfo) – Ignored by this policy.

  • previous_steps (Sequence[PolicyCheckpointInfo]) – Ignored by this policy.

  • context (DecisionContext) – A container for auxiliary information. This policy specifically checks the reached_preemption boolean flag to make its decision.

should_save(step, previous_steps, *, context)[source]#

Returns True if a checkpoint should be saved at the given step.

Return type:

bool

SpecificStepsPolicy#

class orbax.checkpoint.experimental.v1.training.save_decision_policies.SpecificStepsPolicy(steps)[source]#

Checkpoint at specific steps.

This policy evaluates to True whenever the current training step index exists within the configured steps container (i.e., step.step in steps). It makes its decision purely based on the current step index, strictly ignoring previous save history or external context.

steps#

A collection (such as a set, list, or tuple) of step indices where checkpoints should be saved.

Type:

Container[int]

should_save(step, previous_steps, *, context)[source]#

Evaluates whether the current step index exists in the steps container.

Parameters:
  • step (PolicyCheckpointInfo) – Information about the current training step, primarily using the step index for membership testing.

  • previous_steps (Sequence[PolicyCheckpointInfo]) – Ignored by this policy.

  • context (DecisionContext) – Ignored by this policy.

should_save(step, previous_steps, *, context)[source]#

Returns True if a checkpoint should be saved at the given step.

Return type:

bool

AnySavePolicy#

class orbax.checkpoint.experimental.v1.training.save_decision_policies.AnySavePolicy(policies)[source]#

Evaluates all policies and saves if any of them returns True.

This policy iterates through a provided sequence of child policies. It evaluates each one in order and returns True immediately if any child policy returns True. If all child policies return False, this policy returns False. It is highly useful for combining time-based, step-based, and event-based saving rules into a single, unified checkpointer configuration.

policies#

An ordered collection of underlying policies to evaluate.

Type:

Sequence[SaveDecisionPolicy]

should_save(step, previous_steps, *, context)[source]#

Evaluates the sequence of configured policies.

Parameters:
  • step (PolicyCheckpointInfo) – Passed down to each child policy.

  • previous_steps (Sequence[PolicyCheckpointInfo]) – Passed down to each child policy.

  • context (DecisionContext) – Passed down to each child policy.

should_save(step, previous_steps, *, context)[source]#

Returns True if a checkpoint should be saved at the given step.

Return type:

bool

DecisionContext#

class orbax.checkpoint.experimental.v1.training.save_decision_policies.DecisionContext(*, is_saving_in_progress, reached_preemption, multiprocessing_options)[source]#

Additional properties for making a save decision.

This dataclass is populated by the checkpointer framework and passed into the should_save method of all SaveDecisionPolicy implementations. It provides essential external system context, allowing policies to make safe, state-aware decisions.

is_saving_in_progress#

Indicates whether an asynchronous checkpoint save operation is currently running in the background. Policies (like ContinuousCheckpointingPolicy) use this to avoid triggering overlapping save operations.

Type:

bool

reached_preemption#

Indicates whether a preemption signal has been received from the cluster manager, meaning the training job is about to be terminated.

Type:

bool

multiprocessing_options#

Configuration details for distributed multihost training. This provides information such as primary host identification for synchronization barriers.

Type:

orbax.checkpoint.options.MultiprocessingOptions

__eq__(other)#

Return self==value.

__hash__ = None#
__init__(*, is_saving_in_progress, reached_preemption, multiprocessing_options)#