# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines a class for managing a sequence of checkpoints in a training loop."""
from __future__ import annotations
import typing
from typing import Any, Callable, Iterable, Sequence
from absl import logging
from etils import epy
from orbax.checkpoint import checkpoint_manager
from orbax.checkpoint.experimental.v1._src.context import context as context_lib
import orbax.checkpoint.experimental.v1._src.handlers.global_registration # pylint: disable=unused-import
from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout
from orbax.checkpoint.experimental.v1._src.loading import loading
from orbax.checkpoint.experimental.v1._src.metadata import loading as metadata_loading
from orbax.checkpoint.experimental.v1._src.metadata import types as metadata_types
from orbax.checkpoint.experimental.v1._src.path import step as path_step_lib
from orbax.checkpoint.experimental.v1._src.path import types as path_types
from orbax.checkpoint.experimental.v1._src.saving import saving
from orbax.checkpoint.experimental.v1._src.saving import validation
from orbax.checkpoint.experimental.v1._src.synchronization import thread_utils
from orbax.checkpoint.experimental.v1._src.synchronization import types as async_types
from orbax.checkpoint.experimental.v1._src.training import errors
from orbax.checkpoint.experimental.v1._src.training import preservation_policies
from orbax.checkpoint.experimental.v1._src.training import save_decision_policies
from orbax.checkpoint.experimental.v1._src.training.metadata import types as training_metadata_types
from orbax.checkpoint.experimental.v1._src.tree import types as tree_types
from typing_extensions import deprecated # pytype: disable=not-supported-yet
CheckpointMetadata = training_metadata_types.CheckpointMetadata
RootMetadata = training_metadata_types.RootMetadata
STATE_CHECKPOINTABLE_KEY = checkpoint_layout.STATE_CHECKPOINTABLE_KEY
class _AsyncSaveResponse(async_types.AsyncResponse[bool]):
"""Response for asynchronous saving."""
def __init__(
self, manager: checkpoint_manager.CheckpointManager
):
async def _wait() -> bool:
# If a background operation fails wait_until_finished() will re-raise the
# exception back to caller.
manager.wait_until_finished()
return True
self._thread_runner = thread_utils.BackgroundThreadRunner[bool](_wait())
def result(self, timeout: float | None = None) -> bool:
return self._thread_runner.result(timeout=timeout)
def on_complete(self, callback: Callable[[bool], None]) -> None:
self._thread_runner.on_complete(callback)
def _resolve_integer_step(
step: int | CheckpointMetadata,
) -> int:
if isinstance(step, int):
return step
return step.step
[docs]
@typing.final
class Checkpointer(epy.ContextManager):
"""An object that manages a sequence of checkpoints in a training loop."""
[docs]
def __init__(
self,
directory: path_types.PathLike,
*,
context: context_lib.Context | None = None,
save_decision_policy: (
save_decision_policies.SaveDecisionPolicy | None
) = None,
preservation_policy: (
preservation_policies.PreservationPolicy | None
) = None,
step_name_format: (
path_step_lib.NameFormat[CheckpointMetadata[None]] | None
) = None,
custom_metadata: tree_types.JsonType | None = None,
cleanup_tmp_directories: bool = False,
lightweight_initialize: bool = False,
):
"""Initializes a Checkpointer.
IMPORTANT: This class is not thread safe. All APIs should be called across
all available processes, from the main thread.
The Checkpointer is intended for use in a training loop, where a sequence
of checkpoints are saved at regular intervals. Example usage::
# Configure the frequency at which checkpoints are saved.
save_decision_policies = ocp.training.save_decision_policies
# Save every 1000 steps, or when a preemption is detected.
save_decision_policy = save_decision_policies.AnySavePolicy([
save_decision_policies.FixedIntervalPolicy(1000),
save_decision_policies.PreemptionPolicy(),
])
# Configure the checkpoints to preserve (avoid garbage collection).
preservation_policies = ocp.training.preservation_policies
# Avoid garbage collection on the latest 10, or every 10000 steps.
preservation_policy = preservation_policies.AnyPreservationPolicy([
preservation_policies.LatestN(10),
preservation_policies.EveryNSteps(10000),
])
with ocp.training.Checkpointer(
directory,
save_decision_policy=save_decision_policy,
preservation_policy=preservation_policy,
) as ckptr:
if ckptr.latest is None:
model_state = init_from_scratch(rng)
else:
model_state = ckptr.load() # Loads latest checkpoint.
# Note: prefer to specify the abstract tree if available.
model_state = ckptr.load(
ckptr.latest, abstract_state=abstract_model_state)
start_step = ckptr.latest.step if ckptr.latest else 0
for step in range(start_step, num_steps):
model_state = train_step(model_state)
# Saves a checkpoint if needed (according to `save_decision_policy`).
ckptr.save(step, model_state)
Prefer to use the context manager style as shown above, which ensures that
the Checkpointer is closed properly and any outstanding async operations are
completed.
Args:
directory: The root directory where checkpoints are stored. The directory
will be created if it does not exist.
context: A :py:class:`~orbax.checkpoint.v1.Context` object that will be
used to wrap all function calls for this `Checkpointer`.
save_decision_policy: A policy used to determine when a checkpoint should
be saved. If not provided, the `Checkpointer` saves as often as possible
by default (assuming no checkpoint is currently being saved), and saves
when a preemption is detected by the JAX distributed system.
preservation_policy: A policy used to determine when a checkpoint should
be preserved. Any checkpoints not preserved are garbage collected. If
not provided,
step_name_format: An object used to specify the format for step paths. By
default, steps are rendered as simple integers, like
`/root/directory/<step>`.
custom_metadata: A JSON dictionary representing user-specified custom
metadata. This should be information that is relevant to the entire
sequence of checkpoints, rather than to any single checkpoint.
cleanup_tmp_directories: If True, cleans up any existing temporary
directories on Checkpointer creation.
lightweight_initialize: If True, checkpoint step metadata is not read on
Checkpointer initialization during checkpoint info loading. This is
useful to improve init performance when there are O(1k) or more existing
checkpoint steps present and checkpoint info properties like `time` and
`metrics` are not needed.
"""
self._context = context_lib.Context(context or context_lib.get_context())
default_save_decision_policy = save_decision_policies.AnySavePolicy([
save_decision_policies.InitialSavePolicy(),
save_decision_policies.ContinuousCheckpointingPolicy(),
save_decision_policies.PreemptionCheckpointingPolicy(),
])
save_decision_policy = save_decision_policy or default_save_decision_policy
default_preservation_policy = preservation_policies.PreserveAll()
preservation_policy = preservation_policy or default_preservation_policy
self._step_name_format = (
step_name_format or path_step_lib.standard_name_format()
)
options = checkpoint_manager.CheckpointManagerOptions(
save_decision_policy=save_decision_policy,
preservation_policy=preservation_policy,
step_name_format=step_name_format,
cleanup_tmp_directories=cleanup_tmp_directories,
lightweight_initialize=lightweight_initialize,
max_to_keep=None, # Unlimited.
todelete_full_path=self._context.deletion_options.gcs_deletion_options.todelete_full_path,
async_options=self._context.async_options.v0(),
file_options=self._context.file_options.v0(),
multiprocessing_options=self._context.multiprocessing_options.v0(),
temporary_path_class=self._context.file_options.temporary_path_class,
# Prevent the checkpoint manager from writing metrics on its own. This
# class will take responsibility for writing metrics.
prevent_write_metrics=True,
)
self._manager = checkpoint_manager.CheckpointManager(
directory,
options=options,
metadata=custom_metadata,
)
@property
def directory(self) -> path_types.Path:
"""The root directory where checkpoint steps are located."""
return self._manager.directory
@property
def latest(self) -> CheckpointMetadata[None] | None:
"""Returns the latest :py:class:`.CheckpointMetadata`, or None if no checkpoints exist.
See `checkpoints` documentation below.
Returns:
The latest checkpoint, or None if no checkpoints exist.
"""
if not self.checkpoints:
return None
return self.checkpoints[-1]
@property
def checkpoints(self) -> Sequence[CheckpointMetadata[None]]:
"""Returns a list of :py:class:`.CheckpointMetadata`, sorted ascending by step.
The method returns a list of :py:class:`.CheckpointMetadata` objects, which
contain selected properties describing the checkpoint. Contrast this with
the methods :py:func:`.metadata` and
:py:func:`.checkpointables_metadata`, which may perform a more expensive
disk read to retrieve additional information. This method only returns
cheap cacheable properties like step and timestamp. The return value is
annotated as :py:class:`.CheckpointMetadata[None]` because the core
`metadata` property is not retrieved, and is therefore `None`.
The property is cached to avoid repeated disk reads. This is not a problem
unless checkpoints are manually deleted, or deleted by some other job or
class that `Checkpointer` is unaware of. Note that doing this is
discouraged.
Returns:
A list of checkpoints, sorted ascending by step.
"""
infos = sorted(self._manager._checkpoints, key=lambda info: info.step) # pylint: disable=protected-access
return [
CheckpointMetadata[None](
info.step,
path=self.directory / self._step_name_format.build_name(info.step),
metadata=None,
metrics=info.metrics,
commit_timestamp_nsecs=int(info.time.timestamp() * 1e9),
)
for info in infos
]
def _resolve_existing_checkpoint(
self, step: int | CheckpointMetadata | None
) -> CheckpointMetadata[None]:
if step is None:
latest = self.latest
if latest is None:
raise errors.StepNotFoundError(
'Specified `step=None`, but no checkpoints were found.'
)
return latest
step = _resolve_integer_step(step)
for checkpoint in self.checkpoints:
if checkpoint.step == step:
return checkpoint
raise errors.StepNotFoundError(f'No checkpoint found at step {step}.')
[docs]
def should_save(self, step: int) -> bool:
"""Returns whether a checkpoint should be saved at the given step."""
with context_lib.get_context(self._context):
step = _resolve_integer_step(step)
return self._manager.should_save(step)
[docs]
def save(
self,
step: int,
state: tree_types.PyTreeOf[tree_types.Leaf],
*,
checkpointable_name: str = STATE_CHECKPOINTABLE_KEY,
force: bool = False,
overwrite: bool = False,
metrics: tree_types.JsonType | None = None,
custom_metadata: tree_types.JsonType | None = None,
) -> bool:
"""Saves a checkpoint, if dictated by :py:class:`.SaveDecisionPolicy`.
This method behaves similarly to the standalone free function
:py:func:`~orbax.checkpoint.v1.save` (see
documentation), but performs additional tasks related to managing a sequence
of checkpoint steps.
It consists roughly of the following steps:
- Check whether a checkpoint should be saved at the given step.
- Check whether a save is already in progress. If so, wait for it to
finish.
- Save to a directory given by `root_directory / <step_format>`.
- Perform garbage collection if necessary.
- Return whether a checkpoint was saved or not.
It is important to note that the `Checkpointer` never allows saving more
than one checkpoint at a time. Depending on the
:py:class:`.SaveDecisionPolicy`, a checkpoint may be saved or skipped at a
given step, but if a save is initiated, as dictated by the policy, then it
will proceed as normal as long as no other save is currently in progress. If
a save is already in progress, the function will block until the previous
save has finished.
Example usage:
1. Basic Usage:
Save a PyTree at a specific training step. The checkpointer
automatically manages the step-based directory structure inside your
root folder::
from orbax.checkpoint.v1 import training
# Initialize the checkpointer for a directory
ckptr = training.Checkpointer(directory)
# Save the tree at step 0.
saved = ckptr.save(step=0, state=tree)
# Clean up background threads gracefully when the training loop ends
ckptr.close()
2. Advanced Saving with Metrics and Metadata:
Attach JSON-serializable metrics (like loss/accuracy) and custom
metadata to a specific step for thorough experiment tracking::
from orbax.checkpoint.v1 import training
ckptr = training.Checkpointer(directory)
ckptr.save(
step=1,
state=tree,
metrics={'loss': 0.12, 'accuracy': 0.95},
custom_metadata={'description': 'Model after epoch 1'},
)
ckptr.close()
Args:
step: The step number to save.
state: The PyTree to save.
checkpointable_name: The name of the checkpointable to save a pytree
under. Defaults to 'pytree'.
force: If True, ignores all :py:class:`.SaveDecisionPolicy` checks, and
always decides to save a checkpoint.
overwrite: If True, deletes any existing checkpoint at the given step
before saving. Otherwise, raises an error if the checkpoint already
exists.
metrics: A PyTree of metrics to be saved with the checkpoint.
custom_metadata: A JSON dictionary representing user-specified custom
metadata. This should be information that is relevant to the checkpoint
at the given step, rather than to the entire sequence of checkpoints.
Returns:
Whether a checkpoint was saved or not.
"""
response = self.save_async(
step,
state,
checkpointable_name=checkpointable_name,
force=force,
overwrite=overwrite,
metrics=metrics,
custom_metadata=custom_metadata,
)
if response is None:
return False
return response.result()
[docs]
def save_checkpointables(
self,
step: int,
checkpointables: dict[str, Any],
*,
force: bool = False,
overwrite: bool = False,
metrics: tree_types.JsonType | None = None,
custom_metadata: tree_types.JsonType | None = None,
) -> bool:
"""Saves a dictionary of checkpointable objects at the given step.
This method saves a dictionary of checkpointable objects, mapping string
names to values. See `the guide on Checkpointables
<https://orbax.readthedocs.io/en/latest/guides/checkpoint/v1/checkpointables.html>`_
for more details on checkpointables. Also see documentation for
:py:func:`~orbax.checkpoint.v1.save`.
Example:
1. Basic Usage:
Save multiple named items (checkpointables) at a specific step. The
dictionary keys define the names of the saved components::
from orbax.checkpoint.v1 import training
# Initialize the checkpointer for a directory
ckptr = training.Checkpointer(directory)
# Save multiple items, such as model weights and optimizer state
items_to_save = {
'model': my_model_state,
'optimizer': my_opt_state,
}
saved = ckptr.save_checkpointables(
step=0,
checkpointables=items_to_save
)
# Clean up background threads gracefully when the training loop ends
ckptr.close()
2. Advanced Saving with Metrics and Metadata:
Attach JSON-serializable metrics and custom metadata to a specific step
for thorough experiment tracking::
from orbax.checkpoint.v1 import training
ckptr = training.Checkpointer(directory)
items_to_save = {'model': my_model_state}
ckptr.save_checkpointables(
step=1,
checkpointables=items_to_save,
metrics={'loss': 0.12, 'accuracy': 0.95},
custom_metadata={'description': 'Model after epoch 1'},
)
ckptr.close()
Args:
step: The step number to save.
checkpointables: A dictionary mapping string names to the corresponding
objects (checkpointables) that need to be saved.
force: If True, ignores all policy checks and always decides to save a
checkpoint.
overwrite: If True, deletes any existing checkpoint at the given step
before saving. Otherwise, raises an error if the checkpoint already
exists.
metrics: A dictionary of metrics to be saved with the checkpoint. Must be
JSON-serializable.
custom_metadata: A JSON dictionary representing user-specified custom
metadata relevant to the checkpoint at this specific step.
Returns:
bool: True if the checkpoint was successfully saved, False otherwise.
"""
response = self.save_checkpointables_async(
step,
checkpointables,
force=force,
overwrite=overwrite,
metrics=metrics,
custom_metadata=custom_metadata,
)
if response is None:
return False
return response.result()
[docs]
def save_async(
self,
step: int,
state: tree_types.PyTreeOf[tree_types.Leaf],
*,
checkpointable_name: str = STATE_CHECKPOINTABLE_KEY,
force: bool = False,
overwrite: bool = False,
metrics: tree_types.JsonType | None = None,
custom_metadata: tree_types.JsonType | None = None,
) -> async_types.AsyncResponse[bool] | None:
"""Saves a checkpoint asynchronously.
This function is the asynchronous equivalent of
:py:meth:`~.save`. It accepts the exact same
arguments; please refer to that method for detailed descriptions.
This method executes mostly in the background, blocking the main thread for
as little time as possible.
Example:
::
async_response = ckptr.save_async(step=0, state=tree)
if async_response is not None:
saved = async_response.result()
Args:
step: The step number to save.
state: The PyTree to save.
checkpointable_name: The name of the checkpointable to save a pytree
under. Defaults to 'pytree'.
force: See `save`.
overwrite: See `save`.
metrics: See `save`.
custom_metadata: See `save`.
Returns:
An `AsyncResponse`, which can be awaited via `result()`, which returns a
bool indicating whether a checkpoint was saved or not, or None if the save
was skipped by policy.
"""
return self.save_checkpointables_async(
step,
{checkpointable_name: state},
force=force,
overwrite=overwrite,
metrics=metrics,
custom_metadata=custom_metadata,
)
[docs]
def save_checkpointables_async(
self,
step: int,
checkpointables: dict[str, Any],
*,
force: bool = False,
overwrite: bool = False,
metrics: tree_types.JsonType | None = None,
custom_metadata: tree_types.JsonType | None = None,
) -> async_types.AsyncResponse[bool] | None:
"""Saves checkpointable objects asynchronously.
This function is the asynchronous equivalent of
:py:meth:`~.save_checkpointables`. Please refer to that
method for detailed instructions and argument descriptions.
Example:
Save checkpointable objects asynchronously::
async_response = ckptr.save_checkpointables_async(
step=0,
checkpointables=items_to_save
)
if async_response is not None:
saved = async_response.result()
Args:
step: The step number to save.
checkpointables: A dictionary mapping string names to objects to save.
force: See `save_checkpointables`.
overwrite: See `save_checkpointables`.
metrics: See `save_checkpointables`.
custom_metadata: See `save_checkpointables`.
Returns:
An object representing the background operation, or None if the save was
skipped by policy. Call `.result()` on it to block and return a boolean
indicating whether the checkpoint was successfully saved.
Raises:
StepAlreadyExistsError: If `overwrite` is False and a checkpoint at the
target `step` already exists.
"""
context = context_lib.get_context(self._context)
validation.validate_save_checkpointables(checkpointables)
if overwrite:
logging.info(
'Specified `overwrite`: deleting existing checkpoint %d if it'
' exists.',
step,
)
try:
self._manager.delete(step)
except FileNotFoundError:
pass
elif any(c.step == step for c in self.checkpoints):
raise errors.StepAlreadyExistsError(f'Step {step} already exists.')
with context:
checkpointer, args = saving.get_v0_checkpointer_and_args(
checkpointables, metrics=metrics
)
self._manager._checkpointer = checkpointer # pylint: disable=protected-access
save_initiated = self._manager.save(
step,
args=args,
metrics=metrics,
force=force,
custom_metadata=custom_metadata,
)
if not save_initiated:
return None
return _AsyncSaveResponse(self._manager)
[docs]
def load(
self,
step: int | CheckpointMetadata | None = None,
abstract_state: (
tree_types.PyTreeOf[tree_types.AbstractLeaf] | None
) = None,
*,
checkpointable_name: str = STATE_CHECKPOINTABLE_KEY,
) -> tree_types.PyTreeOf[tree_types.Leaf]:
"""Loads a PyTree checkpoint at the given step.
This method behaves similarly to the standalone free function
:py:func:`~orbax.checkpoint.v1.load`.
**Note:** Loading a PyTree without providing an `abstract_state` is
provided purely for convenience. For serious or production use cases, it is
STRONGLY recommended to always provide an `abstract_state` to ensure the
restored PyTree strictly matches the expected shapes, dtypes, and sharding.
Example:
1. Basic Loading:
Load a PyTree without providing an abstract structure. By passing
`step=None` (or omitting it), it automatically loads the latest step::
from orbax.checkpoint.v1 import training
# Initialize the checkpointer for the directory
ckptr = training.Checkpointer(directory)
# Load the saved PyTree from latest step
restored_tree = ckptr.load(step=None)
2. Loading with an Abstract PyTree:
Provide an abstract structure (such as target shapes and dtypes)
to ensure the restored PyTree is safely and correctly formatted::
import jax
import jax.numpy as jnp
from orbax.checkpoint.v1 import training
ckptr = training.Checkpointer(directory)
# Define the expected structure (shapes and dtypes) to restore into
target_structure = {
'weights': jax.ShapeDtypeStruct((128, 128), dtype=jnp.float32),
'bias': jax.ShapeDtypeStruct((128,), dtype=jnp.float32)
}
# Restore exactly matching the target structure
restored_tree = ckptr.load(
step=1,
abstract_state=target_structure
)
Args:
step: The step number or :py:class:`.CheckpointMetadata` to load. If None,
the checkpointer will attempt to resolve and load the latest existing
checkpoint.
abstract_state: The abstract PyTree to load.
checkpointable_name: The name of the checkpointable to load a pytree
under. Defaults to 'pytree'.
Returns:
The loaded PyTree.
"""
return self.load_checkpointables(
step, {checkpointable_name: abstract_state}
)[checkpointable_name]
[docs]
def load_checkpointables(
self,
step: int | CheckpointMetadata | None = None,
abstract_checkpointables: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Loads a set of checkpointables at the given step.
This method behaves similarly to the standalone free function
:py:func:`~orbax.checkpoint.v1.load_checkpointables`.
This function retrieves multiple named items (such as model weights or
optimizer states) from a specific checkpoint directory. If no step is
provided, it automatically resolves to and loads the most recently saved
checkpoint.
**Note:** Loading without providing an `abstract_checkpointables`
dictionary is provided purely for convenience. For serious or production
use cases, it is STRONGLY recommended to always provide
`abstract_checkpointables` to ensure the restored items strictly match
the exact nested structures, shapes, and data types expected.
Example:
1. Basic Loading:
Load multiple named items (such as a model and optimizer) from a
specific step. If step is omitted, it resolves to the latest
available checkpoint::
from orbax.checkpoint.v1 import training
# Initialize the checkpointer for the directory
ckptr = training.Checkpointer(directory)
# Load all checkpointables saved at the latest step
restored_items = ckptr.load_checkpointables(step=None)
# Access the individual components by their original string keys
my_model = restored_items["model"]
my_opt = restored_items["optimizer"]
2. Loading with Abstract Checkpointables (Recommended):
Provide a dictionary of abstract structures to ensure the restored
items strictly match your expected shapes and data types::
import jax
import jax.numpy as jnp
from orbax.checkpoint.v1 import training
ckptr = training.Checkpointer(directory)
# Define the expected structure for each named item using JAX arrays
target_items = {
"model": {
'weights': jax.ShapeDtypeStruct((128, 128), jnp.float32),
'bias': jax.ShapeDtypeStruct((128,), jnp.float32)
},
"optimizer": {
'momentum': jax.ShapeDtypeStruct((128, 128), jnp.float32)
}
}
# Restore exactly matching the target structures
restored_items = ckptr.load_checkpointables(
step=1,
abstract_checkpointables=target_items
)
3. Partial Loading:
If you only need to load a subset of checkpointables (e.g., loading
model weights but omitting optimizer state), you can provide an
`abstract_checkpointables` dictionary containing only the keys for the
items you wish to restore::
import jax
import jax.numpy as jnp
from orbax.checkpoint.v1 import training
ckptr = training.Checkpointer(directory)
# Define abstract structure for ONLY the items to load
target_items = {
"model": {
'weights': jax.ShapeDtypeStruct((128, 128), jnp.float32),
'bias': jax.ShapeDtypeStruct((128,), jnp.float32)
},
}
# Load only "model", omitting "optimizer"
restored_items = ckptr.load_checkpointables(
step=1,
abstract_checkpointables=target_items
)
my_model = restored_items["model"]
# my_opt = restored_items["optimizer"]
Args:
step: The step number or :py:class:`.CheckpointMetadata` to load. If None,
the checkpointer will attempt to resolve and load the latest existing
checkpoint.
abstract_checkpointables: A dictionary mapping string names to their
corresponding abstract structures (e.g., target PyTrees). This guides
the loading process to ensure shape and type compliance. If provided, it
can be used to load only a subset of checkpointables by providing only a
subset of keys.
Returns:
dict[str, Any]: A dictionary containing the loaded checkpointable objects,
keyed by string names. If `abstract_checkpointables` was specified,
returns only the keys specified in that dict, otherwise returns all
keys saved with `save_checkpointables`.
"""
with context_lib.get_context(self._context):
step = self._resolve_existing_checkpoint(step).step
return loading.load_checkpointables(
self.directory / self._step_name_format.build_name(step),
abstract_checkpointables,
)
[docs]
def load_async(
self,
step: int | CheckpointMetadata | None = None,
abstract_state: (
tree_types.PyTreeOf[tree_types.AbstractLeaf] | None
) = None,
) -> async_types.AsyncResponse[tree_types.PyTreeOf[tree_types.Leaf]]:
"""Not yet supported."""
raise NotImplementedError()
[docs]
def load_checkpointables_async(
self,
step: int | CheckpointMetadata | None = None,
abstract_checkpointables: dict[str, Any] | None = None,
) -> async_types.AsyncResponse[dict[str, Any]]:
"""Loads a set of checkpointables asynchronously at the given step."""
raise NotImplementedError()
def root_metadata(
self,
) -> training_metadata_types.RootMetadata:
with context_lib.get_context(self._context):
metadata = self._manager.metadata(None)
return RootMetadata(
directory=self.directory, custom_metadata=metadata.custom_metadata
)
[docs]
def reload(self):
"""Reloads internal properties from the root directory.
Updates the list of available checkpoints by rescanning the storage
location. Use this method to sync the checkpointer with the file system
if checkpoints have been added or removed externally.
"""
self._manager.reload()
[docs]
def is_saving_in_progress(self) -> bool:
"""Returns whether a checkpoint save operation is currently in progress.
Checks if there are any background persistence operations
currently active.
Returns:
`True` if a save operation is in progress, `False` otherwise.
"""
return self._manager.is_saving_in_progress()
[docs]
def wait(self):
"""Waits for any outstanding async operations to complete.
This method blocks until all background tasks, such as asynchronous saves,
have finished. Use this method to ensure that all operations are finalized
before proceeding with dependent actions.
"""
self._manager.wait_until_finished()
[docs]
def close(self):
"""Waits for pending async operations to complete and releases resources.
This method blocks until all background tasks, such as asynchronous saves,
have finished. It also performs necessary cleanup, such as closing
file handles.
"""
self._manager.close()
def __contextmanager__(
self,
) -> Iterable[Checkpointer]:
try:
yield self
finally:
self.close()
@deprecated('Use `save` instead.')
def save_pytree(self, *args, **kwargs):
return self.save(*args, **kwargs)
@deprecated('Use `save_async` instead.')
def save_pytree_async(self, *args, **kwargs):
return self.save_async(*args, **kwargs)
@deprecated('Use `load` instead.')
def load_pytree(self, *args, **kwargs):
return self.load(*args, **kwargs)
@deprecated('Use `load_async` instead.')
def load_pytree_async(self, *args, **kwargs):
return self.load_async(*args, **kwargs)
@deprecated('Use `metadata` instead.')
def pytree_metadata(self, *args, **kwargs):
return self.metadata(*args, **kwargs)