Source code for orbax.checkpoint.experimental.v1._src.metadata.loading

# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Functions for loading metadata from a checkpoint."""

from orbax.checkpoint._src import asyncio_utils
from orbax.checkpoint.experimental.v1 import errors
from orbax.checkpoint.experimental.v1._src.context import context as context_lib
from orbax.checkpoint.experimental.v1._src.handlers import types as handler_types
import orbax.checkpoint.experimental.v1._src.handlers.global_registration  # pylint: disable=unused-import
from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout
from orbax.checkpoint.experimental.v1._src.layout import registry as layout_registry
from orbax.checkpoint.experimental.v1._src.loading import validation
from orbax.checkpoint.experimental.v1._src.metadata import types as metadata_types
from orbax.checkpoint.experimental.v1._src.path import types as path_types
from typing_extensions import deprecated  # pytype: disable=not-supported-yet


CheckpointMetadata = metadata_types.CheckpointMetadata
InvalidLayoutError = errors.InvalidLayoutError
PyTreeMetadata = metadata_types.PyTreeMetadata
STATE_CHECKPOINTABLE_KEY = checkpoint_layout.STATE_CHECKPOINTABLE_KEY
EMPTY_CHECKPOINTABLE_KEY = checkpoint_layout.EMPTY_CHECKPOINTABLE_KEY

AbstractCheckpointable = handler_types.AbstractCheckpointable


[docs] def metadata( path: path_types.PathLike, checkpointable_name: str | None = checkpoint_layout.AUTO_CHECKPOINTABLE_KEY, ) -> CheckpointMetadata[PyTreeMetadata]: """Loads the PyTree metadata from a checkpoint. This function retrieves metadata for a PyTree checkpoint, returning an object of type `CheckpointMetadata[PyTreeMetadata]`. Please see documentation on this class for further details. In short, the returned object contains a `metadata` attribute (among other attributes like timestamps), which is an instance of `PyTreeMetadata`. The `PyTreeMetadata` describes information specific to the PyTree itself. The most important such property is the PyTree structure, which is a tree structure matching the structure of the checkpointed PyTree, with leaf metadata objects describing each leaf. For example:: metadata = ocp.metadata(path) # CheckpointMetadata[PyTreeMetadata] metadata.metadata # PyTreeMetadata metadata.init_timestamp_nsecs # Checkpoint creation timestamp. metadata.metadata # PyTree structure. The metadata can then be used to inform checkpoint loading. For example:: metadata = ocp.metadata(path) restored = ocp.load(path, metadata) # Load with altered properties. def _get_abstract_array(arr): # Assumes all checkpoint leaves are array types. new_dtype = ... new_sharding = ... return jax.ShapeDtypeStruct(arr.shape, new_dtype, sharding=new_sharding) metadata = dataclasses.replace(metadata, metadata=jax.tree.map(_get_abstract_array, metadata.metadata) ) ocp.load(path, metadata) Args: path: The path to the checkpoint. checkpointable_name: The name of the checkpointable to load. A subdirectory with this name must exist in `path`. If None, then path itself is expected to contain all files relevant for loading the PyTree, rather than any subdirectory. Such files include, for example, `manifest.ocdbt`, `_METADATA`, `ocp.process_X`. Defaults to `AUTO`. Setting to `AUTO` mode dynamically discovers and resolves a pytree checkpointable. It prioritizes the standard 'pytree' checkpointable name if present, then sorts any other valid pytree checkpointable names alphabetically and returns the first valid one, and ultimately falls back to interpreting the path as a flat V0 root layout if no standard pytree exists. Returns: A `CheckpointMetadata[PyTreeMetadata]` object. """ validation.validate_pytree_checkpointable_name(checkpointable_name) ctx = context_lib.get_context() path = ctx.file_options.path_class(path) resolver = asyncio_utils.run_sync( layout_registry.CheckpointLayoutResolver.resolve( path, ctx.checkpoint_layout, pytree_name=checkpointable_name ) ) layout = resolver.layout resolved_name = resolver.pytree_name # TODO(b/477603241): This logic currently accounts for the V0 # metadata function returning a pytree for direct pytree checkpoints, while # V1 returns a dictionary. This logic should be cleaned up once we roll up # the composite handler into the layout themselves. step_metadata = _checkpointables_metadata_impl(layout, path) if resolved_name is None: tree_metadata = step_metadata.metadata else: tree_metadata = step_metadata.metadata[resolved_name] return CheckpointMetadata[PyTreeMetadata]( path=path, metadata=tree_metadata, init_timestamp_nsecs=step_metadata.init_timestamp_nsecs, commit_timestamp_nsecs=step_metadata.commit_timestamp_nsecs, custom_metadata=step_metadata.custom_metadata, )
[docs] def checkpointables_metadata( path: path_types.PathLike, ) -> CheckpointMetadata[dict[str, AbstractCheckpointable]]: """Loads all checkpointables metadata from a checkpoint. This function is a more general version of `pytree_metadata`. The same `CheckpointMetadata` object is returned (with properties like `init_timestamp_nsecs` as shown above), but the type of the core `metadata` property is a dictionary, mapping checkpointable names to their metadata. This mirrors the return value of `load_checkpointables`, which similarly returns a dictionary mapping checkpointable names to their loaded values. For example:: ocp.save_checkpointables(path, { 'foo': Foo(), 'bar': Bar(), }) metadata = ocp.checkpointables_metadata(path) metadata.metadata # {'foo': AbstractFoo(), 'bar': AbstractBar()} Args: path: The path to the checkpoint. Returns: A `CheckpointMetadata[dict[str, Any]]` object. """ ctx = context_lib.get_context() path = ctx.file_options.path_class(path) layout = asyncio_utils.run_sync( layout_registry.get_checkpoint_layout(path, ctx.checkpoint_layout) ) return _checkpointables_metadata_impl(layout, path)
def _checkpointables_metadata_impl( layout: checkpoint_layout.CheckpointLayout, path: path_types.Path, ) -> CheckpointMetadata[dict[str, AbstractCheckpointable]]: """Shared implementation for checkpointables_metadata.""" async def _load_metadata() -> ( metadata_types.CheckpointMetadata[dict[str, AbstractCheckpointable]] ): return await layout.metadata(path) return asyncio_utils.run_sync(_load_metadata()) @deprecated('Use `metadata` instead.') def pytree_metadata(*args, **kwargs): return metadata(*args, **kwargs)