ocp.v1.training module#

Public API for training package.

Checkpointer#

final class orbax.checkpoint.experimental.v1.training.Checkpointer(directory, *, context=None, save_decision_policy=None, preservation_policy=None, step_name_format=None, custom_metadata=None, cleanup_tmp_directories=False, lightweight_initialize=False)[source][source]#

An object that manages a sequence of checkpoints in a training loop.

__init__(directory, *, context=None, save_decision_policy=None, preservation_policy=None, step_name_format=None, custom_metadata=None, cleanup_tmp_directories=False, lightweight_initialize=False)[source][source]#

Initializes a Checkpointer.

IMPORTANT: This class is not thread safe. All APIs should be called across all available processes, from the main thread.

The Checkpointer is intended for use in a training loop, where a sequence of checkpoints are saved at regular intervals. Example usage:

# Configure the frequency at which checkpoints are saved.
save_decision_policies = ocp.training.save_decision_policies
# Save every 1000 steps, or when a preemption is detected.
save_decision_policy = save_decision_policies.AnySavePolicy([
    save_decision_policies.FixedIntervalPolicy(1000),
    save_decision_policies.PreemptionPolicy(),
])

# Configure the checkpoints to preserve (avoid garbage collection).
preservation_policies = ocp.training.preservation_policies
# Avoid garbage collection on the latest 10, or every 10000 steps.
preservation_policy = preservation_policies.AnyPreservationPolicy([
    preservation_policies.LatestN(10),
    preservation_policies.EveryNSteps(10000),
])

with ocp.training.Checkpointer(
    directory,
    save_decision_policy=save_decision_policy,
    preservation_policy=preservation_policy,
) as ckptr:
  if ckptr.latest is None:
    model_state = init_from_scratch(rng)
  else:
    model_state = ckptr.load()  # Loads latest checkpoint.
    # Note: prefer to specify the abstract tree if available.
    model_state = ckptr.load(
        ckptr.latest, abstract_state=abstract_model_state)
  start_step = ckptr.latest.step if ckptr.latest else 0
  for step in range(start_step, num_steps):
    model_state = train_step(model_state)
    # Saves a checkpoint if needed (according to `save_decision_policy`).
    ckptr.save(step, model_state)

Prefer to use the context manager style as shown above, which ensures that the Checkpointer is closed properly and any outstanding async operations are completed.

Parameters:
  • directory (path_types.PathLike) – The root directory where checkpoints are stored. The directory will be created if it does not exist.

  • context (context_lib.Context | None) – A Context object that will be used to wrap all function calls for this Checkpointer.

  • save_decision_policy (save_decision_policies.SaveDecisionPolicy | None) – A policy used to determine when a checkpoint should be saved. If not provided, the Checkpointer saves as often as possible by default (assuming no checkpoint is currently being saved), and saves when a preemption is detected by the JAX distributed system.

  • preservation_policy (preservation_policies.PreservationPolicy | None) – A policy used to determine when a checkpoint should be preserved. Any checkpoints not preserved are garbage collected. If not provided,

  • step_name_format (path_step_lib.NameFormat[CheckpointMetadata[None]] | None) – An object used to specify the format for step paths. By default, steps are rendered as simple integers, like /root/directory/<step>.

  • custom_metadata (tree_types.JsonType | None) – A JSON dictionary representing user-specified custom metadata. This should be information that is relevant to the entire sequence of checkpoints, rather than to any single checkpoint.

  • cleanup_tmp_directories (bool) – If True, cleans up any existing temporary directories on Checkpointer creation.

  • lightweight_initialize (bool) – If True, checkpoint step metadata is not read on Checkpointer initialization during checkpoint info loading. This is useful to improve init performance when there are O(1k) or more existing checkpoint steps present and checkpoint info properties like time and metrics are not needed.

property directory: Path#

The root directory where checkpoint steps are located.

Return type:

Path

property latest: CheckpointMetadata[None] | None#

Returns the latest CheckpointMetadata, or None if no checkpoints exist.

See checkpoints documentation below.

Return type:

Optional[CheckpointMetadata[None], None]

Returns:

The latest checkpoint, or None if no checkpoints exist.

property checkpoints: Sequence[CheckpointMetadata[None]]#

Returns a list of CheckpointMetadata, sorted ascending by step.

The method returns a list of CheckpointMetadata objects, which contain selected properties describing the checkpoint. Contrast this with the methods metadata() and checkpointables_metadata(), which may perform a more expensive disk read to retrieve additional information. This method only returns cheap cacheable properties like step and timestamp. The return value is annotated as CheckpointMetadata[None] because the core metadata property is not retrieved, and is therefore None.

The property is cached to avoid repeated disk reads. This is not a problem unless checkpoints are manually deleted, or deleted by some other job or class that Checkpointer is unaware of. Note that doing this is discouraged.

Return type:

Sequence[CheckpointMetadata[None]]

Returns:

A list of checkpoints, sorted ascending by step.

should_save(step)[source][source]#

Returns whether a checkpoint should be saved at the given step.

Return type:

bool

save(step, state, *, checkpointable_name='state', force=False, overwrite=False, metrics=None, custom_metadata=None)[source][source]#

Saves a checkpoint, if dictated by SaveDecisionPolicy.

This method behaves similarly to the standalone free function save() (see documentation), but performs additional tasks related to managing a sequence of checkpoint steps.

It consists roughly of the following steps:
  • Check whether a checkpoint should be saved at the given step.

  • Check whether a save is already in progress. If so, wait for it to finish.

  • Save to a directory given by root_directory / <step_format>.

  • Perform garbage collection if necessary.

  • Return whether a checkpoint was saved or not.

It is important to note that the Checkpointer never allows saving more than one checkpoint at a time. Depending on the SaveDecisionPolicy, a checkpoint may be saved or skipped at a given step, but if a save is initiated, as dictated by the policy, then it will proceed as normal as long as no other save is currently in progress. If a save is already in progress, the function will block until the previous save has finished.

Example usage:
  1. Basic Usage: Save a PyTree at a specific training step. The checkpointer automatically manages the step-based directory structure inside your root folder:

    from orbax.checkpoint.v1 import training
    
    # Initialize the checkpointer for a directory
    ckptr = training.Checkpointer(directory)
    
    # Save the tree at step 0.
    saved = ckptr.save(step=0, state=tree)
    
    # Clean up background threads gracefully when the training loop ends
    ckptr.close()
    
  2. Advanced Saving with Metrics and Metadata: Attach JSON-serializable metrics (like loss/accuracy) and custom metadata to a specific step for thorough experiment tracking:

    from orbax.checkpoint.v1 import training
    
    ckptr = training.Checkpointer(directory)
    
    ckptr.save(
        step=1,
        state=tree,
        metrics={'loss': 0.12, 'accuracy': 0.95},
        custom_metadata={'description': 'Model after epoch 1'},
    )
    
    ckptr.close()
    
Parameters:
  • step (int) – The step number to save.

  • state (tree_types.PyTreeOf[tree_types.Leaf]) – The PyTree to save.

  • checkpointable_name (str) – The name of the checkpointable to save a pytree under. Defaults to ‘pytree’.

  • force (bool) – If True, ignores all SaveDecisionPolicy checks, and always decides to save a checkpoint.

  • overwrite (bool) – If True, deletes any existing checkpoint at the given step before saving. Otherwise, raises an error if the checkpoint already exists.

  • metrics (tree_types.JsonType | None) – A PyTree of metrics to be saved with the checkpoint.

  • custom_metadata (tree_types.JsonType | None) – A JSON dictionary representing user-specified custom metadata. This should be information that is relevant to the checkpoint at the given step, rather than to the entire sequence of checkpoints.

Return type:

bool

Returns:

Whether a checkpoint was saved or not.

save_checkpointables(step, checkpointables, *, force=False, overwrite=False, metrics=None, custom_metadata=None)[source][source]#

Saves a dictionary of checkpointable objects at the given step.

This method saves a dictionary of checkpointable objects, mapping string names to values. See the guide on Checkpointables for more details on checkpointables. Also see documentation for save().

Example

  1. Basic Usage: Save multiple named items (checkpointables) at a specific step. The dictionary keys define the names of the saved components:

    from orbax.checkpoint.v1 import training
    
    # Initialize the checkpointer for a directory
    ckptr = training.Checkpointer(directory)
    
    # Save multiple items, such as model weights and optimizer state
    items_to_save = {
        'model': my_model_state,
        'optimizer': my_opt_state,
    }
    
    saved = ckptr.save_checkpointables(
        step=0,
        checkpointables=items_to_save
    )
    
    # Clean up background threads gracefully when the training loop ends
    ckptr.close()
    
  2. Advanced Saving with Metrics and Metadata: Attach JSON-serializable metrics and custom metadata to a specific step for thorough experiment tracking:

    from orbax.checkpoint.v1 import training
    
    ckptr = training.Checkpointer(directory)
    items_to_save = {'model': my_model_state}
    
    ckptr.save_checkpointables(
        step=1,
        checkpointables=items_to_save,
        metrics={'loss': 0.12, 'accuracy': 0.95},
        custom_metadata={'description': 'Model after epoch 1'},
    )
    
    ckptr.close()
    
Parameters:
  • step (int) – The step number to save.

  • checkpointables (dict[str, Any]) – A dictionary mapping string names to the corresponding objects (checkpointables) that need to be saved.

  • force (bool) – If True, ignores all policy checks and always decides to save a checkpoint.

  • overwrite (bool) – If True, deletes any existing checkpoint at the given step before saving. Otherwise, raises an error if the checkpoint already exists.

  • metrics (tree_types.JsonType | None) – A dictionary of metrics to be saved with the checkpoint. Must be JSON-serializable.

  • custom_metadata (tree_types.JsonType | None) – A JSON dictionary representing user-specified custom metadata relevant to the checkpoint at this specific step.

Returns:

True if the checkpoint was successfully saved, False otherwise.

Return type:

bool

save_async(step, state, *, checkpointable_name='state', force=False, overwrite=False, metrics=None, custom_metadata=None)[source][source]#

Saves a checkpoint asynchronously.

This function is the asynchronous equivalent of save(). It accepts the exact same arguments; please refer to that method for detailed descriptions.

This method executes mostly in the background, blocking the main thread for as little time as possible.

Example

async_response = ckptr.save_async(step=0, state=tree)
saved = async_response.result()
Parameters:
  • step (int) – The step number to save.

  • state (tree_types.PyTreeOf[tree_types.Leaf]) – The PyTree to save.

  • checkpointable_name (str) – The name of the checkpointable to save a pytree under. Defaults to ‘pytree’.

  • force (bool) – See save.

  • overwrite (bool) – See save.

  • metrics (tree_types.JsonType | None) – See save.

  • custom_metadata (tree_types.JsonType | None) – See save.

Return type:

async_types.AsyncResponse[bool]

Returns:

An AsyncResponse, which can be awaited via result(), which returns a bool indicating whether a checkpoint was saved or not.

save_checkpointables_async(step, checkpointables, *, force=False, overwrite=False, metrics=None, custom_metadata=None)[source][source]#

Saves checkpointable objects asynchronously.

This function is the asynchronous equivalent of save_checkpointables(). Please refer to that method for detailed instructions and argument descriptions.

Example

Save checkpointable objects asynchronously:

async_response = ckptr.save_checkpointables_async(
    step=0,
    checkpointables=items_to_save
)
saved = async_response.result()
Parameters:
  • step (int) – The step number to save.

  • checkpointables (dict[str, Any]) – A dictionary mapping string names to objects to save.

  • force (bool) – See save_checkpointables.

  • overwrite (bool) – See save_checkpointables.

  • metrics (tree_types.JsonType | None) – See save_checkpointables.

  • custom_metadata (tree_types.JsonType | None) – See save_checkpointables.

Return type:

async_types.AsyncResponse[bool]

Returns:

An object representing the background operation. Call .result() on it to block and return a boolean indicating whether the checkpoint was successfully saved.

Raises:

StepAlreadyExistsError – If overwrite is False and a checkpoint at the target step already exists.

load(step=None, abstract_state=None, *, checkpointable_name='state')[source][source]#

Loads a PyTree checkpoint at the given step.

This method behaves similarly to the standalone free function load().

Note: Loading a PyTree without providing an abstract_state is provided purely for convenience. For serious or production use cases, it is STRONGLY recommended to always provide an abstract_state to ensure the restored PyTree strictly matches the expected shapes, dtypes, and sharding.

Example

  1. Basic Loading: Load a PyTree without providing an abstract structure. By passing step=None (or omitting it), it automatically loads the latest step:

    from orbax.checkpoint.v1 import training
    
    # Initialize the checkpointer for the directory
    ckptr = training.Checkpointer(directory)
    
    # Load the saved PyTree from latest step
    restored_tree = ckptr.load(step=None)
    
  2. Loading with an Abstract PyTree: Provide an abstract structure (such as target shapes and dtypes) to ensure the restored PyTree is safely and correctly formatted:

    import jax
    import jax.numpy as jnp
    from orbax.checkpoint.v1 import training
    
    ckptr = training.Checkpointer(directory)
    
    # Define the expected structure (shapes and dtypes) to restore into
    target_structure = {
        'weights': jax.ShapeDtypeStruct((128, 128), dtype=jnp.float32),
        'bias': jax.ShapeDtypeStruct((128,), dtype=jnp.float32)
    }
    
    # Restore exactly matching the target structure
    restored_tree = ckptr.load(
        step=1,
        abstract_state=target_structure
    )
    
Parameters:
  • step (UnionType[int, CheckpointMetadata, None]) – The step number or CheckpointMetadata to load. If None, the checkpointer will attempt to resolve and load the latest existing checkpoint.

  • abstract_state (Optional[PyTreeOf[UnionType[AbstractArray, AbstractShardedArray, int, float, number, bytes, bool, str]], None]) – The abstract PyTree to load.

  • checkpointable_name (str) – The name of the checkpointable to load a pytree under. Defaults to ‘pytree’.

Return type:

PyTreeOf[UnionType[Array, ndarray, int, float, number, bytes, bool, str]]

Returns:

The loaded PyTree.

load_checkpointables(step=None, abstract_checkpointables=None)[source][source]#

Loads a set of checkpointables at the given step.

This method behaves similarly to the standalone free function load_checkpointables().

This function retrieves multiple named items (such as model weights or optimizer states) from a specific checkpoint directory. If no step is provided, it automatically resolves to and loads the most recently saved checkpoint.

Note: Loading without providing an abstract_checkpointables dictionary is provided purely for convenience. For serious or production use cases, it is STRONGLY recommended to always provide abstract_checkpointables to ensure the restored items strictly match the exact nested structures, shapes, and data types expected.

Example

  1. Basic Loading: Load multiple named items (such as a model and optimizer) from a specific step. If step is omitted, it resolves to the latest available checkpoint:

    from orbax.checkpoint.v1 import training
    
    # Initialize the checkpointer for the directory
    ckptr = training.Checkpointer(directory)
    
    # Load all checkpointables saved at the latest step
    restored_items = ckptr.load_checkpointables(step=None)
    
    # Access the individual components by their original string keys
    my_model = restored_items["model"]
    my_opt = restored_items["optimizer"]
    
  2. Loading with Abstract Checkpointables (Recommended): Provide a dictionary of abstract structures to ensure the restored items strictly match your expected shapes and data types:

    import jax
    import jax.numpy as jnp
    from orbax.checkpoint.v1 import training
    
    ckptr = training.Checkpointer(directory)
    
    # Define the expected structure for each named item using JAX arrays
    target_items = {
        "model": {
            'weights': jax.ShapeDtypeStruct((128, 128), jnp.float32),
            'bias': jax.ShapeDtypeStruct((128,), jnp.float32)
        },
        "optimizer": {
            'momentum': jax.ShapeDtypeStruct((128, 128), jnp.float32)
        }
    }
    
    # Restore exactly matching the target structures
    restored_items = ckptr.load_checkpointables(
        step=1,
        abstract_checkpointables=target_items
    )
    
  3. Partial Loading: If you only need to load a subset of checkpointables (e.g., loading model weights but omitting optimizer state), you can provide an abstract_checkpointables dictionary containing only the keys for the items you wish to restore:

    import jax
    import jax.numpy as jnp
    from orbax.checkpoint.v1 import training
    
    ckptr = training.Checkpointer(directory)
    
    # Define abstract structure for ONLY the items to load
    target_items = {
        "model": {
            'weights': jax.ShapeDtypeStruct((128, 128), jnp.float32),
            'bias': jax.ShapeDtypeStruct((128,), jnp.float32)
        },
    }
    
    # Load only "model", omitting "optimizer"
    restored_items = ckptr.load_checkpointables(
        step=1,
        abstract_checkpointables=target_items
    )
    my_model = restored_items["model"]
    # my_opt = restored_items["optimizer"]
    
Parameters:
  • step (UnionType[int, CheckpointMetadata, None]) – The step number or CheckpointMetadata to load. If None, the checkpointer will attempt to resolve and load the latest existing checkpoint.

  • abstract_checkpointables (UnionType[dict[str, Any], None]) – A dictionary mapping string names to their corresponding abstract structures (e.g., target PyTrees). This guides the loading process to ensure shape and type compliance. If provided, it can be used to load only a subset of checkpointables by providing only a subset of keys.

Returns:

A dictionary containing the loaded checkpointable objects,

keyed by string names. If abstract_checkpointables was specified, returns only the keys specified in that dict, otherwise returns all keys saved with save_checkpointables.

Return type:

dict[str, Any]

load_async(step=None, abstract_state=None)[source][source]#

Not yet supported.

Return type:

AsyncResponse[PyTreeOf[UnionType[Array, ndarray, int, float, number, bytes, bool, str]]]

load_checkpointables_async(step=None, abstract_checkpointables=None)[source][source]#

Loads a set of checkpointables asynchronously at the given step.

Return type:

AsyncResponse[dict[str, Any]]

metadata(step=None)[source][source]#

Returns checkpoint metadata for the given step.

Retrieves metadata describing the structure of the PyTree stored at the given step. If no step is provided, the method resolves to the latest available checkpoint.

Parameters:

step (UnionType[int, CheckpointMetadata, None]) – The step number to retrieve metadata for. If None, the latest step is used. Can also be a CheckpointMetadata object, from which the step is extracted.

Return type:

CheckpointMetadata[PyTreeOf[UnionType[AbstractArray, AbstractShardedArray, int, float, number, bytes, bool, str]]]

Returns:

A CheckpointMetadata object containing PyTreeMetadata, along with checkpoint timestamp and metrics information.

checkpointables_metadata(step=None)[source][source]#

Returns checkpoint metadata for the given step.

Retrieves metadata describing the structure of the checkpointables stored at the given step. If no step is provided, the method resolves to the latest available checkpoint.

Parameters:

step (UnionType[int, CheckpointMetadata, None]) – The step number to retrieve metadata for. If None, the latest step is used. Can also be a CheckpointMetadata object, from which the step is extracted.

Return type:

CheckpointMetadata[dict[str, Any]]

Returns:

A CheckpointMetadata object containing a dict[str, Any] describing the checkpointables, along with checkpoint timestamp and metrics information.

reload()[source][source]#

Reloads internal properties from the root directory.

Updates the list of available checkpoints by rescanning the storage location. Use this method to sync the checkpointer with the file system if checkpoints have been added or removed externally.

is_saving_in_progress()[source][source]#

Returns whether a checkpoint save operation is currently in progress.

Checks if there are any background persistence operations currently active.

Return type:

bool

Returns:

True if a save operation is in progress, False otherwise.

wait()[source][source]#

Waits for any outstanding async operations to complete.

This method blocks until all background tasks, such as asynchronous saves, have finished. Use this method to ensure that all operations are finalized before proceeding with dependent actions.

close()[source][source]#

Waits for pending async operations to complete and releases resources.

This method blocks until all background tasks, such as asynchronous saves, have finished. It also performs necessary cleanup, such as closing file handles.

CheckpointMetadata#

final class orbax.checkpoint.experimental.v1.training.CheckpointMetadata(step, path, *, metadata, init_timestamp_nsecs=None, commit_timestamp_nsecs=None, custom_metadata=None, metrics=None)[source][source]#

Represents metadata for a single checkpoint (corresponding to a step).

Like its parent, the class has a metadata attribute that is a generic type. The .metadata attribute contains checkpointable-specific metadata. If a PyTree was saved, it will contain PyTreeMetadata, otherwise if `Checkpointable`s were saved, it will be a dictionary mapping names to metadata.

The Orbax checkpointing API provides two symmetric levels of interaction:

  1. Higher level (sequence-of-steps API): Accessed via Checkpointer.

  2. Lower level (individual path API): Accessed via free functions.

CheckpointMetadata objects are returned by both API levels using the same core methods (metadata() and checkpointables_metadata()), reflecting this inherent symmetry.

See superclass documentation for more information, and for a list of base attributes. This class defines several additional attributes that are relevant to checkpoints in a sequence, but not necessarily to a singular checkpoint in isolation.

Example Usage:

from orbax.checkpoint import v1 as ocp

# Higher level (sequence-of-steps API)
with ocp.training.Checkpointer('/path/to/my/checkpoints') as ckptr:
  ckpt_meta = ckptr.metadata(100)

# Lower level (individual path API)
ckpt_meta = ocp.metadata('/path/to/my/checkpoints/100')

# Inspect checkpoint-level properties
print(f'Init time (ns): {ckpt_meta.init_timestamp_nsecs}')
print(f'Commit time (ns): {ckpt_meta.commit_timestamp_nsecs}')
print(f'Custom metadata: {ckpt_meta.custom_metadata}')

# The `.metadata` field contains checkpointable-specific metadata,
# which will be `PyTreeMetadata` or dict[str, CheckpointableMetadata]
# depending on what was saved.
print(f'Checkpointable metadata: {ckpt_meta.metadata}')

See also RootMetadata.

See the parent class, CheckpointMetadata, for base attributes.

Additional Attributes:

step: The step number of the checkpoint. metrics: An optional dictionary containing user-provided metrics saved

alongside the checkpoint.

__init__(step, path, *, metadata, init_timestamp_nsecs=None, commit_timestamp_nsecs=None, custom_metadata=None, metrics=None)[source][source]#

RootMetadata#

final class orbax.checkpoint.experimental.v1.training.RootMetadata(*, directory, custom_metadata=None)[source][source]#

Metadata of a sequence of checkpoint at root level (contains all steps).

This class represents the top-level metadata for an entire checkpointing directory, distinct from step-specific metadata. It associates the physical storage location of the sequence with arbitrary, user-defined information that applies to all steps (e.g., experiment configuration).

Example Usage:

RootMetadata objects are returned by root_metadata().

It can be used to inspect checkpoint-wide information, such as experiment configuration:

import orbax.checkpoint.v1 as ocp
ckptr = ocp.training.Checkpointer('/path/to/my/checkpoints')
root_meta = ckptr.root_metadata()

print(f'Directory: {root_meta.directory}')
print(f'Custom metadata: {root_meta.custom_metadata}')

See also CheckpointMetadata.

directory#

The directory of the checkpoint sequence.

Type:

path_types.Path

custom_metadata#

User-provided custom metadata. An arbitrary JSON-serializable dictionary the user can use to store additional information. The field is treated as opaque by Orbax.

Type:

tree_types.JsonType | None

__delattr__(name)#

Implement delattr(self, name).

__eq__(other)#

Return self==value.

__hash__()#

Return hash(self).

__init__(*, directory, custom_metadata=None)#
__setattr__(name, value)#

Implement setattr(self, name, value).