# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Orbax step storage entities.
Please reserve this file for entities related to `NameFormat`. Other existing
functions should be moved out over time.
"""
import abc
from collections.abc import Iterable
from concurrent import futures
import dataclasses
import datetime
import functools
import os
import re
import time
from typing import Callable, Generic, Iterator, List, Optional, Protocol, Sequence, TypeVar
from absl import logging
from etils import epath
import numpy as np
from orbax.checkpoint._src import asyncio_utils
from orbax.checkpoint._src.metadata import checkpoint
from orbax.checkpoint._src.metadata import step_metadata_serialization
from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint._src.path import gcs_utils
from orbax.checkpoint._src.path import temporary_paths
# Allowed checkpoint step naming using any non empty `step_prefix`.
ALLOWED_STEP_NAME_PATTERN = r'(.+)_(\d+)'
TMP_DIR_SUFFIX = temporary_paths.TMP_DIR_SUFFIX
# prefix_1000.orbax-checkpoint-tmp-1010101
# OR
# 1000.orbax-checkpoint-tmp-1010101
TMP_DIR_STEP_PATTERN = r'.*?_*?(\d+)\.orbax-checkpoint-tmp'
MetadataT = TypeVar('MetadataT', bound='Metadata')
is_path_finalized = lambda *a, **k: asyncio_utils.run_sync(
temporary_paths.is_path_finalized(*a, **k)
)
is_path_temporary = lambda *a, **k: asyncio_utils.run_sync(
temporary_paths.is_path_temporary(*a, **k)
)
all_temporary_paths = lambda *a, **k: asyncio_utils.run_sync(
temporary_paths.all_temporary_paths(*a, **k)
)
# Deprecated aliases, use the above functions, or use the temporary_paths module
# directly instead.
is_checkpoint_finalized = is_path_finalized
is_tmp_checkpoint = is_path_temporary
tmp_checkpoints = lambda *a, **k: [
p.get().name for p in all_temporary_paths(*a, **k)
]
def _is_valid_base_path(base_path: epath.PathLike) -> bool:
"""Validates base_path and returns it as an epath.Path."""
base_path = epath.Path(base_path)
return base_path.exists() and base_path.is_dir()
[docs]
def build_step_path(
base_path: epath.PathLike, name_format: NameFormat[Metadata], step: int
) -> epath.Path:
"""Returns `step` path under `base_path` for step `name_format`."""
return epath.Path(base_path) / name_format.build_name(step)
[docs]
def step_prefix_with_underscore(step_prefix: Optional[str]) -> str:
"""Returns `step_prefix` appended with `underscore` or <empty> if None."""
return '' if step_prefix is None else f'{step_prefix}_'
# TODO(b/337858698): Works with CompositeNameFormat.write_name_format only. Also
# support read_name_formats.
[docs]
def find_step_path(
base_path: epath.PathLike,
name_format: NameFormat[Metadata],
*,
step: int,
include_uncommitted: bool = False,
) -> epath.Path:
"""Returns `step` path under `base_path` for step `name_format`.
NOTE: Experimental function, subject to change.
Args:
base_path: directory path containing step subdirs.
name_format: NameFormat of the target `step`.
step: target step number.
include_uncommitted: if True then uncommitted steps are considered in search
too, otherwise only committed steps are looked up.
Raises:
ValueError if the target step path does not exist.
"""
base_path = epath.Path(base_path)
if not include_uncommitted:
return name_format.find_step(base_path, step).path
# First try finding uncommitted step.
uncommitted_step_path = None
for tmp_path in all_temporary_paths(base_path):
if tmp_path.get_final() == build_step_path(base_path, name_format, step):
uncommitted_step_path = tmp_path.get()
break
if uncommitted_step_path and uncommitted_step_path.exists():
return uncommitted_step_path
# Uncommitted step not found, return committed one or raise error.
return name_format.find_step(base_path, step).path
def maybe_find_step_metadata(
base_path: epath.PathLike,
name_format: NameFormat[Metadata],
*,
step: int,
) -> Metadata | None:
"""Returns `Metadata` for `step` with `name_format` or None."""
try:
return name_format.find_step(base_path, step)
except ValueError:
return None
@dataclasses.dataclass(frozen=True)
class _StandardNameFormat(NameFormat[Metadata]):
"""NameFormat for 'standard' steps for common Orbax use cases.
NOTE: Ignores uncommitted checkpoints.
Naming examples:
* step_prefix=None step_format_fixed_length=None -> 23
* step_prefix=None step_format_fixed_length=4 -> 0023
* step_prefix=step step_format_fixed_length=None -> step_23
* step_prefix=step step_format_fixed_length=4 -> step_0023
Attributes:
step_prefix: Optional fixed string prefixed to step. Note an *underscore* is
appended before applying it.
step_format_fixed_length: Optional length of the zero padded step. e.g. 6
for 000123.
single_host_load_and_broadcast: If True, the jax process=0 will list all
steps and broadcast them to all other processes. NOTE: Ignored if jax
backend is not multi controller.
"""
step_prefix: Optional[str] = None
step_format_fixed_length: Optional[int] = None
single_host_load_and_broadcast: bool = False
def __str__(self):
return f'StandardNameFormat("{self.build_name(1234)}")'
def build_name(self, step: int) -> str:
"""Returns `(prefix_)?(zero padding)?step` name."""
if self.step_format_fixed_length is not None:
step_str = f'{step:0{self.step_format_fixed_length}d}'
else:
step_str = f'{step}'
# [prefix]step
return f'{step_prefix_with_underscore(self.step_prefix)}{step_str}'
def _build_metadata(
self, step_path: epath.Path, step: Optional[int] = None
) -> Optional[Metadata]:
"""Returns metadata for given `step_path` if it is valid or None."""
if not is_path_finalized(step_path):
return None
if step is not None:
# step already known, just check exists.
if step_path.exists():
return Metadata(step=step, path=step_path)
# Regex: [prefix]*(step)
if self.step_format_fixed_length and self.step_format_fixed_length > 0:
zero_present = rf'0\d{{{self.step_format_fixed_length-1}}}'
zero_not_present = rf'[1-9]\d{{{self.step_format_fixed_length-1}}}\d*'
zero_padded_step_group = rf'({zero_present}|{zero_not_present})'
else:
zero_padded_step_group = r'(0|[1-9]\d*)'
name_regex = f'^{step_prefix_with_underscore(self.step_prefix)}{zero_padded_step_group}$'
match = re.search(name_regex, step_path.name)
if match is None:
return None
(step_,) = match.groups()
step_ = int(step_)
return Metadata(step=step_, path=step_path)
def _glob_step_paths(self, base_path: epath.PathLike) -> list[epath.Path]:
"""Returns step paths under `base_path`."""
base_path = epath.Path(base_path)
# <step_prefix>_?<0 padding>?*
if gcs_utils.is_hierarchical_namespace_enabled(base_path):
logging.vlog(
1,
'HNS enabled. Using GCS API to list step paths at %s',
base_path.as_posix(),
)
bucket_name, path_prefix = gcs_utils.parse_gcs_path(base_path)
bucket = gcs_utils.get_bucket(bucket_name)
result = bucket.list_blobs(
prefix=path_prefix,
delimiter='/',
include_folders_as_prefixes=True,
)
# Iterate over pages to force a fetch from the server, after which
# `result.prefixes` will be populated.
for _ in result.pages:
pass
return [
epath.Path(f'gs://{bucket_name}/{folder}')
for folder in result.prefixes
if folder.startswith(
os.path.join(path_prefix, self.step_prefix or '')
)
]
else:
prefix = step_prefix_with_underscore(self.step_prefix)
return [x for x in base_path.iterdir() if x.name.startswith(prefix)]
def _get_step_paths_and_total_steps(
self, base_path: epath.PathLike, is_primary_host: bool
) -> tuple[list[epath.Path], int]:
"""Broadcasts total steps and get step paths from the primary host."""
step_paths = []
total_steps = -1
if is_primary_host:
# <step_prefix>_?<0 padding>?*
step_paths = self._glob_step_paths(base_path)
logging.info(
'[process=%s][single_host_load_and_broadcast] step_paths=%s,'
' root_dir=%s',
multihost.process_index(),
step_paths,
base_path,
)
total_steps = len(step_paths)
total_steps = int(
multihost.broadcast_one_to_all(total_steps, is_source=is_primary_host)
)
return step_paths, total_steps
def _get_broadcasted_step_list(
self,
step_paths: list[epath.Path],
total_steps: int,
is_primary_host: bool,
) -> np.ndarray:
"""Builds the step list on the primary host and broadcasts it."""
padded_step_list = np.array([-1] * total_steps)
if is_primary_host:
steps = []
with futures.ThreadPoolExecutor() as executor:
metadata_futures = [
executor.submit(self._build_metadata, step_path) # File IO
for step_path in step_paths
]
for future in futures.as_completed(metadata_futures):
metadata = future.result()
if metadata is not None:
steps.append(metadata.step)
steps.sort()
steps = np.array(steps)
assert (
len(steps) <= total_steps
), f'len(steps)={len(steps)} > total_steps={total_steps}'
padded_step_list[0 : len(steps)] = steps
return multihost.broadcast_one_to_all(
padded_step_list, is_source=is_primary_host
)
def _find_all_with_single_host_load_and_broadcast(
self, base_path: epath.PathLike
) -> Iterator[Metadata]:
"""Returns metadata by validating and broadcasting from process_index=0.
Even though step_paths are gathered on all hosts, the validation and
extraction of step numbers of complete checkpoints is done only on the
process_index=0 (this involves multiple file I/O operations and arguably
improves performance by reducing QPS to underlying storages like GCS). The
step numbers are then broadcast to all hosts. In the end, all hosts build
the metadata from step numbers and return.
Args:
base_path: Root path under which step folders are placed.
"""
process_index = multihost.process_index()
is_primary_host = process_index == 0
time_start = time.time()
step_paths, total_steps = self._get_step_paths_and_total_steps(
base_path, is_primary_host
)
logging.info(
'[process=%s][single_host_load_and_broadcast] total_steps=%s,'
' step_paths=%s, root_dir=%s',
process_index,
total_steps,
step_paths,
base_path,
)
if total_steps <= 0:
logging.info(
'[process=%s][single_host_load_and_broadcast] No steps found,'
' root_dir=%s',
process_index,
base_path,
)
return iter([])
time_glob_list = time.time()
padded_step_list = self._get_broadcasted_step_list(
step_paths, total_steps, is_primary_host
)
base_path = epath.Path(base_path)
paths_to_step_dict: dict[epath.Path, int] = {
base_path / self.build_name(step): step
for step in padded_step_list
if step >= 0
}
metadatas = build_step_metadatas(
paths_to_step_dict.keys(),
lambda step_path: Metadata(
step=paths_to_step_dict[step_path], path=step_path
),
)
time_build_metadata = time.time()
logging.info(
'[process=%s][single_host_load_and_broadcast] Found #steps=%s,'
' root_dir=%s, time_taken_to_glob=%ss,'
' time_taken_to_build_metadata=%ss, total_time=%ss',
process_index,
len(paths_to_step_dict),
base_path,
time_glob_list - time_start,
time_build_metadata - time_glob_list,
time_build_metadata - time_start,
)
return metadatas
def find_all(self, base_path: epath.PathLike) -> Iterator[Metadata]:
"""Returns metadata of all steps matching with name_format attributes."""
if not _is_valid_base_path(base_path):
return iter([])
# Note: the order of conjuncts is important here; we should not call
# `multihost.process_count()` when `single_host_load_and_broadcast` is False
# as this has the possible side effect of initializing the jax backend. See
# b/454565916 for details.
if self.single_host_load_and_broadcast and multihost.process_count() > 1:
return self._find_all_with_single_host_load_and_broadcast(base_path)
# <step_prefix>_?<0 padding>?*
step_paths = self._glob_step_paths(base_path)
return build_step_metadatas(step_paths, self._build_metadata)
def find_step(self, base_path: epath.PathLike, step: int) -> Metadata:
"""Returns the metadata for `step` or raises ValueError."""
if not _is_valid_base_path(base_path):
raise ValueError(
f'Invalid base_path: {base_path} does not exist or is not a'
' directory.'
)
step_path = build_step_path(base_path, self, step)
metadata = self._build_metadata(step_path, step=step)
if metadata is not None:
return metadata
# Raise detailed error message.
raise ValueError(
f'No step path found with name={self.build_name(step)},'
f' NameFormat={self} for step={step} under {base_path}.'
)
@dataclasses.dataclass(frozen=True)
class _CompositeNameFormat(NameFormat[Metadata]):
"""Supports reading multiple step namings, but just one format to write.
Attributes:
write_name_format: NameFormat used to build step names meant for writing
checkpoints. Must be present in `read_name_formats` at a preferred
priority position.
read_name_formats: Sequence (ordered) of NameFormats used to find steps for
reading checkpoints. It acts like an *or*, where the first one to match is
returned.
"""
write_name_format: NameFormat[Metadata]
read_name_formats: Sequence[NameFormat[Metadata]]
def __post_init__(self):
if self.write_name_format not in self.read_name_formats:
read_name_formats = ','.join(str(f) for f in self.read_name_formats)
raise ValueError(
f'write_name_format: {self.write_name_format} must be present in'
f' read_name_formats: [{read_name_formats}].'
)
def __str__(self):
read_name_formats = ','.join(str(f) for f in self.read_name_formats)
return f'Composite([{read_name_formats}])'
def build_name(self, step: int) -> str:
"""Returns `step` name using `write_name_format`."""
return self.write_name_format.build_name(step)
def find_all(self, base_path: epath.PathLike) -> Iterator[Metadata]:
"""Returns metadata of all steps."""
found_paths = set()
for read_name_format in self.read_name_formats:
for step_metadata in read_name_format.find_all(base_path):
if step_metadata.path not in found_paths:
found_paths.add(step_metadata.path)
yield step_metadata
def find_step(self, base_path: epath.PathLike, step: int) -> Metadata:
"""Returns the metadata for `step` or raises ValueError."""
errors = [] # Used to raise the final collated error if needed.
for read_name_format in self.read_name_formats:
try:
return read_name_format.find_step(base_path, step)
except Exception as e: # pylint: disable=broad-exception-caught
logging.info(
'Failed to find step=%s with NameFormat=%s under %s. Error: %s',
step,
read_name_format,
base_path,
e,
)
errors.append(e)
# Raise the concatenated errors.
messages = [f'{e}' for e in errors]
raise ValueError('\n'.join(messages))
# TODO(b/337137764) Can't move it to path/utils due to cyclic dependency.
# Explore other options.
[docs]
def get_save_directory(
step: int,
directory: epath.PathLike,
name: Optional[str] = None,
step_prefix: Optional[str] = None,
override_directory: Optional[epath.PathLike] = None,
step_format_fixed_length: Optional[int] = None,
step_name_format: Optional[NameFormat[Metadata]] = None,
) -> epath.Path:
"""Returns the standardized path to a save directory for a single item.
Args:
step: Step number.
directory: Top level checkpoint directory.
name: Item name ('params', 'state', 'dataset', etc.).
step_prefix: Prefix applied to `step` (e.g. 'checkpoint').
override_directory: If provided, step, directory, and step_prefix are
ignored.
step_format_fixed_length: Uses a fixed number of digits with leading zeros
to represent the step number. If None, there are no leading zeros.
step_name_format: NameFormat used to define step name for step and under
given root directory. If provided, `step_prefix` and
`step_format_fixed_length` are ignored.
Returns:
A directory.
"""
if directory is None:
raise ValueError('Directory cannot be None.')
directory = epath.Path(directory)
if override_directory is not None:
result = epath.Path(override_directory)
else:
step_name_format = step_name_format or standard_name_format(
step_prefix=step_prefix,
step_format_fixed_length=step_format_fixed_length,
)
result = build_step_path(directory, step_name_format, step)
if name is not None:
result /= name
return result
def _is_legacy_step_checkpoint(path: epath.Path) -> bool:
"""Determines if the path resembles an Orbax step directory.
Note that this is not foolproof, and users should not add extra files to the
checkpoint directory beyond what is done by CheckpointManager.
Args:
path: path to check.
Returns:
bool indicating whether the path resembles an Orbax step directory.
"""
name = path.name
# Path must be a directory and either a digit, or end in '_' + digit.
return path.is_dir() and (name.isdigit() or name.split('_')[-1].isdigit())
def _is_legacy_finalized_step_checkpoint(path: epath.Path) -> bool:
return _is_legacy_step_checkpoint(path) and is_path_finalized(path)
[docs]
def step_from_checkpoint_name(name: str) -> int:
"""Returns the step from a checkpoint name. Also works for tmp checkpoints."""
if name.isdigit():
return int(name)
elif m := re.fullmatch(ALLOWED_STEP_NAME_PATTERN, name):
return int(m.group(2))
elif tmp_match := re.match(TMP_DIR_STEP_PATTERN, name):
return int(tmp_match.group(1))
raise ValueError(f'Unrecognized name format: {name}.')
[docs]
def checkpoint_steps_paths(
checkpoint_dir: epath.PathLike,
) -> List[epath.Path]:
"""Returns a list of finalized checkpoint paths in the directory."""
checkpoint_dir = epath.Path(checkpoint_dir)
if not checkpoint_dir.exists():
return []
with futures.ThreadPoolExecutor() as executor:
fs = {
step_dir: executor.submit(
_is_legacy_finalized_step_checkpoint, step_dir
)
for step_dir in checkpoint_dir.iterdir()
}
return [step_dir for step_dir, future in fs.items() if future.result()]
def is_standard_name_format(name_format: NameFormat[Metadata]) -> bool:
"""Returns True if the name format is a standard name format."""
return isinstance(name_format, _StandardNameFormat)
def single_host_load_and_broadcast_name_format(
name_format: NameFormat[Metadata],
) -> NameFormat[Metadata]:
"""Returns a name format with single_host_load_and_broadcast enabled."""
if is_standard_name_format(name_format):
return dataclasses.replace(name_format, single_host_load_and_broadcast=True) # pytype: disable=wrong-arg-types
else:
raise ValueError(
'single_host_load_and_broadcast is only supported for standard name'
f' formats. Got {name_format}.'
)
[docs]
def checkpoint_steps(checkpoint_dir: epath.PathLike) -> List[int]:
"""Returns a list of finalized checkpoint steps in the directory."""
return [
step_from_checkpoint_name(s.name)
for s in checkpoint_steps_paths(checkpoint_dir)
]
[docs]
def any_checkpoint_step(checkpoint_dir: epath.PathLike) -> Optional[int]:
"""Returns any finalized checkpoint step in the directory or None.
This avoids iterating over the entire directory.
Args:
checkpoint_dir: Checkpoint directory.
Returns:
Any finalized checkpoint step in the directory or None.
"""
checkpoint_dir = epath.Path(checkpoint_dir)
for s in checkpoint_dir.iterdir():
if _is_legacy_finalized_step_checkpoint(s):
return step_from_checkpoint_name(s.name)
return None