# Copyright 2024 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A class providing functionalities for managing a series of checkpoints."""
import concurrent.futures
import dataclasses
import datetime
import threading
import time
import typing
from typing import Any, Callable, Container, Iterable, List, Mapping, Optional, Sequence, Set, Tuple, Type, Union
from absl import logging
from etils import epath
from etils import epy
import jax
from jax.experimental.array_serialization import serialization
from orbax.checkpoint import abstract_checkpoint_manager
from orbax.checkpoint import abstract_checkpointer
from orbax.checkpoint import args as args_lib
from orbax.checkpoint import async_checkpointer
from orbax.checkpoint import checkpoint_args
from orbax.checkpoint import checkpoint_handler
from orbax.checkpoint import checkpointer as checkpointer_lib
from orbax.checkpoint import composite_checkpoint_handler
from orbax.checkpoint import json_checkpoint_handler
from orbax.checkpoint import multihost
from orbax.checkpoint import proto_checkpoint_handler
from orbax.checkpoint import utils
from orbax.checkpoint.logging import abstract_logger
from orbax.checkpoint.logging import step_statistics
from orbax.checkpoint.path import deleter
from orbax.checkpoint.path import step as step_lib
from typing_extensions import Self # for Python version < 3.11
PyTree = Any
CheckpointDirs = Tuple[str, str]
SaveParams = Mapping[str, Any]
RestoreParams = SaveParams
AbstractCheckpointer = abstract_checkpointer.AbstractCheckpointer
CheckpointersDict = Mapping[str, AbstractCheckpointer]
AbstractCheckpointManager = (
abstract_checkpoint_manager.AbstractCheckpointManager
)
AsyncCheckpointer = async_checkpointer.AsyncCheckpointer
Checkpointer = checkpointer_lib.Checkpointer
JsonCheckpointHandler = json_checkpoint_handler.JsonCheckpointHandler
ProtoCheckpointHandler = proto_checkpoint_handler.ProtoCheckpointHandler
CompositeCheckpointHandler = (
composite_checkpoint_handler.CompositeCheckpointHandler
)
CheckpointHandler = checkpoint_handler.CheckpointHandler
CheckpointArgs = checkpoint_args.CheckpointArgs
CheckpointHandlersDict = Mapping[str, CheckpointHandler]
DEFAULT_ITEM_NAME = 'default'
DESCRIPTOR_ITEM_NAME = 'descriptor'
METRIC_ITEM_NAME = 'metrics'
METADATA_ITEM_NAME = 'metadata'
RESERVED_ITEM_NAMES = [DESCRIPTOR_ITEM_NAME, METRIC_ITEM_NAME]
_INIT_TIME = datetime.datetime.now(tz=datetime.timezone.utc)
def _metrics_file_exists(metrics_item_path: epath.Path) -> bool:
"""True if item directory AND actual file both exist."""
return (
metrics_item_path.exists()
and (metrics_item_path / METRIC_ITEM_NAME).exists()
)
def _descriptor_file_exists(descriptor_item_path: epath.Path) -> bool:
"""True if item directory AND actual file both exist."""
return (
descriptor_item_path.exists()
and (descriptor_item_path / f'{DESCRIPTOR_ITEM_NAME}.pbtxt').exists()
)
class StepAlreadyExistsError(ValueError):
"""Raised when a step is already present for a save request."""
pass
class _FinalizeThread(threading.Thread):
"""Thread wrapper that raises an exception if encountered."""
exception = None
def run(self):
try:
super().run()
except BaseException as e: # pylint:disable=broad-exception-caught
self.exception = e
def join(self, *args, **kwargs):
super().join(*args, **kwargs)
if self.exception:
exception = self.exception
self.exception = None
raise exception
# TODO(b/268051457) Clean up when no longer depended upon by internal users.
[docs]def is_async_checkpointer(checkpointer: AbstractCheckpointer):
return isinstance(
checkpointer, async_checkpointer.AsyncCheckpointer
) or isinstance(
checkpointer,
serialization.GlobalAsyncCheckpointManagerBase,
)
[docs]@dataclasses.dataclass
class AsyncOptions:
"""Options used to configure async behavior.
See `AsyncCheckpointer` for details.
"""
timeout_secs: int = 300
barrier_sync_fn: Optional[async_checkpointer.BarrierSyncFn] = None
post_finalization_callback: Optional[Callable[[], None]] = None
[docs]@dataclasses.dataclass
class MultiprocessingOptions:
"""Options used to configure multiprocessing behavior.
primary_host: the host id of the primary host. Default to 0. If it's set
to None, then all hosts will be considered as primary. It's useful in
the case that all hosts are only working with local storage.
active_processes: A set of process indices (corresponding to
`multihost.process_index()`) over which `CheckpointManager` is expected to
be
called. This makes it possible to have a `CheckpointManager` instance
that runs over a subset of processes, rather than all processes as it is
normally expected to do. If specified, `primary_host` must belong to
`active_processes`.
"""
primary_host: Optional[int] = 0
active_processes: Optional[Set[int]] = None
# TODO(b/309965339) Set todelete_subdir defaults if directory is on CNS.
[docs]@dataclasses.dataclass
class CheckpointManagerOptions:
"""Optional arguments for CheckpointManager.
save_interval_steps:
The interval at which checkpoints should be saved.
Ensures checkpoints will only be saved every n steps. Defaults to 1.
max_to_keep:
If provided, specifies the maximum number of checkpoints to
keep. Older checkpoints are removed. By default, does not remove any old
checkpoints. Must be None or non-negative. When set, checkpoints
may be considered for deletion when there are more than `max_to_keep`
checkpoints present. Checkpoints are kept if they meet any of the conditions
below, such as `keep_time_interval`, `keep_period`, etc. Any remaining
checkpoints that do not meet these conditions are garbage-collected.
keep_time_interval:
When more than max_to_keep checkpoints are present,
an older checkpoint that would ordinarily be deleted will be preserved if it
has been at least `keep_time_interval` since the previous preserved
checkpoint. The default setting of `None` does not preserve any checkpoints
in this way. For example, this may be used to ensure checkpoints are
retained at a frequency of approximately than one per hour.
keep_period:
If set, any existing checkpoints matching checkpoint_step % keep_period == 0
will not be deleted.
best_fn:
If set, maintains checkpoints based on the quality of given
metrics rather than recency. The function should accept a PyTree of metrics,
and return a scalar value that can be used to determine the quality score
of the checkpoint. If `max_to_keep` is also set, then the retained
checkpoints will be kept based on their quality, as measured by this
function.
best_mode:
One of ['max', 'min']. The best metric is determine on the basis of this
value.
keep_checkpoints_without_metrics:
If False, checkpoints without metrics present
are eligible for cleanup. Otherwise, they will never be deleted.
step_prefix:
If provided, step directories will take the form
f'{step_prefix}_<step>'. Otherwise, they will simply be an integer <step>.
step_format_fixed_length:
If set, formats step with n digits (leading zeros).
This makes sorting steps easier. Otherwise, step has no leading zeros.
step_name_format:
NameFormat to build or find steps under input root directory. If provided,
`step_prefix`, `step_format_fixed_length` are ignored.
create:
If True, creates the top-level directory if it does not already exist.
cleanup_tmp_directories:
If True, cleans up any existing temporary directories
on CheckpointManager creation.
save_on_steps:
Optional set of steps at which checkpoints should be saved.
Useful to save checkpoints on a fixed set of steps that are not multiple of
`save_interval_steps`.
single_host_load_and_broadcast:
If True, calling `all_steps(read=True)` will load on only a single host, and
will then be broadcast to other hosts. Otherwise, I/O will be performed on
every host. This can be helpful to reduce QPS to the filesystem if there
are a large number of hosts.
todelete_subdir: If set, checkpoints to be deleted will be only renamed into a
subdirectory with the provided string. Otherwise, they will be directly
deleted from the file system. Useful if checkpoint deletion is time
consuming. By default, delete the checkpoint assets. Ignored if file system
is Google Cloud Storage (directory is prefixed with gs://)
enable_background_delete: If True, old checkpoint deletions will be done in a
background thread, otherwise, it will be done at the end of each save. When
it's enabled, make sure to call CheckpointManager.close() or use context to
make sure all old steps are deleted before exit.
read_only: If True, then checkpoints save and delete are skipped. However,
checkpoints restore works as usual.
enable_async_checkpointing:
If True, enables async checkpointing.
async_options:
Used to configure properties of async behavior. See above.
multiprocessing_options: MultiprocessingOptions instance to configure
multiprocessing behavior.
should_save_fn:
Predicate callable to check if given step can be saved. This callable
accepts step number and optional latest step number as param and returns
bool. If present then `save_interval_steps` and `save_on_steps` options are
ignored.
"""
save_interval_steps: int = 1
max_to_keep: Optional[int] = None
keep_time_interval: Optional[datetime.timedelta] = None
keep_period: Optional[int] = None
best_fn: Optional[Callable[[PyTree], float]] = None
best_mode: str = 'max'
keep_checkpoints_without_metrics: bool = True
step_prefix: Optional[str] = None
step_format_fixed_length: Optional[int] = None
step_name_format: Optional[step_lib.NameFormat] = None
create: bool = True
cleanup_tmp_directories: bool = False
save_on_steps: Optional[Container[int]] = None
single_host_load_and_broadcast: bool = False
todelete_subdir: Optional[str] = None
enable_background_delete: bool = False
read_only: bool = False
enable_async_checkpointing: bool = True
async_options: Optional[AsyncOptions] = None
multiprocessing_options: MultiprocessingOptions = dataclasses.field(
default_factory=MultiprocessingOptions
)
should_save_fn: Optional[Callable[[int, Optional[int]], bool]] = None
def __post_init__(self):
if self.best_mode not in ('min', 'max'):
msg = (
"`CheckpointManagerOptions.best_mode` must be one of None, 'min' "
"or 'max'. Got {self.dtype}."
)
raise ValueError(msg)
if self.max_to_keep is not None and self.max_to_keep < 0:
raise ValueError('Setting of `max_to_keep` must be None or non-negative.')
if self.read_only and self.save_interval_steps > 0:
self.save_interval_steps = 0
logging.warning(
'CheckpointManagerOptions.read_only=True, setting'
' save_interval_steps=0.'
)
if self.read_only and self.max_to_keep is not None:
self.max_to_keep = None
logging.warning(
'CheckpointManagerOptions.read_only=True, setting max_to_keep=None.'
)
if self.read_only and self.keep_time_interval is not None:
self.keep_time_interval = None
logging.warning(
'CheckpointManagerOptions.read_only=True, setting'
' keep_time_interval=None.'
)
if self.read_only and self.keep_period is not None:
self.keep_period = None
logging.warning(
'CheckpointManagerOptions.read_only=True, setting keep_period=None.'
)
if self.read_only and self.create:
self.create = False
logging.warning(
'CheckpointManagerOptions.read_only=True, setting create=False.'
)
if self.read_only and self.cleanup_tmp_directories:
self.cleanup_tmp_directories = False
logging.warning(
'CheckpointManagerOptions.read_only=True, setting'
' cleanup_tmp_directories=False.'
)
if self.read_only and self.save_on_steps:
self.save_on_steps = None
logging.warning(
'CheckpointManagerOptions.read_only=True, setting save_on_steps=None.'
)
if self.read_only and self.todelete_subdir is not None:
self.todelete_subdir = None
logging.warning(
'CheckpointManagerOptions.read_only=True, setting'
' todelete_subdir=None.'
)
if self.read_only and self.should_save_fn is not None:
self.should_save_fn = None
logging.warning(
'CheckpointManagerOptions.read_only=True, setting'
' should_save_fn=None.'
)
self.save_on_steps = frozenset(self.save_on_steps or ())
@dataclasses.dataclass
class CheckpointInfo:
"""Metadata about a checkpoint."""
step: int
time: datetime.datetime
metrics: Optional[PyTree]
is_locked: Optional[bool] = None
def __post_init__(self):
# Users may provide step as a jax.Array.
if isinstance(self.step, jax.Array):
self.step = int(self.step)
def __str__(self) -> str:
return (
f'Checkpoint[step={self.step} | time={self.time} |'
f' is_locked={self.is_locked}]'
)
def __eq__(self, other: 'CheckpointInfo') -> bool:
return self.step == other.step and self.time == other.time
def __hash__(self) -> int:
return hash((self.step, self.time))
def _get_args_for_key(
handler: CheckpointHandler, item_name: str
) -> Tuple[Type[CheckpointArgs], Type[CheckpointArgs]]:
if not isinstance(handler, CompositeCheckpointHandler):
raise ValueError(
'Expected handler to be a `CompositeCheckpointHandler`, but got'
f' {type(handler)}.'
)
for key, handler in handler._known_handlers.items(): # pylint: disable=protected-access
if key == item_name:
return checkpoint_args.get_registered_args_cls(handler)
raise ValueError(f'Unknown key "{item_name}" in CompositeCheckpointHandler.')
[docs]class CheckpointManager(AbstractCheckpointManager, epy.ContextManager):
"""A generic, synchronous AbstractCheckpointManager implementation."""
[docs] def __init__(
self,
directory: epath.PathLike,
checkpointers: Optional[
Union[AbstractCheckpointer, CheckpointersDict]
] = None,
options: Optional[CheckpointManagerOptions] = None,
metadata: Optional[Mapping[str, Any]] = None,
item_names: Optional[Sequence[str]] = None,
item_handlers: Optional[
Union[CheckpointHandler, CheckpointHandlersDict]
] = None,
logger: Optional[abstract_logger.AbstractLogger] = None,
):
"""CheckpointManager constructor.
IMPORTANT: `CheckpointManager` has been refactored to provide a new API.
Please ensure you have migrated all existing use cases to the newer style by
August 1st, 2024. Please see
https://orbax.readthedocs.io/en/latest/api_refactor.html
for technical details.
The `CheckpointManager` is ultimately backed by a single `Checkpointer`, to
which saving and restoring is delegated. Behind step management options,
metrics-related logic, and other frills, saving and restoring with
`CheckpointManager` is quite similar to using
`Checkpointer(CompositeCheckpointHandler)`.
Example::
with CheckpointManager(
'path/to/dir/',
# Multiple items.
item_names=('train_state', 'custom_metadata'),
metadata={'version': 1.1, 'lang': 'en'},
) as mngr:
mngr.save(0, args=args.Composite(
train_state=args.StandardSave(train_state),
custom_metadata=args.JsonSave(custom_metadata),
)
)
restored = mngr.restore(0)
print(restored.train_state)
print(restored.custom_metadata)
restored = mngr.restore(0, args=args.Composite(
train_state=args.StandardRestore(abstract_train_state),
)
)
print(restored.train_state)
print(restored.custom_metadata) # Error, not restored
# Single item, no need to specify `item_names`.
with CheckpointManager(
'path/to/dir/',
options = CheckpointManagerOptions(max_to_keep=5, ...),
) as mngr:
mngr.save(0, args=StandardSave(train_state))
train_state = mngr.restore(0)
train_state = mngr.restore(0,
args=StandardRestore(abstract_train_state))
IMPORTANT: Don't forget to use the keyword `args=...` for save and restore!
Otherwise you will get the legacy API. This will not be necessary forever,
but only until the legacy API is removed.
IMPORTANT: The CheckpointManager is designed to be used as a context
manager. Use `with CheckpointManager` schematic for automatic cleanup. If
you can't use a context manager, always call `close()` to release resources
properly. Otherwise, background operations such as deleting old checkpoints
might not finish before your program exits.
CheckpointManager:
- is NOT thread-safe.
- IS multi-process-safe.
- is NOT multi-job-safe.
This means that CheckpointManager is intended to be created and called
across all processes within a single job, where each process is
single-threaded, but is not safe to use when multiple jobs each have
CheckpointManager instances pointing to the same root directory. Concretely,
this means that if you have a trainer job and one or more evaluator jobs,
the CheckpointManager should be created and called across all processes
in the trainer, but a CheckpointManager cannot be used in the evaluators.
Instead, utilities used during evaluation can be found in
`checkpoint_utils` (see
https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.checkpoint_utils.html).
Args:
directory: the top level directory in which to save all files.
checkpointers: a mapping of object name to Checkpointer object. For
example, `items` provided to `save` below should have keys matching the
keys in this argument. Alternatively, a single Checkpointer may be
provided, in which case `save` and `restore` should always be called
with a single item rather than a dictionary of items. See below for more
details. `item_names` and `checkpointers` are mutually exclusive - do
not use together. Also, please don't use `checkpointers` and
`item_handlers` together.
options: CheckpointManagerOptions. May be provided to specify additional
arguments. If None, uses default values of CheckpointManagerOptions.
metadata: High-level metadata that does not depend on step number. If
`directory` is write enabled then given metadata is saved only once. A
new CheckpointManager instance with that `directory` does not overwrite
the existing metadata and ignores the current given metadata. If
`directory` is read-only then the current given metadata is not saved as
expected. A CheckpointManager instance with a read-only `directory` uses
the metadata if already present, otherwise always uses the current given
metadata.
item_names: Names of distinct items that may be saved/restored with this
`CheckpointManager`. `item_names` and `checkpointers` are mutually
exclusive - do not use together. Also see `item_handlers` below.
item_handlers: A mapping of item name to `CheckpointHandler`. The mapped
CheckpointHandler must be registered against the `CheckpointArgs` input
in save/restore operations. Please don't use `checkpointers` and
`item_handlers` together. It can be used with or without `item_names`.
The item name key may or may not be present in `item_names`.
Alternatively, a single CheckpointHandler may be provided, in which case
`save` and `restore` should always be called in a single item context.
logger: A logger to log checkpointing events.
"""
jax.monitoring.record_event('/jax/orbax/checkpoint_manager/init')
self._options = options or CheckpointManagerOptions()
if self._options.best_mode not in ['min', 'max']:
raise ValueError('`best_mode` must be one of: "min", "max"')
self._multiprocessing_options = (
self._options.multiprocessing_options or MultiprocessingOptions()
)
if (logger is not None) and (
not isinstance(logger, abstract_logger.AbstractLogger)
):
raise ValueError('`logger` must be an AbstractLogger instance.')
self._logger = logger
if checkpointers and item_names:
raise ValueError(
'`item_names` and `checkpointers` are mutually exclusive - do not use'
' together.'
)
if checkpointers and item_handlers:
raise ValueError(
'`item_handlers` and `checkpointers` are mutually exclusive - do not'
' use together.'
)
if item_names and isinstance(item_handlers, CheckpointHandler):
raise ValueError(
'`item_handlers` in single item mode and `item_names` should not be'
' provided together.'
)
if checkpointers:
logging.warning(
'Configured `CheckpointManager` using deprecated legacy API. Please'
' follow the instructions at'
' https://orbax.readthedocs.io/en/latest/api_refactor.html to'
' migrate by August 1st, 2024.'
)
self._single_item = isinstance(checkpointers, AbstractCheckpointer)
self._checkpointer = self._configure_checkpointer_legacy_init(
checkpointers, self._options
)
else:
self._single_item = isinstance(item_handlers, CheckpointHandler) or (
item_names is None and item_handlers is None
)
self._checkpointer = self._configure_checkpointer(
item_names, item_handlers, self._options, self._single_item
)
self._directory = epath.Path(directory)
if self._options.read_only:
logging.warning('Given directory is read only=%s', self._directory)
if self._options.create:
if (
utils.is_primary_host(self._multiprocessing_options.primary_host)
and not self._directory.exists()
):
self._directory.mkdir(parents=True)
utils.sync_global_processes(
'CheckpointManager:create_directory',
processes=self._multiprocessing_options.active_processes,
)
# Cleanup directories from previous runs that may not have been finalized.
if self._options.cleanup_tmp_directories:
self._cleanup_tmp_directories()
self._step_name_format = (
self._options.step_name_format
or step_lib.standard_name_format(
step_prefix=self._options.step_prefix,
step_format_fixed_length=self._options.step_format_fixed_length,
)
)
self._checkpoints = self._load_checkpoint_infos()
self._metadata_checkpointer = Checkpointer(
JsonCheckpointHandler(
primary_host=self._multiprocessing_options.primary_host
),
primary_host=self._multiprocessing_options.primary_host,
active_processes=self._multiprocessing_options.active_processes,
)
if self._options.read_only and not self._metadata_path().exists():
self._metadata = {} if metadata is None else metadata
else:
self._metadata = None
if metadata is not None and not self._options.read_only:
self._save_metadata(metadata)
self._finalize_thread = None
self._checkpoint_deleter: deleter.CheckpointDeleter = (
deleter.create_checkpoint_deleter(
self._multiprocessing_options.primary_host,
self._directory,
self._options.todelete_subdir,
self._step_name_format,
self._options.enable_background_delete,
)
)
logging.info(
'CheckpointManager created, jax.process_index=%s, primary_host=%s,'
' CheckpointManagerOptions=%s: %s',
multihost.process_index(),
self._multiprocessing_options.primary_host,
self._options,
self,
)
def _configure_checkpointer_common(
self,
handler: CompositeCheckpointHandler,
options: CheckpointManagerOptions,
use_async: bool,
) -> Checkpointer:
if use_async:
if options.async_options is not None:
return async_checkpointer.AsyncCheckpointer(
handler,
timeout_secs=options.async_options.timeout_secs,
primary_host=self._multiprocessing_options.primary_host,
barrier_sync_fn=options.async_options.barrier_sync_fn,
active_processes=self._multiprocessing_options.active_processes,
post_finalization_callback=options.async_options.post_finalization_callback,
)
else:
return async_checkpointer.AsyncCheckpointer(
handler,
primary_host=self._multiprocessing_options.primary_host,
active_processes=self._multiprocessing_options.active_processes,
)
else:
return Checkpointer(
handler,
primary_host=self._multiprocessing_options.primary_host,
active_processes=self._multiprocessing_options.active_processes,
)
def _configure_checkpointer_legacy_init(
self,
checkpointers: Union[AbstractCheckpointer, CheckpointersDict],
options: CheckpointManagerOptions,
) -> Checkpointer:
"""Initializes _CompositeCheckpointer with legacy style checkpointers."""
if self._multiprocessing_options.primary_host != 0:
raise ValueError(
f'`primary_host`={self._multiprocessing_options.primary_host} is not'
' supported in legacy API.'
)
item_handlers = {}
if isinstance(checkpointers, Checkpointer):
use_async = is_async_checkpointer(checkpointers)
if isinstance(checkpointers, async_checkpointer.AsyncCheckpointer):
async_timeout = checkpointers._async_manager._timeout_secs # pylint: disable=protected-access
else:
async_timeout = None
item_handlers[DEFAULT_ITEM_NAME] = checkpointers.handler
elif isinstance(checkpointers, dict):
individual_use_async = []
async_timeout = 0
for item_name, checkpointer in checkpointers.items():
if not isinstance(checkpointer, Checkpointer):
raise ValueError(
f'Value corresponding to {item_name} in `checkpointers` is not a'
f' Checkpointer. Found {type(checkpointer)}.'
)
individual_use_async.append(is_async_checkpointer(checkpointer))
if isinstance(checkpointer, async_checkpointer.AsyncCheckpointer):
async_timeout = max(
async_timeout, checkpointer._async_manager._timeout_secs # pylint: disable=protected-access
)
if item_name in RESERVED_ITEM_NAMES:
raise ValueError(
f'Found {item_name} in `checkpointers`; this is a reserved key.'
)
item_handlers[item_name] = checkpointer.handler
if any(individual_use_async) and not all(individual_use_async):
logging.error(
'Orbax `CheckpointManager` is transitioning toward using'
' asynchronous saving logic under the hood in all cases. Users that'
' configure `CheckpointManager` with some `Checkpointer`s and some'
' `AsyncCheckpointer`s will now see asynchronous logic used to save'
' all items. This may result in breakages if the code is assuming'
' that certain objects will be available immediately after saving.'
' Ensure that if you depend on the result of `save` being fully'
' written at a particular moment, use `wait_until_finished()`.'
)
use_async = any(individual_use_async)
async_timeout = async_timeout or None
else:
raise ValueError(
f'Invalid type for `checkpointers`. Found {checkpointers}.'
)
# if options.best_fn:
item_handlers[METRIC_ITEM_NAME] = JsonCheckpointHandler(
filename=METRIC_ITEM_NAME
)
options.async_options = options.async_options or AsyncOptions(
timeout_secs=async_timeout
)
return self._configure_checkpointer_common(
CompositeCheckpointHandler(
composite_options=composite_checkpoint_handler.CompositeOptions(
primary_host=self._multiprocessing_options.primary_host,
),
**item_handlers,
),
options,
use_async,
)
def _validate_handler(self, handler):
if (
hasattr(handler, '_primary_host')
and handler._primary_host != self._multiprocessing_options.primary_host # pylint: disable=protected-access
):
raise ValueError(
'Inconsistent primary_host,'
f' CheckpointManager={self._multiprocessing_options.primary_host}, '
f'handler[{type(handler)}]={handler._primary_host} ' # pylint: disable=protected-access
)
def _configure_checkpointer(
self,
item_names: Optional[Sequence[str]],
item_handlers: Optional[Union[CheckpointHandler, CheckpointHandlersDict]],
options: CheckpointManagerOptions,
single_item: bool,
) -> Checkpointer:
"""Initializes _CompositeCheckpointer given `item_names`."""
if (
self._multiprocessing_options.primary_host is None
and item_handlers is None
):
raise ValueError(
'When primary_host is set to None, item_handlers must be provided to'
' match with the primary_host setting.'
)
if single_item:
item_handler = (
item_handlers
if isinstance(item_handlers, CheckpointHandler)
else None
)
if item_handler:
self._validate_handler(item_handler)
all_item_handlers = {DEFAULT_ITEM_NAME: item_handler}
else:
# Initialize all_item_handlers with None or empty.
if item_names:
all_item_handlers = {item_name: None for item_name in item_names}
else:
all_item_handlers = {}
# Update all_item_handlers with provided CheckpointHandlers.
if item_handlers and isinstance(item_handlers, Mapping):
for item_name, handler in item_handlers.items():
self._validate_handler(handler)
all_item_handlers[item_name] = handler
for item_name in all_item_handlers:
if item_name in RESERVED_ITEM_NAMES:
raise ValueError(
f'Found {item_name} in `checkpointers`; this is a reserved key.'
)
if options.best_fn:
all_item_handlers[METRIC_ITEM_NAME] = JsonCheckpointHandler(
filename=METRIC_ITEM_NAME,
primary_host=self._multiprocessing_options.primary_host,
)
# CompositeCheckpointHandler defers per-item handler creation until
# save/restore time.
return self._configure_checkpointer_common(
CompositeCheckpointHandler(
composite_options=composite_checkpoint_handler.CompositeOptions(
primary_host=self._multiprocessing_options.primary_host,
),
**all_item_handlers,
),
options,
options.enable_async_checkpointing,
)
@property
def directory(self) -> epath.Path:
"""See superclass documentation."""
return self._directory
[docs] def all_steps(self, read: bool = False) -> Sequence[int]:
"""See superclass documentation."""
if read:
logging.warning(
'`read` option is deprecated. Use `reload` to read from disk.'
)
return utils.checkpoint_steps(
self.directory, self._options.single_host_load_and_broadcast
)
return [ckpt.step for ckpt in self._checkpoints]
[docs] def latest_step(self) -> Optional[int]:
"""See superclass documentation."""
return self._checkpoints[-1].step if self._checkpoints else None
[docs] def best_step(self) -> Optional[int]:
"""See superclass documentation."""
if not self._track_best:
return self.latest_step()
if not self._checkpoints:
return None
_, sorted_checkpoints = self._sort_checkpoints_by_metrics(self._checkpoints)
if not sorted_checkpoints:
return None
return sorted_checkpoints[-1].step
[docs] def reload(self):
"""Reloads internal properties.
Resets internal cache of checkpoint steps, in case the directory managed
by this object has been updated externally.
"""
self._checkpoints = self._load_checkpoint_infos()
[docs] def reached_preemption(self, step: int) -> bool:
"""See superclass documentation."""
return utils.reached_preemption(step)
[docs] def should_save(self, step: int) -> bool:
"""See superclass documentation."""
if self._options.read_only:
logging.warning('%s is read only, save will be skipped', self.directory)
return False
if self.reached_preemption(step):
return True
last_checkpoint_step = self.latest_step()
# Ensure current step is between the last step and next step (accounting for
# save interval). The `last_checkpoint_step` may not be initialized, in
# which case we should save. Otherwise, step must fall on the specified
# save interval. This condition accounts for the possibility of saving
# on preemption, in which case we want to maintain the same save period as
# if preemption had not happened.
if last_checkpoint_step is not None and last_checkpoint_step >= step:
return False
# If present then prefer should_save_fn over other 'save_*' options.
if self._options.should_save_fn is not None:
logging.info(
'CheckpointManagerOptions.should_save_fn is available, following save'
' options will be ignored: save_interval_steps=%s and'
' save_on_steps=%s',
self._options.save_interval_steps,
self._options.save_on_steps,
)
return self._options.should_save_fn(step, last_checkpoint_step)
return last_checkpoint_step is None or (
step % self._options.save_interval_steps == 0
or step in self._options.save_on_steps
)
def _get_save_directory(
self,
step: int,
directory: epath.Path,
) -> epath.Path:
"""Returns the standardized path to a save directory for a single item."""
return step_lib.build_step_path(directory, self._step_name_format, step)
def _get_write_step_directory(
self, step: int, root_dir: epath.Path
) -> epath.Path:
return self._get_save_directory(step, root_dir)
def _get_read_step_directory(
self, step: int, root_dir: epath.Path
) -> epath.Path:
if self._options.step_name_format is not None:
return self._options.step_name_format.find_step(root_dir, step).path
else:
return self._get_save_directory(step, root_dir)
[docs] def delete(self, step: int):
"""See superclass documentation.
Delete can be run asynchronously if
CheckpointManagerOptions.enable_background_delete is set to True.
Args:
step: The step to delete.
"""
if self._options.read_only:
logging.warning('%s is read only, delete will be skipped', self.directory)
return
if step not in self.all_steps():
raise ValueError(f'Requested deleting a non-existent step: {step}.')
self._checkpoint_deleter.delete(step)
utils.sync_global_processes(
'CheckpointManager:deleted_step',
processes=self._multiprocessing_options.active_processes,
)
for i, info in enumerate(self._checkpoints):
if info.step == step:
self._checkpoints.pop(i)
def _validate_args(
self,
items: Optional[Union[Any, Mapping[str, Any]]],
args: Optional[args_lib.CheckpointArgs],
):
if isinstance(items, args_lib.CheckpointArgs):
raise ValueError(
'Found an instance of `CheckpointArgs` provided for `items`. This may'
' be due to misuse of the newer API - make sure to specify the'
' argument keyword (e.g. `args=args`).'
)
if args is not None:
if not isinstance(args, args_lib.CheckpointArgs):
raise ValueError(
f'Expected args of type `CheckpointArgs`; found {type(args)}.'
)
if self._single_item:
if isinstance(args, args_lib.Composite):
raise ValueError(
'Cannot provide `args` of type `Composite` when dealing with a'
' single checkpointable object.'
)
else:
if not isinstance(args, args_lib.Composite):
raise ValueError(
'Must provide `args` of type `Composite` when dealing with'
' multiple checkpointable objects.'
)
[docs] def save(
self,
step: int,
items: Optional[Union[Any, Mapping[str, Any]]] = None,
save_kwargs: Optional[Union[SaveParams, Mapping[str, SaveParams]]] = None,
metrics: Optional[PyTree] = None,
force: Optional[bool] = False,
args: Optional[args_lib.CheckpointArgs] = None,
) -> bool:
"""See superclass documentation."""
step_stats = step_statistics.StepStatistics()
step_stats.step = step
step_stats.start_time = time.time()
step_stats.event_type = 'save'
if items is None and args is None:
raise ValueError('Must provide `args` for `save`.')
self._validate_args(items, args)
if not force and not self.should_save(step):
return False
if self.reached_preemption(step):
logging.info('Saving checkpoint at step %d due to preemption.', step)
step_stats.reached_preemption = True
step_stats.preemption_received_at = time.time()
# Wait for ongoing saves to complete. Only applicable if some of the
# checkpointers are AsyncCheckpointers.
# Must happen after `should_save` to avoid blocking callers.
wait_for_prev_start_time = time.time()
self.wait_until_finished()
jax.monitoring.record_event_duration_secs(
'/jax/checkpoint/write/wait_for_prev_duration_secs',
time.time() - wait_for_prev_start_time,
)
step_stats.wait_for_prev_start_time = wait_for_prev_start_time
step_stats.wait_for_prev_end_time = time.time()
if step in self.all_steps():
raise StepAlreadyExistsError(
f'Checkpoint for step {step} already exists.'
)
if items is None:
items = {}
if save_kwargs is None:
save_kwargs = {}
if self._single_item:
items = {DEFAULT_ITEM_NAME: items}
save_kwargs = {DEFAULT_ITEM_NAME: save_kwargs}
if self._track_best and metrics is None:
logging.warning('Requested `tracked_metric`; did not provide metrics.')
if args is None:
args_dict = {}
for key, item in items.items():
save_ckpt_arg_cls, _ = _get_args_for_key(
self._checkpointer.handler,
key,
)
extra_args = save_kwargs[key] if key in save_kwargs else {}
extra_args = extra_args or {}
args_dict[key] = save_ckpt_arg_cls(item, **extra_args) # pytype: disable=wrong-arg-count
args = args_lib.Composite(**args_dict)
else:
if self._single_item:
args = args_lib.Composite(**{DEFAULT_ITEM_NAME: args})
else:
if not isinstance(args, args_lib.Composite):
raise ValueError(
f'Expected args of type `Composite`; found {type(args)}.'
)
args = typing.cast(args_lib.Composite, args)
args_dict = dict(args.items())
if metrics is not None and self._track_best:
args_dict['metrics'] = args_lib.JsonSave(metrics)
args = args_lib.Composite(**args_dict)
save_directory = self._get_write_step_directory(step, self.directory)
# If a folder for the step to save exists and is not finalized, remove the
# existing folder.
if step_lib.is_gcs_path(self.directory):
if (
utils.is_primary_host(self._multiprocessing_options.primary_host)
and save_directory.exists()
and utils.is_tmp_checkpoint(save_directory)
):
logging.warning(
'Attempting to save on GCS at step %s which has an unfinalized'
' checkpoint from previous runs. Removing the unfinalized'
' checkpoint before saving.',
step,
)
# make sure to use a synchronous deleter here
deleter.create_checkpoint_deleter(
self._multiprocessing_options.primary_host,
self._directory,
self._options.todelete_subdir,
self._step_name_format,
enable_background_delete=False, # no background thread
).delete(step)
utils.sync_global_processes(
'CheckpointManager:delete_unfinalized_step_gcs',
processes=self._multiprocessing_options.active_processes,
)
logging.info('Saving checkpoint at step %d', step)
step_stats.checkpoint_start_time = time.time()
self._checkpointer.save(save_directory, args=args)
step_stats.checkpoint_end_time = time.time()
self._add_checkpoint_info(step, metrics)
get_old_steps_start_time = time.time()
steps_to_remove = self._get_old_steps_to_remove()
jax.monitoring.record_event_duration_secs(
'/jax/checkpoint/write/get_old_steps_duration_secs',
time.time() - get_old_steps_start_time,
)
step_stats.get_old_steps_start_time = get_old_steps_start_time
step_stats.get_old_steps_end_time = time.time()
self._checkpoints = [
info for info in self._checkpoints if info.step not in steps_to_remove
]
# Sync needed to ensure that old steps to remove are retrieved before
# actually deleting them during finalize, since retrieval can involve
# looking at the directory.
utils.sync_global_processes(
'CheckpointManager:old_steps_to_remove',
processes=self._multiprocessing_options.active_processes,
)
assert self._finalize_thread is None
if is_async_checkpointer(self._checkpointer):
logging.info('Beginning async checkpoint finalize.')
t = _FinalizeThread(
target=self._finalize, args=(save_directory, steps_to_remove)
)
t.start()
self._finalize_thread = t
else:
self._finalize(save_directory, steps_to_remove)
logging.info('Finished synchronous save.')
utils.sync_global_processes(
'CheckpointManager:finalize',
processes=self._multiprocessing_options.active_processes,
)
step_stats.synchronous = is_async_checkpointer(self._checkpointer)
step_stats.end_time = time.time()
if self._logger is not None:
self._logger.log_entry(dataclasses.asdict(step_stats))
return True
[docs] def restore(
self,
step: int,
items: Optional[Union[Any, Mapping[str, Any]]] = None,
restore_kwargs: Optional[
Union[RestoreParams, Mapping[str, RestoreParams]]
] = None,
directory: Optional[epath.PathLike] = None,
args: Optional[args_lib.CheckpointArgs] = None,
) -> Union[Any, Mapping[str, Any]]:
"""See superclass documentation."""
directory = directory or self.directory
directory = epath.Path(directory)
self._validate_args(items, args)
if items is None:
items = {}
elif self._single_item:
items = {DEFAULT_ITEM_NAME: items}
if restore_kwargs is None:
restore_kwargs = {}
elif self._single_item:
restore_kwargs = {DEFAULT_ITEM_NAME: restore_kwargs}
if args is None:
args_dict = {}
item_keys = set(items.keys()) | set(restore_kwargs.keys())
for key in item_keys:
_, restore_ckpt_arg_cls = _get_args_for_key(
self._checkpointer.handler,
key,
)
item = items[key] if key in items else None
extra_args = restore_kwargs[key] if key in restore_kwargs else {}
extra_args = extra_args or {}
args_dict[key] = restore_ckpt_arg_cls(item, **extra_args) # pytype: disable=wrong-arg-count
args = args_lib.Composite(**args_dict)
else:
if self._single_item:
args = args_lib.Composite(**{DEFAULT_ITEM_NAME: args})
else:
args = typing.cast(args_lib.Composite, args)
restore_directory = self._get_read_step_directory(step, directory)
restored = self._checkpointer.restore(restore_directory, args=args)
if self._single_item:
return restored[DEFAULT_ITEM_NAME]
return restored
[docs] def metrics(self, step: int) -> Optional[PyTree]:
if self._track_best:
try:
restored = self._checkpointer.restore(
self._get_read_step_directory(step, self.directory),
args=args_lib.Composite(
**{METRIC_ITEM_NAME: args_lib.JsonRestore()}
),
)
return restored[METRIC_ITEM_NAME]
except FileNotFoundError:
logging.warning('Missing metrics for step %d', step)
return None
else:
return None
@property
def _track_best(self):
"""Returns true if we should track the best checkpoints by given metric."""
return self._options.best_fn is not None
def _load_checkpoint_infos(self) -> List[CheckpointInfo]:
"""Loads a list of CheckpointInfo for existing checkpoints.
If none are present, returns empty list.
Returns:
a list of CheckpointInfo, sorted by increasing step.
"""
start = time.time()
steps = utils.checkpoint_steps(
self.directory, self._options.single_host_load_and_broadcast
)
steps.sort() # Prefer in-place sort.
if not steps:
checkpoint_infos = []
else:
def build_checkpoint_info(step: int) -> CheckpointInfo:
step_metadata = self._step_name_format.find_step(self.directory, step)
return CheckpointInfo(
step=step,
time=step_metadata.commit_timestamp,
metrics=self.metrics(step),
)
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = {
step: executor.submit(build_checkpoint_info, step) for step in steps
}
checkpoint_infos = [futures[step].result() for step in steps]
jax.monitoring.record_event_duration_secs(
'/jax/checkpoint/read/load_all_step_metadata_duration_secs',
time.time() - start,
)
logging.info(
'Found %d checkpoint steps in %s', len(checkpoint_infos), self.directory
)
return checkpoint_infos
def _get_interval_preserved_checkpoints(
self, checkpoints: List[CheckpointInfo]
) -> List[CheckpointInfo]:
"""Gets which checkpoints should be kept based on keep_time_interval."""
if not checkpoints:
return []
interval_preserved_checkpoints = [checkpoints[0]]
if self._options.keep_time_interval is not None:
for info in checkpoints[1:]:
if (
info.time
>= interval_preserved_checkpoints[-1].time
+ self._options.keep_time_interval
):
interval_preserved_checkpoints.append(info)
return interval_preserved_checkpoints
def _add_checkpoint_info(self, step: int, metrics: Optional[PyTree]):
self._checkpoints.append(
CheckpointInfo(
step, datetime.datetime.now(tz=datetime.timezone.utc), metrics
)
)
def _metadata_path(self) -> epath.Path:
return self.directory / METADATA_ITEM_NAME
def _save_metadata(self, metadata: Mapping[str, Any]):
"""Saves CheckpointManager level metadata, skips if already present."""
path = self._metadata_path()
if not path.exists(): # May have been created by a previous run.
self._metadata_checkpointer.save(path, metadata)
def _sort_checkpoints_by_metrics(
self, checkpoints: List[CheckpointInfo]
) -> Tuple[List[CheckpointInfo], List[CheckpointInfo]]:
"""Sorts `checkpoints` in order of increasing metric quality.
Checkpoints without corresponding metrics set will be at the beginning.
Args:
checkpoints: a list of CheckpointInfo.
Returns:
Tuple of CheckpointInfo lists:
(checkpoints_without_metrics, checkpoints_sorted_by_metrics)
"""
without_metrics = [info for info in checkpoints if info.metrics is None]
with_metrics = [info for info in checkpoints if info.metrics is not None]
return without_metrics, sorted(
with_metrics,
key=lambda info: self._options.best_fn(info.metrics),
reverse=(self._options.best_mode == 'min'),
)
def _cleanup_tmp_directories(self):
utils.cleanup_tmp_directories(
self.directory,
primary_host=self._multiprocessing_options.primary_host,
active_processes=self._multiprocessing_options.active_processes,
)
def _get_old_steps_to_remove(self) -> List[int]:
"""Returns checkpoints that should be deleted."""
# Must have set max_to_keep in order to remove any checkpoints.
if self._options.max_to_keep is None:
return []
# Not enough checkpoints accumulated to consider deletion.
if len(self._checkpoints) <= self._options.max_to_keep:
return []
# This isn't a duration but there isn't a general counter that we can use so
# we abuse a duration metric to count the number of steps examined.
jax.monitoring.record_event_duration_secs(
'/jax/checkpoint/write/old_steps_examined_count',
len(self._checkpoints),
)
# Exclude the latest checkpoint, since it is not finalized.
are_locked = utils.are_locked(
self.directory,
steps=tuple([info.step for info in self._checkpoints[:-1]]),
step_name_format=self._step_name_format,
)
self._checkpoints[:-1] = [
dataclasses.replace(info, is_locked=is_locked)
for info, is_locked in zip(self._checkpoints[:-1], are_locked)
]
if self._track_best:
# Best steps (to keep) are at the end, after sorting.
(
checkpoints_without_metrics,
sorted_checkpoints,
) = self._sort_checkpoints_by_metrics(self._checkpoints)
else:
# checkpoints already sorted by ascending step
checkpoints_without_metrics = []
sorted_checkpoints = self._checkpoints
keep = int(self._options.max_to_keep)
if self._options.keep_checkpoints_without_metrics:
maybe_delete = (
sorted_checkpoints[:-keep] if keep > 0 else sorted_checkpoints
)
active_checkpoints = set(
checkpoints_without_metrics
+ (sorted_checkpoints[-keep:] if keep > 0 else [])
)
else:
all_checkpoints = checkpoints_without_metrics + sorted_checkpoints
maybe_delete = all_checkpoints[:-keep] if keep > 0 else sorted_checkpoints
active_checkpoints = set(all_checkpoints[-keep:] if keep > 0 else [])
interval_preserved_checkpoints = self._get_interval_preserved_checkpoints(
self._checkpoints
)
kept_checkpoints = set()
for info in maybe_delete:
if info.is_locked:
logging.info(
'Preserving %s: (Reason: checkpoint is locked).',
info,
)
kept_checkpoints.add(info)
continue
if (
self._options.keep_time_interval is not None
and interval_preserved_checkpoints
):
if info in interval_preserved_checkpoints:
logging.info(
'Preserving %s: (Reason: older falling on keep_time_interval).',
info,
)
kept_checkpoints.add(info)
continue
elif info.time >= (
interval_preserved_checkpoints[-1].time
+ self._options.keep_time_interval
):
interval_preserved_checkpoints.append(info)
logging.info(
'Preserving %s: (Reason: latest falling on keep_time_interval).',
info,
)
kept_checkpoints.add(info)
continue
if (
self._options.keep_period is not None
and info.step % self._options.keep_period == 0
):
logging.info(
'Preserving %s: (Reason: on keep_period=%s).',
info,
self._options.keep_period,
)
kept_checkpoints.add(info)
continue
kept_checkpoints.update(active_checkpoints)
steps_to_remove = []
for info in self._checkpoints:
if info not in kept_checkpoints:
reason = 'worse metric' if self._track_best else 'old checkpoint'
logging.info('Deleting %s: (Reason: %s).', info, reason)
steps_to_remove.append(info.step)
return steps_to_remove
def _wait_for_checkpointers(self):
if is_async_checkpointer(self._checkpointer):
self._checkpointer.wait_until_finished() # pytype: disable=attribute-error
[docs] def wait_until_finished(self):
"""See superclass documentation."""
t = self._finalize_thread
if t is not None:
self._finalize_thread = None
try:
t.join()
except BaseException as e: # pylint:disable=broad-exception-caught
# If an exception occurred in the in finalization of the previous
# save, we clean up since that checkpoint was never actually saved.
assert self._checkpoints
self._checkpoints = self._checkpoints[:-1]
raise e
# Additional work is being done on process 0 of the finalize threads.
# When joining the threads, we must wait for all threads to complete
# before proceeding.
utils.sync_global_processes(
'CheckpointManager:join_finalize_thread',
processes=self._multiprocessing_options.active_processes,
)
[docs] def check_for_errors(self):
"""See superclass documentation."""
if is_async_checkpointer(self._checkpointer):
self._checkpointer.check_for_errors() # pytype: disable=attribute-error
def _finalize_checkpoint(self, step: int):
"""Executes final actions just before the checkpoint write completes.
* Logs error if any.
* Records duration saved due to preemption if any.
Args:
step: finalized checkpoint step.
"""
if utils.is_primary_host(self._multiprocessing_options.primary_host):
try:
self.check_for_errors()
except Exception as e: # pylint: disable=broad-except
logging.error(
(
'Received error: %s from Checkpointer. One or more items may'
' not be finalized. Skipping finalization of step checkpoint.'
),
e,
)
return None
# If at a preemption step, record the time since the previous checkpoint.
# This represents training time that would otherwise have been wasted.
# If another checkpoint has not been previously saved, measures the time
# since program start.
if self.reached_preemption(step):
if len(self._checkpoints) > 1:
previous_time = self._checkpoints[-2].time
else:
previous_time = _INIT_TIME
assert self._checkpoints
duration = self._checkpoints[-1].time - previous_time
jax.monitoring.record_event_duration_secs(
'/jax/checkpoint/write/preempt/duration_saved_secs',
duration.total_seconds(),
)
def _finalize(self, directory: epath.Path, steps_to_remove: List[int]):
"""Cleans up old checkpoints and synchronizes hosts."""
self._wait_for_checkpointers()
# If an error is encountered while waiting for commit futures to complete,
# we will not proceed past this point.
self._finalize_checkpoint(utils.step_from_checkpoint_name(directory.name))
remove_steps_start_time = time.time()
self._checkpoint_deleter.delete_steps(steps_to_remove)
jax.monitoring.record_event_duration_secs(
'/jax/checkpoint/write/remove_steps_duration_secs',
time.time() - remove_steps_start_time,
)
[docs] def close(self):
"""See superclass documentation."""
self.wait_until_finished()
self._checkpointer.close()
self._checkpoint_deleter.close()
def __contextmanager__(
self,
) -> Iterable[Self]:
try:
yield self
finally:
self.close()