Source code for orbax.checkpoint._src.path.step

# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Orbax step storage entities.

Please reserve this file for entities related to `NameFormat`. Other existing
functions should be moved out over time.
"""

import abc
from collections.abc import Iterable
from concurrent import futures
import dataclasses
import datetime
import functools
import os
import re
import time
from typing import Callable, Generic, Iterator, List, Optional, Protocol, Sequence, TypeVar
from absl import logging
from etils import epath
import numpy as np
from orbax.checkpoint._src import asyncio_utils
from orbax.checkpoint._src.metadata import checkpoint
from orbax.checkpoint._src.metadata import step_metadata_serialization
from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint._src.path import gcs_utils
from orbax.checkpoint._src.path import temporary_paths

# Allowed checkpoint step naming using any non empty `step_prefix`.
ALLOWED_STEP_NAME_PATTERN = r'(.+)_(\d+)'

TMP_DIR_SUFFIX = temporary_paths.TMP_DIR_SUFFIX
# prefix_1000.orbax-checkpoint-tmp-1010101
# OR
# 1000.orbax-checkpoint-tmp-1010101
TMP_DIR_STEP_PATTERN = r'.*?_*?(\d+)\.orbax-checkpoint-tmp'

MetadataT = TypeVar('MetadataT', bound='Metadata')


is_path_finalized = lambda *a, **k: asyncio_utils.run_sync(
    temporary_paths.is_path_finalized(*a, **k)
)
is_path_temporary = lambda *a, **k: asyncio_utils.run_sync(
    temporary_paths.is_path_temporary(*a, **k)
)
all_temporary_paths = lambda *a, **k: asyncio_utils.run_sync(
    temporary_paths.all_temporary_paths(*a, **k)
)

# Deprecated aliases, use the above functions, or use the temporary_paths module
# directly instead.
is_checkpoint_finalized = is_path_finalized
is_tmp_checkpoint = is_path_temporary
tmp_checkpoints = lambda *a, **k: [
    p.get().name for p in all_temporary_paths(*a, **k)
]


def _is_valid_base_path(base_path: epath.PathLike) -> bool:
  """Validates base_path and returns it as an epath.Path."""
  base_path = epath.Path(base_path)
  return base_path.exists() and base_path.is_dir()


[docs] @dataclasses.dataclass(frozen=True) class Metadata: """Metadata of a step. Attributes: step: step number of the checkpoint. path: path to the checkpoint. """ step: int path: epath.Path @functools.cached_property def _checkpoint_metadata(self) -> Optional[checkpoint.StepMetadata]: """Returns checkpoint metadata of this step.""" metadata_dict = checkpoint.metadata_store(enable_write=False).read( file_path=checkpoint.step_metadata_file_path(self.path) ) if metadata_dict is None: return None return step_metadata_serialization.deserialize(metadata_dict) @property def init_timestamp_nsecs(self) -> Optional[int]: """Returns init timestamp of uncommitted checkpoint of this step. It is specified as nano seconds since epoch. """ metadata = self._checkpoint_metadata if metadata is None: return None return metadata.init_timestamp_nsecs @property def commit_timestamp_nsecs(self) -> Optional[int]: """Returns commit timestamp of the checkpoint of this step. It is specified as nano seconds since epoch. """ metadata = self._checkpoint_metadata if metadata is None: return None return metadata.commit_timestamp_nsecs @property def commit_timestamp(self) -> datetime.datetime: """Returns commit timestamp of the checkpoint of this step. It is specified as datetime in UTC timezone. """ commit_timestamp_nsecs = self.commit_timestamp_nsecs if commit_timestamp_nsecs is not None: timestamp_sec = commit_timestamp_nsecs / 1e9 else: timestamp_sec = self.path.stat().mtime return datetime.datetime.fromtimestamp( timestamp_sec, tz=datetime.timezone.utc )
[docs] class NameFormat(Protocol, Generic[MetadataT]): """Protocol responsible for naming and querying steps."""
[docs] def build_name(self, step: int) -> str: """Returns `step` name. *Implementation hint:* Implement it to build a name for the given step using the class's custom formatting attributes. Since it is mainly meant for building names to save checkpoints, it can raise error if this NameFormat is just meant for finding already existing step paths. Args: step: Step number. """ ...
[docs] @abc.abstractmethod def find_all(self, base_path: epath.PathLike) -> Iterator[MetadataT]: """Returns metadata of all steps. NOTE: Ignores uncommitted checkpoints. *Implementation hint:* Implement it to find all step folders under `base_path` performing IO operations if needed. Use `build_step_metadatas(...)` helper function to build all the `MetadataT` using the found step paths. Args: base_path: *root* Path under which Step folders are placed. """ ...
[docs] @abc.abstractmethod def find_step(self, base_path: epath.PathLike, step: int) -> MetadataT: """Returns the metadata for `step` or raises ValueError. NOTE: Ignores uncommitted checkpoints. *Implementation hint:* Implement it to find the step folder under `base_path` performing IO operations if needed. Args: base_path: *root* Path under which Step folders are placed. step: Step number. Raises: ValueError if no committed paths for the requested step is found. """ ...
[docs] def build_step_path( base_path: epath.PathLike, name_format: NameFormat[Metadata], step: int ) -> epath.Path: """Returns `step` path under `base_path` for step `name_format`.""" return epath.Path(base_path) / name_format.build_name(step)
[docs] def build_step_metadatas( step_paths: Iterable[epath.Path], build_metadata: Callable[[epath.Path], Optional[MetadataT]], ) -> Iterator[MetadataT]: """Yields filtered metadata mapped with `step_paths`. Args: step_paths: Iterator of step paths. build_metadata: Callable to match and build step metadata from `step_paths` elements. If a `step_paths` element doesn't match then it returns None. Yields: Step metadata. """ with futures.ThreadPoolExecutor() as executor: metadata_futures = [ executor.submit(build_metadata, step_path) for step_path in step_paths ] for future in futures.as_completed(metadata_futures): metadata = future.result() if metadata is not None: yield metadata
[docs] def step_prefix_with_underscore(step_prefix: Optional[str]) -> str: """Returns `step_prefix` appended with `underscore` or <empty> if None.""" return '' if step_prefix is None else f'{step_prefix}_'
[docs] def latest_step_metadata( root_path: epath.PathLike, name_format: NameFormat[MetadataT] ) -> Optional[MetadataT]: """Returns step.MetadataT of the latest step in `root_path`.""" return max( sorted( name_format.find_all(root_path), key=lambda metadata: metadata.path.name, reverse=True, ), default=None, key=lambda metadata: metadata.step, )
[docs] def step_metadata_of_checkpoint_path( checkpoint_path: epath.PathLike, name_format: NameFormat[MetadataT] ) -> MetadataT: """Returns step.MetadataT of given `checkpoint_path`.""" checkpoint_path = epath.Path(checkpoint_path) all_step_metadata = list(name_format.find_all(checkpoint_path.parent)) for step_metadata in all_step_metadata: if step_metadata.path.name == checkpoint_path.name: return step_metadata raise ValueError( 'Failed to resolve step metadata of checkpoint path with' f' NameFormat={name_format}, checkpoint path={checkpoint_path}, path' f' name({checkpoint_path.name}) did not match with available step names:' f' {[step_metadata.path.name for step_metadata in all_step_metadata]}.' ' Please check if the given path is really a checkpoint path.' )
# TODO(b/337858698): Works with CompositeNameFormat.write_name_format only. Also # support read_name_formats.
[docs] def find_step_path( base_path: epath.PathLike, name_format: NameFormat[Metadata], *, step: int, include_uncommitted: bool = False, ) -> epath.Path: """Returns `step` path under `base_path` for step `name_format`. NOTE: Experimental function, subject to change. Args: base_path: directory path containing step subdirs. name_format: NameFormat of the target `step`. step: target step number. include_uncommitted: if True then uncommitted steps are considered in search too, otherwise only committed steps are looked up. Raises: ValueError if the target step path does not exist. """ base_path = epath.Path(base_path) if not include_uncommitted: return name_format.find_step(base_path, step).path # First try finding uncommitted step. uncommitted_step_path = None for tmp_path in all_temporary_paths(base_path): if tmp_path.get_final() == build_step_path(base_path, name_format, step): uncommitted_step_path = tmp_path.get() break if uncommitted_step_path and uncommitted_step_path.exists(): return uncommitted_step_path # Uncommitted step not found, return committed one or raise error. return name_format.find_step(base_path, step).path
def maybe_find_step_metadata( base_path: epath.PathLike, name_format: NameFormat[Metadata], *, step: int, ) -> Metadata | None: """Returns `Metadata` for `step` with `name_format` or None.""" try: return name_format.find_step(base_path, step) except ValueError: return None @dataclasses.dataclass(frozen=True) class _StandardNameFormat(NameFormat[Metadata]): """NameFormat for 'standard' steps for common Orbax use cases. NOTE: Ignores uncommitted checkpoints. Naming examples: * step_prefix=None step_format_fixed_length=None -> 23 * step_prefix=None step_format_fixed_length=4 -> 0023 * step_prefix=step step_format_fixed_length=None -> step_23 * step_prefix=step step_format_fixed_length=4 -> step_0023 Attributes: step_prefix: Optional fixed string prefixed to step. Note an *underscore* is appended before applying it. step_format_fixed_length: Optional length of the zero padded step. e.g. 6 for 000123. single_host_load_and_broadcast: If True, the jax process=0 will list all steps and broadcast them to all other processes. NOTE: Ignored if jax backend is not multi controller. """ step_prefix: Optional[str] = None step_format_fixed_length: Optional[int] = None single_host_load_and_broadcast: bool = False def __str__(self): return f'StandardNameFormat("{self.build_name(1234)}")' def build_name(self, step: int) -> str: """Returns `(prefix_)?(zero padding)?step` name.""" if self.step_format_fixed_length is not None: step_str = f'{step:0{self.step_format_fixed_length}d}' else: step_str = f'{step}' # [prefix]step return f'{step_prefix_with_underscore(self.step_prefix)}{step_str}' def _build_metadata( self, step_path: epath.Path, step: Optional[int] = None ) -> Optional[Metadata]: """Returns metadata for given `step_path` if it is valid or None.""" if not is_path_finalized(step_path): return None if step is not None: # step already known, just check exists. if step_path.exists(): return Metadata(step=step, path=step_path) # Regex: [prefix]*(step) if self.step_format_fixed_length and self.step_format_fixed_length > 0: zero_present = rf'0\d{{{self.step_format_fixed_length-1}}}' zero_not_present = rf'[1-9]\d{{{self.step_format_fixed_length-1}}}\d*' zero_padded_step_group = rf'({zero_present}|{zero_not_present})' else: zero_padded_step_group = r'(0|[1-9]\d*)' name_regex = f'^{step_prefix_with_underscore(self.step_prefix)}{zero_padded_step_group}$' match = re.search(name_regex, step_path.name) if match is None: return None (step_,) = match.groups() step_ = int(step_) return Metadata(step=step_, path=step_path) def _glob_step_paths(self, base_path: epath.PathLike) -> list[epath.Path]: """Returns step paths under `base_path`.""" base_path = epath.Path(base_path) # <step_prefix>_?<0 padding>?* if gcs_utils.is_hierarchical_namespace_enabled(base_path): logging.vlog( 1, 'HNS enabled. Using GCS API to list step paths at %s', base_path.as_posix(), ) bucket_name, path_prefix = gcs_utils.parse_gcs_path(base_path) bucket = gcs_utils.get_bucket(bucket_name) result = bucket.list_blobs( prefix=path_prefix, delimiter='/', include_folders_as_prefixes=True, ) # Iterate over pages to force a fetch from the server, after which # `result.prefixes` will be populated. for _ in result.pages: pass return [ epath.Path(f'gs://{bucket_name}/{folder}') for folder in result.prefixes if folder.startswith( os.path.join(path_prefix, self.step_prefix or '') ) ] else: prefix = step_prefix_with_underscore(self.step_prefix) return [x for x in base_path.iterdir() if x.name.startswith(prefix)] def _get_step_paths_and_total_steps( self, base_path: epath.PathLike, is_primary_host: bool ) -> tuple[list[epath.Path], int]: """Broadcasts total steps and get step paths from the primary host.""" step_paths = [] total_steps = -1 if is_primary_host: # <step_prefix>_?<0 padding>?* step_paths = self._glob_step_paths(base_path) logging.info( '[process=%s][single_host_load_and_broadcast] step_paths=%s,' ' root_dir=%s', multihost.process_index(), step_paths, base_path, ) total_steps = len(step_paths) total_steps = int( multihost.broadcast_one_to_all(total_steps, is_source=is_primary_host) ) return step_paths, total_steps def _get_broadcasted_step_list( self, step_paths: list[epath.Path], total_steps: int, is_primary_host: bool, ) -> np.ndarray: """Builds the step list on the primary host and broadcasts it.""" padded_step_list = np.array([-1] * total_steps) if is_primary_host: steps = [] with futures.ThreadPoolExecutor() as executor: metadata_futures = [ executor.submit(self._build_metadata, step_path) # File IO for step_path in step_paths ] for future in futures.as_completed(metadata_futures): metadata = future.result() if metadata is not None: steps.append(metadata.step) steps.sort() steps = np.array(steps) assert ( len(steps) <= total_steps ), f'len(steps)={len(steps)} > total_steps={total_steps}' padded_step_list[0 : len(steps)] = steps return multihost.broadcast_one_to_all( padded_step_list, is_source=is_primary_host ) def _find_all_with_single_host_load_and_broadcast( self, base_path: epath.PathLike ) -> Iterator[Metadata]: """Returns metadata by validating and broadcasting from process_index=0. Even though step_paths are gathered on all hosts, the validation and extraction of step numbers of complete checkpoints is done only on the process_index=0 (this involves multiple file I/O operations and arguably improves performance by reducing QPS to underlying storages like GCS). The step numbers are then broadcast to all hosts. In the end, all hosts build the metadata from step numbers and return. Args: base_path: Root path under which step folders are placed. """ process_index = multihost.process_index() is_primary_host = process_index == 0 time_start = time.time() step_paths, total_steps = self._get_step_paths_and_total_steps( base_path, is_primary_host ) logging.info( '[process=%s][single_host_load_and_broadcast] total_steps=%s,' ' step_paths=%s, root_dir=%s', process_index, total_steps, step_paths, base_path, ) if total_steps <= 0: logging.info( '[process=%s][single_host_load_and_broadcast] No steps found,' ' root_dir=%s', process_index, base_path, ) return iter([]) time_glob_list = time.time() padded_step_list = self._get_broadcasted_step_list( step_paths, total_steps, is_primary_host ) base_path = epath.Path(base_path) paths_to_step_dict: dict[epath.Path, int] = { base_path / self.build_name(step): step for step in padded_step_list if step >= 0 } metadatas = build_step_metadatas( paths_to_step_dict.keys(), lambda step_path: Metadata( step=paths_to_step_dict[step_path], path=step_path ), ) time_build_metadata = time.time() logging.info( '[process=%s][single_host_load_and_broadcast] Found #steps=%s,' ' root_dir=%s, time_taken_to_glob=%ss,' ' time_taken_to_build_metadata=%ss, total_time=%ss', process_index, len(paths_to_step_dict), base_path, time_glob_list - time_start, time_build_metadata - time_glob_list, time_build_metadata - time_start, ) return metadatas def find_all(self, base_path: epath.PathLike) -> Iterator[Metadata]: """Returns metadata of all steps matching with name_format attributes.""" if not _is_valid_base_path(base_path): return iter([]) # Note: the order of conjuncts is important here; we should not call # `multihost.process_count()` when `single_host_load_and_broadcast` is False # as this has the possible side effect of initializing the jax backend. See # b/454565916 for details. if self.single_host_load_and_broadcast and multihost.process_count() > 1: return self._find_all_with_single_host_load_and_broadcast(base_path) # <step_prefix>_?<0 padding>?* step_paths = self._glob_step_paths(base_path) return build_step_metadatas(step_paths, self._build_metadata) def find_step(self, base_path: epath.PathLike, step: int) -> Metadata: """Returns the metadata for `step` or raises ValueError.""" if not _is_valid_base_path(base_path): raise ValueError( f'Invalid base_path: {base_path} does not exist or is not a' ' directory.' ) step_path = build_step_path(base_path, self, step) metadata = self._build_metadata(step_path, step=step) if metadata is not None: return metadata # Raise detailed error message. raise ValueError( f'No step path found with name={self.build_name(step)},' f' NameFormat={self} for step={step} under {base_path}.' )
[docs] def standard_name_format( *, step_prefix: Optional[str] = None, step_format_fixed_length: Optional[int] = None, single_host_load_and_broadcast: bool = False, ) -> NameFormat[Metadata]: """Returns NameFormat for 'standard' steps for common Orbax use cases. NOTE: Ignores uncommitted checkpoints. Naming examples: * step_prefix=None step_format_fixed_length=None -> 23 * step_prefix=None step_format_fixed_length=4 -> 0023 * step_prefix=step step_format_fixed_length=None -> step_23 * step_prefix=step step_format_fixed_length=4 -> step_0023 Args: step_prefix: Optional fixed string prefixed to step. Note an *underscore* is appended before applying it. step_format_fixed_length: Optional length of the zero padded step. e.g. 6 for 000123. single_host_load_and_broadcast: If True, the jax process=0 will list all steps and broadcast them to all other processes. NOTE: Ignored if jax backend is not multi controller. """ return _StandardNameFormat( step_prefix=step_prefix, step_format_fixed_length=step_format_fixed_length, single_host_load_and_broadcast=single_host_load_and_broadcast, )
@dataclasses.dataclass(frozen=True) class _CompositeNameFormat(NameFormat[Metadata]): """Supports reading multiple step namings, but just one format to write. Attributes: write_name_format: NameFormat used to build step names meant for writing checkpoints. Must be present in `read_name_formats` at a preferred priority position. read_name_formats: Sequence (ordered) of NameFormats used to find steps for reading checkpoints. It acts like an *or*, where the first one to match is returned. """ write_name_format: NameFormat[Metadata] read_name_formats: Sequence[NameFormat[Metadata]] def __post_init__(self): if self.write_name_format not in self.read_name_formats: read_name_formats = ','.join(str(f) for f in self.read_name_formats) raise ValueError( f'write_name_format: {self.write_name_format} must be present in' f' read_name_formats: [{read_name_formats}].' ) def __str__(self): read_name_formats = ','.join(str(f) for f in self.read_name_formats) return f'Composite([{read_name_formats}])' def build_name(self, step: int) -> str: """Returns `step` name using `write_name_format`.""" return self.write_name_format.build_name(step) def find_all(self, base_path: epath.PathLike) -> Iterator[Metadata]: """Returns metadata of all steps.""" found_paths = set() for read_name_format in self.read_name_formats: for step_metadata in read_name_format.find_all(base_path): if step_metadata.path not in found_paths: found_paths.add(step_metadata.path) yield step_metadata def find_step(self, base_path: epath.PathLike, step: int) -> Metadata: """Returns the metadata for `step` or raises ValueError.""" errors = [] # Used to raise the final collated error if needed. for read_name_format in self.read_name_formats: try: return read_name_format.find_step(base_path, step) except Exception as e: # pylint: disable=broad-exception-caught logging.info( 'Failed to find step=%s with NameFormat=%s under %s. Error: %s', step, read_name_format, base_path, e, ) errors.append(e) # Raise the concatenated errors. messages = [f'{e}' for e in errors] raise ValueError('\n'.join(messages))
[docs] def composite_name_format( write_name_format: NameFormat[Metadata], read_name_formats: Sequence[NameFormat[Metadata]], ) -> NameFormat[Metadata]: """Returns *composite* NameFormat supporting multiple read/single write formats. Args: write_name_format: NameFormat used to build step names meant for writing checkpoints. Must be present in `read_name_formats` at a preferred priority position. read_name_formats: Sequence (ordered) of NameFormats used to find steps for reading checkpoints. Please note that to resolve conflicts (and avoid raising errors) in case of multiple NameFormats matching a given step, the sequence should be provided in highest to lowest priority order: NameFormat appearing earlier in the sequence is preferred. """ return _CompositeNameFormat(write_name_format, read_name_formats)
# TODO(b/337137764) Can't move it to path/utils due to cyclic dependency. # Explore other options.
[docs] def get_save_directory( step: int, directory: epath.PathLike, name: Optional[str] = None, step_prefix: Optional[str] = None, override_directory: Optional[epath.PathLike] = None, step_format_fixed_length: Optional[int] = None, step_name_format: Optional[NameFormat[Metadata]] = None, ) -> epath.Path: """Returns the standardized path to a save directory for a single item. Args: step: Step number. directory: Top level checkpoint directory. name: Item name ('params', 'state', 'dataset', etc.). step_prefix: Prefix applied to `step` (e.g. 'checkpoint'). override_directory: If provided, step, directory, and step_prefix are ignored. step_format_fixed_length: Uses a fixed number of digits with leading zeros to represent the step number. If None, there are no leading zeros. step_name_format: NameFormat used to define step name for step and under given root directory. If provided, `step_prefix` and `step_format_fixed_length` are ignored. Returns: A directory. """ if directory is None: raise ValueError('Directory cannot be None.') directory = epath.Path(directory) if override_directory is not None: result = epath.Path(override_directory) else: step_name_format = step_name_format or standard_name_format( step_prefix=step_prefix, step_format_fixed_length=step_format_fixed_length, ) result = build_step_path(directory, step_name_format, step) if name is not None: result /= name return result
def _is_legacy_step_checkpoint(path: epath.Path) -> bool: """Determines if the path resembles an Orbax step directory. Note that this is not foolproof, and users should not add extra files to the checkpoint directory beyond what is done by CheckpointManager. Args: path: path to check. Returns: bool indicating whether the path resembles an Orbax step directory. """ name = path.name # Path must be a directory and either a digit, or end in '_' + digit. return path.is_dir() and (name.isdigit() or name.split('_')[-1].isdigit()) def _is_legacy_finalized_step_checkpoint(path: epath.Path) -> bool: return _is_legacy_step_checkpoint(path) and is_path_finalized(path)
[docs] def step_from_checkpoint_name(name: str) -> int: """Returns the step from a checkpoint name. Also works for tmp checkpoints.""" if name.isdigit(): return int(name) elif m := re.fullmatch(ALLOWED_STEP_NAME_PATTERN, name): return int(m.group(2)) elif tmp_match := re.match(TMP_DIR_STEP_PATTERN, name): return int(tmp_match.group(1)) raise ValueError(f'Unrecognized name format: {name}.')
[docs] def checkpoint_steps_paths( checkpoint_dir: epath.PathLike, ) -> List[epath.Path]: """Returns a list of finalized checkpoint paths in the directory.""" checkpoint_dir = epath.Path(checkpoint_dir) if not checkpoint_dir.exists(): return [] with futures.ThreadPoolExecutor() as executor: fs = { step_dir: executor.submit( _is_legacy_finalized_step_checkpoint, step_dir ) for step_dir in checkpoint_dir.iterdir() } return [step_dir for step_dir, future in fs.items() if future.result()]
def is_standard_name_format(name_format: NameFormat[Metadata]) -> bool: """Returns True if the name format is a standard name format.""" return isinstance(name_format, _StandardNameFormat) def single_host_load_and_broadcast_name_format( name_format: NameFormat[Metadata], ) -> NameFormat[Metadata]: """Returns a name format with single_host_load_and_broadcast enabled.""" if is_standard_name_format(name_format): return dataclasses.replace(name_format, single_host_load_and_broadcast=True) # pytype: disable=wrong-arg-types else: raise ValueError( 'single_host_load_and_broadcast is only supported for standard name' f' formats. Got {name_format}.' )
[docs] def checkpoint_steps(checkpoint_dir: epath.PathLike) -> List[int]: """Returns a list of finalized checkpoint steps in the directory.""" return [ step_from_checkpoint_name(s.name) for s in checkpoint_steps_paths(checkpoint_dir) ]
[docs] def any_checkpoint_step(checkpoint_dir: epath.PathLike) -> Optional[int]: """Returns any finalized checkpoint step in the directory or None. This avoids iterating over the entire directory. Args: checkpoint_dir: Checkpoint directory. Returns: Any finalized checkpoint step in the directory or None. """ checkpoint_dir = epath.Path(checkpoint_dir) for s in checkpoint_dir.iterdir(): if _is_legacy_finalized_step_checkpoint(s): return step_from_checkpoint_name(s.name) return None