# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for working with paths constructed from steps."""
import abc
import dataclasses
from typing import Generic, Iterator, Protocol, Sequence, TypeVar
from absl import logging
from etils import epath
from orbax.checkpoint._src.path import step as step_lib
from orbax.checkpoint.experimental.v1._src.training.metadata import types as training_metadata_types
CheckpointMetadata = training_metadata_types.CheckpointMetadata
MetadataT = TypeVar('MetadataT', bound='CheckpointMetadata')
class _StandardNameFormat(NameFormat[CheckpointMetadata[None]]):
"""NameFormat for 'standard' steps for common Orbax use cases."""
def __init__(
self,
step_prefix: str | None = None,
step_format_fixed_length: int | None = None,
single_host_load_and_broadcast: bool = False,
):
self._delegate = step_lib.standard_name_format(
step_prefix=step_prefix,
step_format_fixed_length=step_format_fixed_length,
single_host_load_and_broadcast=single_host_load_and_broadcast,
)
def build_name(self, step: int) -> str:
return self._delegate.build_name(step)
def find_all(
self, base_path: epath.PathLike
) -> Iterator[CheckpointMetadata[None]]:
result = self._delegate.find_all(base_path)
for metadata in result:
yield CheckpointMetadata(
step=metadata.step,
path=metadata.path,
metadata=None,
init_timestamp_nsecs=metadata.init_timestamp_nsecs,
commit_timestamp_nsecs=metadata.commit_timestamp_nsecs,
)
def find_step(
self, base_path: epath.PathLike, step: int
) -> CheckpointMetadata[None]:
result = self._delegate.find_step(base_path, step)
return CheckpointMetadata(
step=result.step,
path=result.path,
metadata=None,
init_timestamp_nsecs=result.init_timestamp_nsecs,
commit_timestamp_nsecs=result.commit_timestamp_nsecs,
)
@dataclasses.dataclass(frozen=True)
class _CompositeNameFormat(NameFormat[CheckpointMetadata[None]]):
"""A NameFormat that supports reading multiple step namings, but just one format to write.
Attributes:
write_name_format: NameFormat used to build step names meant for writing
checkpoints. Must be present in `read_name_formats` at a preferred
priority position.
read_name_formats: Sequence (ordered) of NameFormats used to find steps for
reading checkpoints. It acts like an *or*, where the first one to match is
returned.
"""
write_name_format: NameFormat[CheckpointMetadata[None]]
read_name_formats: Sequence[NameFormat[CheckpointMetadata[None]]]
def __post_init__(self):
if self.write_name_format not in self.read_name_formats:
read_name_formats = ','.join(str(f) for f in self.read_name_formats)
raise ValueError(
f'write_name_format: {self.write_name_format} must be present in'
f' read_name_formats: [{read_name_formats}].'
)
def __str__(self):
read_name_formats = ','.join(str(f) for f in self.read_name_formats)
return f'Composite([{read_name_formats}])'
def build_name(self, step: int) -> str:
"""Returns `step` name using `write_name_format`."""
return self.write_name_format.build_name(step)
def find_all(
self, base_path: epath.PathLike
) -> Iterator[CheckpointMetadata[None]]:
"""Returns metadata of all steps."""
found_paths = set()
for read_name_format in self.read_name_formats:
for step_metadata in read_name_format.find_all(base_path):
if step_metadata.path not in found_paths:
found_paths.add(step_metadata.path)
yield step_metadata
def find_step(
self, base_path: epath.PathLike, step: int
) -> CheckpointMetadata[None]:
"""Returns the metadata for `step` or raises ValueError."""
errors = [] # Used to raise the final collated error if needed.
for read_name_format in self.read_name_formats:
try:
return read_name_format.find_step(base_path, step)
except Exception as e: # pylint: disable=broad-exception-caught
logging.info(
'Failed to find step=%s with NameFormat=%s under %s. Error: %s',
step,
read_name_format,
base_path,
e,
)
errors.append(e)
# Raise the concatenated errors.
messages = [f'{e}' for e in errors]
raise ValueError('\n'.join(messages))