ocp.v1.training module#
Public API for training package.
Checkpointer#
- final class orbax.checkpoint.experimental.v1.training.Checkpointer(directory, *, context=None, save_decision_policy=None, preservation_policy=None, step_name_format=None, custom_metadata=None, cleanup_tmp_directories=False, lightweight_initialize=False)[source][source]#
An object that manages a sequence of checkpoints in a training loop.
- __init__(directory, *, context=None, save_decision_policy=None, preservation_policy=None, step_name_format=None, custom_metadata=None, cleanup_tmp_directories=False, lightweight_initialize=False)[source][source]#
Initializes a Checkpointer.
IMPORTANT: This class is not thread safe. All APIs should be called across all available processes, from the main thread.
The Checkpointer is intended for use in a training loop, where a sequence of checkpoints are saved at regular intervals. Example usage:
# Configure the frequency at which checkpoints are saved. save_decision_policies = ocp.training.save_decision_policies # Save every 1000 steps, or when a preemption is detected. save_decision_policy = save_decision_policies.AnySavePolicy([ save_decision_policies.FixedIntervalPolicy(1000), save_decision_policies.PreemptionPolicy(), ]) # Configure the checkpoints to preserve (avoid garbage collection). preservation_policies = ocp.training.preservation_policies # Avoid garbage collection on the latest 10, or every 10000 steps. preservation_policy = preservation_policies.AnyPreservationPolicy([ preservation_policies.LatestN(10), preservation_policies.EveryNSteps(10000), ]) with ocp.training.Checkpointer( directory, save_decision_policy=save_decision_policy, preservation_policy=preservation_policy, ) as ckptr: if ckptr.latest is None: model_state = init_from_scratch(rng) else: model_state = ckptr.load() # Loads latest checkpoint. # Note: prefer to specify the abstract tree if available. model_state = ckptr.load( ckptr.latest, abstract_state=abstract_model_state) start_step = ckptr.latest.step if ckptr.latest else 0 for step in range(start_step, num_steps): model_state = train_step(model_state) # Saves a checkpoint if needed (according to `save_decision_policy`). ckptr.save(step, model_state)
Prefer to use the context manager style as shown above, which ensures that the Checkpointer is closed properly and any outstanding async operations are completed.
- Parameters:
directory (path_types.PathLike) – The root directory where checkpoints are stored. The directory will be created if it does not exist.
context (context_lib.Context | None) – A
Contextobject that will be used to wrap all function calls for this Checkpointer.save_decision_policy (save_decision_policies.SaveDecisionPolicy | None) – A policy used to determine when a checkpoint should be saved. If not provided, the Checkpointer saves as often as possible by default (assuming no checkpoint is currently being saved), and saves when a preemption is detected by the JAX distributed system.
preservation_policy (preservation_policies.PreservationPolicy | None) – A policy used to determine when a checkpoint should be preserved. Any checkpoints not preserved are garbage collected. If not provided,
step_name_format (path_step_lib.NameFormat[CheckpointMetadata[None]] | None) – An object used to specify the format for step paths. By default, steps are rendered as simple integers, like /root/directory/<step>.
custom_metadata (tree_types.JsonType | None) – A JSON dictionary representing user-specified custom metadata. This should be information that is relevant to the entire sequence of checkpoints, rather than to any single checkpoint.
cleanup_tmp_directories (bool) – If True, cleans up any existing temporary directories on Checkpointer creation.
lightweight_initialize (bool) – If True, checkpoint step metadata is not read on Checkpointer initialization during checkpoint info loading. This is useful to improve init performance when there are O(1k) or more existing checkpoint steps present and checkpoint info properties like time and metrics are not needed.
- property latest: CheckpointMetadata[None] | None#
Returns the latest
CheckpointMetadata, or None if no checkpoints exist.See checkpoints documentation below.
- Return type:
Optional[CheckpointMetadata[None],None]- Returns:
The latest checkpoint, or None if no checkpoints exist.
- property checkpoints: Sequence[CheckpointMetadata[None]]#
Returns a list of
CheckpointMetadata, sorted ascending by step.The method returns a list of
CheckpointMetadataobjects, which contain selected properties describing the checkpoint. Contrast this with the methodsmetadata()andcheckpointables_metadata(), which may perform a more expensive disk read to retrieve additional information. This method only returns cheap cacheable properties like step and timestamp. The return value is annotated asCheckpointMetadata[None]because the core metadata property is not retrieved, and is therefore None.The property is cached to avoid repeated disk reads. This is not a problem unless checkpoints are manually deleted, or deleted by some other job or class that Checkpointer is unaware of. Note that doing this is discouraged.
- Return type:
Sequence[CheckpointMetadata[None]]- Returns:
A list of checkpoints, sorted ascending by step.
- should_save(step)[source][source]#
Returns whether a checkpoint should be saved at the given step.
- Return type:
bool
- save(step, state, *, checkpointable_name='state', force=False, overwrite=False, metrics=None, custom_metadata=None)[source][source]#
Saves a checkpoint, if dictated by
SaveDecisionPolicy.This method behaves similarly to the standalone free function
save()(see documentation), but performs additional tasks related to managing a sequence of checkpoint steps.- It consists roughly of the following steps:
Check whether a checkpoint should be saved at the given step.
Check whether a save is already in progress. If so, wait for it to finish.
Save to a directory given by root_directory / <step_format>.
Perform garbage collection if necessary.
Return whether a checkpoint was saved or not.
It is important to note that the Checkpointer never allows saving more than one checkpoint at a time. Depending on the
SaveDecisionPolicy, a checkpoint may be saved or skipped at a given step, but if a save is initiated, as dictated by the policy, then it will proceed as normal as long as no other save is currently in progress. If a save is already in progress, the function will block until the previous save has finished.- Example usage:
Basic Usage: Save a PyTree at a specific training step. The checkpointer automatically manages the step-based directory structure inside your root folder:
from orbax.checkpoint.v1 import training # Initialize the checkpointer for a directory ckptr = training.Checkpointer(directory) # Save the tree at step 0. saved = ckptr.save(step=0, state=tree) # Clean up background threads gracefully when the training loop ends ckptr.close()
Advanced Saving with Metrics and Metadata: Attach JSON-serializable metrics (like loss/accuracy) and custom metadata to a specific step for thorough experiment tracking:
from orbax.checkpoint.v1 import training ckptr = training.Checkpointer(directory) ckptr.save( step=1, state=tree, metrics={'loss': 0.12, 'accuracy': 0.95}, custom_metadata={'description': 'Model after epoch 1'}, ) ckptr.close()
- Parameters:
step (int) – The step number to save.
state (tree_types.PyTreeOf[tree_types.Leaf]) – The PyTree to save.
checkpointable_name (str) – The name of the checkpointable to save a pytree under. Defaults to ‘pytree’.
force (bool) – If True, ignores all
SaveDecisionPolicychecks, and always decides to save a checkpoint.overwrite (bool) – If True, deletes any existing checkpoint at the given step before saving. Otherwise, raises an error if the checkpoint already exists.
metrics (tree_types.JsonType | None) – A PyTree of metrics to be saved with the checkpoint.
custom_metadata (tree_types.JsonType | None) – A JSON dictionary representing user-specified custom metadata. This should be information that is relevant to the checkpoint at the given step, rather than to the entire sequence of checkpoints.
- Return type:
bool
- Returns:
Whether a checkpoint was saved or not.
- save_checkpointables(step, checkpointables, *, force=False, overwrite=False, metrics=None, custom_metadata=None)[source][source]#
Saves a dictionary of checkpointable objects at the given step.
This method saves a dictionary of checkpointable objects, mapping string names to values. See the guide on Checkpointables for more details on checkpointables. Also see documentation for
save().Example
Basic Usage: Save multiple named items (checkpointables) at a specific step. The dictionary keys define the names of the saved components:
from orbax.checkpoint.v1 import training # Initialize the checkpointer for a directory ckptr = training.Checkpointer(directory) # Save multiple items, such as model weights and optimizer state items_to_save = { 'model': my_model_state, 'optimizer': my_opt_state, } saved = ckptr.save_checkpointables( step=0, checkpointables=items_to_save ) # Clean up background threads gracefully when the training loop ends ckptr.close()
Advanced Saving with Metrics and Metadata: Attach JSON-serializable metrics and custom metadata to a specific step for thorough experiment tracking:
from orbax.checkpoint.v1 import training ckptr = training.Checkpointer(directory) items_to_save = {'model': my_model_state} ckptr.save_checkpointables( step=1, checkpointables=items_to_save, metrics={'loss': 0.12, 'accuracy': 0.95}, custom_metadata={'description': 'Model after epoch 1'}, ) ckptr.close()
- Parameters:
step (int) – The step number to save.
checkpointables (dict[str, Any]) – A dictionary mapping string names to the corresponding objects (checkpointables) that need to be saved.
force (bool) – If True, ignores all policy checks and always decides to save a checkpoint.
overwrite (bool) – If True, deletes any existing checkpoint at the given step before saving. Otherwise, raises an error if the checkpoint already exists.
metrics (tree_types.JsonType | None) – A dictionary of metrics to be saved with the checkpoint. Must be JSON-serializable.
custom_metadata (tree_types.JsonType | None) – A JSON dictionary representing user-specified custom metadata relevant to the checkpoint at this specific step.
- Returns:
True if the checkpoint was successfully saved, False otherwise.
- Return type:
bool
- save_async(step, state, *, checkpointable_name='state', force=False, overwrite=False, metrics=None, custom_metadata=None)[source][source]#
Saves a checkpoint asynchronously.
This function is the asynchronous equivalent of
save(). It accepts the exact same arguments; please refer to that method for detailed descriptions.This method executes mostly in the background, blocking the main thread for as little time as possible.
Example
async_response = ckptr.save_async(step=0, state=tree) saved = async_response.result()
- Parameters:
step (int) – The step number to save.
state (tree_types.PyTreeOf[tree_types.Leaf]) – The PyTree to save.
checkpointable_name (str) – The name of the checkpointable to save a pytree under. Defaults to ‘pytree’.
force (bool) – See save.
overwrite (bool) – See save.
metrics (tree_types.JsonType | None) – See save.
custom_metadata (tree_types.JsonType | None) – See save.
- Return type:
async_types.AsyncResponse[bool]
- Returns:
An AsyncResponse, which can be awaited via result(), which returns a bool indicating whether a checkpoint was saved or not.
- save_checkpointables_async(step, checkpointables, *, force=False, overwrite=False, metrics=None, custom_metadata=None)[source][source]#
Saves checkpointable objects asynchronously.
This function is the asynchronous equivalent of
save_checkpointables(). Please refer to that method for detailed instructions and argument descriptions.Example
Save checkpointable objects asynchronously:
async_response = ckptr.save_checkpointables_async( step=0, checkpointables=items_to_save ) saved = async_response.result()
- Parameters:
step (int) – The step number to save.
checkpointables (dict[str, Any]) – A dictionary mapping string names to objects to save.
force (bool) – See save_checkpointables.
overwrite (bool) – See save_checkpointables.
metrics (tree_types.JsonType | None) – See save_checkpointables.
custom_metadata (tree_types.JsonType | None) – See save_checkpointables.
- Return type:
async_types.AsyncResponse[bool]
- Returns:
An object representing the background operation. Call .result() on it to block and return a boolean indicating whether the checkpoint was successfully saved.
- Raises:
StepAlreadyExistsError – If overwrite is False and a checkpoint at the target step already exists.
- load(step=None, abstract_state=None, *, checkpointable_name='state')[source][source]#
Loads a PyTree checkpoint at the given step.
This method behaves similarly to the standalone free function
load().Note: Loading a PyTree without providing an abstract_state is provided purely for convenience. For serious or production use cases, it is STRONGLY recommended to always provide an abstract_state to ensure the restored PyTree strictly matches the expected shapes, dtypes, and sharding.
Example
Basic Loading: Load a PyTree without providing an abstract structure. By passing step=None (or omitting it), it automatically loads the latest step:
from orbax.checkpoint.v1 import training # Initialize the checkpointer for the directory ckptr = training.Checkpointer(directory) # Load the saved PyTree from latest step restored_tree = ckptr.load(step=None)
Loading with an Abstract PyTree: Provide an abstract structure (such as target shapes and dtypes) to ensure the restored PyTree is safely and correctly formatted:
import jax import jax.numpy as jnp from orbax.checkpoint.v1 import training ckptr = training.Checkpointer(directory) # Define the expected structure (shapes and dtypes) to restore into target_structure = { 'weights': jax.ShapeDtypeStruct((128, 128), dtype=jnp.float32), 'bias': jax.ShapeDtypeStruct((128,), dtype=jnp.float32) } # Restore exactly matching the target structure restored_tree = ckptr.load( step=1, abstract_state=target_structure )
- Parameters:
step (
UnionType[int,CheckpointMetadata,None]) – The step number orCheckpointMetadatato load. If None, the checkpointer will attempt to resolve and load the latest existing checkpoint.abstract_state (
Optional[PyTreeOf[UnionType[AbstractArray,AbstractShardedArray,int,float,number,bytes,bool,str]],None]) – The abstract PyTree to load.checkpointable_name (
str) – The name of the checkpointable to load a pytree under. Defaults to ‘pytree’.
- Return type:
PyTreeOf[UnionType[Array,ndarray,int,float,number,bytes,bool,str]]- Returns:
The loaded PyTree.
- load_checkpointables(step=None, abstract_checkpointables=None)[source][source]#
Loads a set of checkpointables at the given step.
This method behaves similarly to the standalone free function
load_checkpointables().This function retrieves multiple named items (such as model weights or optimizer states) from a specific checkpoint directory. If no step is provided, it automatically resolves to and loads the most recently saved checkpoint.
Note: Loading without providing an abstract_checkpointables dictionary is provided purely for convenience. For serious or production use cases, it is STRONGLY recommended to always provide abstract_checkpointables to ensure the restored items strictly match the exact nested structures, shapes, and data types expected.
Example
Basic Loading: Load multiple named items (such as a model and optimizer) from a specific step. If step is omitted, it resolves to the latest available checkpoint:
from orbax.checkpoint.v1 import training # Initialize the checkpointer for the directory ckptr = training.Checkpointer(directory) # Load all checkpointables saved at the latest step restored_items = ckptr.load_checkpointables(step=None) # Access the individual components by their original string keys my_model = restored_items["model"] my_opt = restored_items["optimizer"]
Loading with Abstract Checkpointables (Recommended): Provide a dictionary of abstract structures to ensure the restored items strictly match your expected shapes and data types:
import jax import jax.numpy as jnp from orbax.checkpoint.v1 import training ckptr = training.Checkpointer(directory) # Define the expected structure for each named item using JAX arrays target_items = { "model": { 'weights': jax.ShapeDtypeStruct((128, 128), jnp.float32), 'bias': jax.ShapeDtypeStruct((128,), jnp.float32) }, "optimizer": { 'momentum': jax.ShapeDtypeStruct((128, 128), jnp.float32) } } # Restore exactly matching the target structures restored_items = ckptr.load_checkpointables( step=1, abstract_checkpointables=target_items )
Partial Loading: If you only need to load a subset of checkpointables (e.g., loading model weights but omitting optimizer state), you can provide an abstract_checkpointables dictionary containing only the keys for the items you wish to restore:
import jax import jax.numpy as jnp from orbax.checkpoint.v1 import training ckptr = training.Checkpointer(directory) # Define abstract structure for ONLY the items to load target_items = { "model": { 'weights': jax.ShapeDtypeStruct((128, 128), jnp.float32), 'bias': jax.ShapeDtypeStruct((128,), jnp.float32) }, } # Load only "model", omitting "optimizer" restored_items = ckptr.load_checkpointables( step=1, abstract_checkpointables=target_items ) my_model = restored_items["model"] # my_opt = restored_items["optimizer"]
- Parameters:
step (
UnionType[int,CheckpointMetadata,None]) – The step number orCheckpointMetadatato load. If None, the checkpointer will attempt to resolve and load the latest existing checkpoint.abstract_checkpointables (
UnionType[dict[str,Any],None]) – A dictionary mapping string names to their corresponding abstract structures (e.g., target PyTrees). This guides the loading process to ensure shape and type compliance. If provided, it can be used to load only a subset of checkpointables by providing only a subset of keys.
- Returns:
- A dictionary containing the loaded checkpointable objects,
keyed by string names. If abstract_checkpointables was specified, returns only the keys specified in that dict, otherwise returns all keys saved with save_checkpointables.
- Return type:
dict[str, Any]
- load_async(step=None, abstract_state=None)[source][source]#
Not yet supported.
- Return type:
AsyncResponse[PyTreeOf[UnionType[Array,ndarray,int,float,number,bytes,bool,str]]]
- load_checkpointables_async(step=None, abstract_checkpointables=None)[source][source]#
Loads a set of checkpointables asynchronously at the given step.
- Return type:
AsyncResponse[dict[str,Any]]
- metadata(step=None)[source][source]#
Returns checkpoint metadata for the given step.
Retrieves metadata describing the structure of the PyTree stored at the given step. If no step is provided, the method resolves to the latest available checkpoint.
- Parameters:
step (
UnionType[int,CheckpointMetadata,None]) – The step number to retrieve metadata for. If None, the latest step is used. Can also be aCheckpointMetadataobject, from which the step is extracted.- Return type:
CheckpointMetadata[PyTreeOf[UnionType[AbstractArray,AbstractShardedArray,int,float,number,bytes,bool,str]]]- Returns:
A
CheckpointMetadataobject containingPyTreeMetadata, along with checkpoint timestamp and metrics information.
- checkpointables_metadata(step=None)[source][source]#
Returns checkpoint metadata for the given step.
Retrieves metadata describing the structure of the checkpointables stored at the given step. If no step is provided, the method resolves to the latest available checkpoint.
- Parameters:
step (
UnionType[int,CheckpointMetadata,None]) – The step number to retrieve metadata for. If None, the latest step is used. Can also be aCheckpointMetadataobject, from which the step is extracted.- Return type:
CheckpointMetadata[dict[str,Any]]- Returns:
A
CheckpointMetadataobject containing a dict[str, Any] describing the checkpointables, along with checkpoint timestamp and metrics information.
- reload()[source][source]#
Reloads internal properties from the root directory.
Updates the list of available checkpoints by rescanning the storage location. Use this method to sync the checkpointer with the file system if checkpoints have been added or removed externally.
- is_saving_in_progress()[source][source]#
Returns whether a checkpoint save operation is currently in progress.
Checks if there are any background persistence operations currently active.
- Return type:
bool- Returns:
True if a save operation is in progress, False otherwise.
CheckpointMetadata#
- final class orbax.checkpoint.experimental.v1.training.CheckpointMetadata(step, path, *, metadata, init_timestamp_nsecs=None, commit_timestamp_nsecs=None, custom_metadata=None, metrics=None)[source][source]#
Represents metadata for a single checkpoint (corresponding to a step).
Like its parent, the class has a metadata attribute that is a generic type. The .metadata attribute contains checkpointable-specific metadata. If a PyTree was saved, it will contain
PyTreeMetadata, otherwise if `Checkpointable`s were saved, it will be a dictionary mapping names to metadata.The Orbax checkpointing API provides two symmetric levels of interaction:
Higher level (sequence-of-steps API): Accessed via
Checkpointer.Lower level (individual path API): Accessed via free functions.
CheckpointMetadata objects are returned by both API levels using the same core methods (
metadata()andcheckpointables_metadata()), reflecting this inherent symmetry.See superclass documentation for more information, and for a list of base attributes. This class defines several additional attributes that are relevant to checkpoints in a sequence, but not necessarily to a singular checkpoint in isolation.
Example Usage:
from orbax.checkpoint import v1 as ocp # Higher level (sequence-of-steps API) with ocp.training.Checkpointer('/path/to/my/checkpoints') as ckptr: ckpt_meta = ckptr.metadata(100) # Lower level (individual path API) ckpt_meta = ocp.metadata('/path/to/my/checkpoints/100') # Inspect checkpoint-level properties print(f'Init time (ns): {ckpt_meta.init_timestamp_nsecs}') print(f'Commit time (ns): {ckpt_meta.commit_timestamp_nsecs}') print(f'Custom metadata: {ckpt_meta.custom_metadata}') # The `.metadata` field contains checkpointable-specific metadata, # which will be `PyTreeMetadata` or dict[str, CheckpointableMetadata] # depending on what was saved. print(f'Checkpointable metadata: {ckpt_meta.metadata}')
See also
RootMetadata.See the parent class,
CheckpointMetadata, for base attributes.- Additional Attributes:
step: The step number of the checkpoint. metrics: An optional dictionary containing user-provided metrics saved
alongside the checkpoint.
RootMetadata#
- final class orbax.checkpoint.experimental.v1.training.RootMetadata(*, directory, custom_metadata=None)[source][source]#
Metadata of a sequence of checkpoint at root level (contains all steps).
This class represents the top-level metadata for an entire checkpointing directory, distinct from step-specific metadata. It associates the physical storage location of the sequence with arbitrary, user-defined information that applies to all steps (e.g., experiment configuration).
- Example Usage:
RootMetadata objects are returned by
root_metadata().It can be used to inspect checkpoint-wide information, such as experiment configuration:
import orbax.checkpoint.v1 as ocp ckptr = ocp.training.Checkpointer('/path/to/my/checkpoints') root_meta = ckptr.root_metadata() print(f'Directory: {root_meta.directory}') print(f'Custom metadata: {root_meta.custom_metadata}')
See also
CheckpointMetadata.
- directory#
The directory of the checkpoint sequence.
- Type:
path_types.Path
- custom_metadata#
User-provided custom metadata. An arbitrary JSON-serializable dictionary the user can use to store additional information. The field is treated as opaque by Orbax.
- Type:
tree_types.JsonType | None
- __delattr__(name)#
Implement delattr(self, name).
- __eq__(other)#
Return self==value.
- __hash__()#
Return hash(self).
- __init__(*, directory, custom_metadata=None)#
- __setattr__(name, value)#
Implement setattr(self, name, value).