Source code for orbax.checkpoint.experimental.v1._src.layout.registry
# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Registry for checkpoint layouts."""
from __future__ import annotations
import asyncio
from absl import logging
from orbax.checkpoint._src import asyncio_utils
from orbax.checkpoint.experimental.v1._src.context import context as context_lib
from orbax.checkpoint.experimental.v1._src.context import options as options_lib
from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout
from orbax.checkpoint.experimental.v1._src.layout import orbax_layout
from orbax.checkpoint.experimental.v1._src.layout import orbax_v0_layout
from orbax.checkpoint.experimental.v1._src.layout import safetensors_layout
from orbax.checkpoint.experimental.v1._src.path import types as path_types
InvalidLayoutError = checkpoint_layout.InvalidLayoutError
CheckpointLayout = checkpoint_layout.CheckpointLayout
CheckpointLayoutEnum = options_lib.CheckpointLayout
ORBAX_LAYOUT_CLASSES = [
orbax_layout.OrbaxLayout,
orbax_v0_layout.OrbaxV0Layout,
]
async def _is_orbax_checkpoint_async(path: path_types.PathLike) -> bool:
ctx = context_lib.get_context()
path = ctx.file_options.path_class(path)
tasks = []
for layout_cls in ORBAX_LAYOUT_CLASSES:
tasks.append(layout_cls().validate(path))
results = await asyncio.gather(*tasks, return_exceptions=True)
return any(not isinstance(r, Exception) for r in results)
[docs]
def is_orbax_checkpoint(path: path_types.PathLike) -> bool:
"""Returns True if the path is an Orbax checkpoint."""
return asyncio_utils.run_sync(_is_orbax_checkpoint_async(path))
async def get_layout_class(
layout_enum: CheckpointLayoutEnum, path: path_types.PathLike | None = None
) -> type[CheckpointLayout]:
"""Returns the layout class for the given layout enum."""
match layout_enum:
case CheckpointLayoutEnum.ORBAX:
if path is None or (
await orbax_layout.checkpoint_version(path)
== orbax_layout.CheckpointVersion.V1
):
return orbax_layout.OrbaxLayout
else:
return orbax_v0_layout.OrbaxV0Layout
case CheckpointLayoutEnum.SAFETENSORS:
return safetensors_layout.SafetensorsLayout
case _:
raise ValueError(f"Unsupported checkpoint layout: {layout_enum}")
async def get_checkpoint_layout(
path: path_types.PathLike, layout_enum: CheckpointLayoutEnum
) -> CheckpointLayout:
"""Returns the checkpoint layout class for the given path and validates it.
Args:
path: The path to the checkpoint.
layout_enum: The checkpoint layout to use.
Returns:
The class of the matching :py:class:`.CheckpointLayout`.
Raises:
InvalidLayoutError: If the path is not a valid checkpoint for any registered
layout, with details from each layout's validation attempt.
"""
ctx = context_lib.get_context()
path = ctx.file_options.path_class(path)
layout_class = await get_layout_class(layout_enum, path)
try:
layout = layout_class()
await layout.validate(path)
return layout
except InvalidLayoutError as e:
raise InvalidLayoutError(
f"Could not recognize the checkpoint at {path} as a valid"
f" {layout_enum.value} checkpoint. If you are trying to load a"
" checkpoint that does not conform to the standard Orbax format, use"
" `ctx.checkpoint_layout = ...` to specify the expected checkpoint"
" layout."
) from e
class CheckpointLayoutResolver:
"""Resolves the layout and pytree name for a checkpoint."""
def __init__(
self,
path: path_types.PathLike,
layout_enum: CheckpointLayoutEnum,
layout: CheckpointLayout,
resolved_pytree_name: str | None,
):
self._path = path
self._layout_enum = layout_enum
self._layout = layout
self._resolved_pytree_name = resolved_pytree_name
@classmethod
async def resolve(
cls,
path: path_types.PathLike,
layout_enum: CheckpointLayoutEnum,
*,
pytree_name: str | None = None,
) -> CheckpointLayoutResolver:
"""Resolves the layout and pytree name for a checkpoint.
Args:
path: The path to the checkpoint.
layout_enum: The checkpoint layout to use.
pytree_name: The name of the pytree to load. If
`checkpoint_layout.AUTO_CHECKPOINTABLE_KEY`, the name will be
auto-resolved.
Returns:
A CheckpointLayoutResolver instance.
Raises:
InvalidLayoutError: If the checkpoint layout is invalid or if a valid
PyTree checkpointable cannot be found.
"""
ctx = context_lib.get_context()
path = ctx.file_options.path_class(path)
layout = await get_checkpoint_layout(path, layout_enum)
if pytree_name == checkpoint_layout.AUTO_CHECKPOINTABLE_KEY:
names = await layout.get_checkpointable_names(path)
for name in names:
try:
await layout.validate_pytree(path, name)
logging.info(
"AUTO resolution mode successfully identified a pytree with"
" checkpointable name '%s' at path '%s'. Attempting to load with"
" this name. If this is not the desired checkpointable, please"
" specify the name explicitly.",
name,
path,
)
return cls(path, layout_enum, layout, name)
except InvalidLayoutError:
continue
if isinstance(layout, orbax_v0_layout.OrbaxV0Layout):
try:
await layout.validate_pytree(path, None)
logging.info(
"AUTO resolution mode successfully identified a pytree at path"
" '%s'. Attempting to load as a flat layout V0 Orbax checkpoint."
" with checkpointable_name=None.",
path,
)
return cls(path, layout_enum, layout, None)
except InvalidLayoutError:
pass
raise InvalidLayoutError(
"Failed to load checkpoint using AUTO resolution mode on"
f" path='{path}'. No valid PyTree checkpointable found."
) from None
await layout.validate_pytree(path, pytree_name)
return cls(path, layout_enum, layout, pytree_name)
@property
def layout(self) -> CheckpointLayout:
"""Returns the checkpoint layout."""
return self._layout
@property
def pytree_name(self) -> str | None:
"""Returns the resolved pytree name."""
return self._resolved_pytree_name