Checkpoint Metadata#

Defines exported symbols for orbax.checkpoint.experimental.v1.

Prefer to use the style:

import orbax.checkpoint.experimental.v1 as ocp

PyTreeMetadata#

orbax.checkpoint.experimental.v1.PyTreeMetadata#

alias of PyTreeOf[AbstractArray | AbstractShardedArray | int | float | number | bytes | bool | str]

CheckpointMetadata#

final class orbax.checkpoint.experimental.v1.CheckpointMetadata(path, *, metadata, init_timestamp_nsecs=None, commit_timestamp_nsecs=None, custom_metadata=None)[source][source]#

Represents complete metadata describing a checkpoint.

Note that this class has a generic type CheckpointableMetadataT. This will typically be either PyTreeMetadata (see above), or dict[str, AbstractCheckpointable].

CheckpointMetadata can be accessed via one of two metadata methods. Please see metadata() and checkpointables_metadata() for more information and usage instructions.

If the checkpoint contains a PyTree, this metadata can be accessed via:

metadata = ocp.metadata(path)

# Inspect various properties
metadata.init_timestamp_nsecs

# Inspect the tree structure
metadata.metadata.pytree
metadata.metadata.pytree['layer0']['bias'].shape
metadata.metadata.pytree['layer0']['bias'].dtype

The checkpoint metadata can also be accessed more generically via:

metadata = ocp.checkpointables_metadata(path)

metadata.metadata.keys()  # == ['pytree', 'dataset', etc.]
metadata.metadata['pytree']  # instance of PyTreeMetadata
metadata#

Metadata for the checkpointable.

init_timestamp_nsecs#

The timestamp when the uncommitted checkpoint was initialized, specified in nanoseconds since the epoch. Defaults to None.

commit_timestamp_nsecs#

The commit timestamp of a checkpoint, specified in nanoseconds since the epoch. Defaults to None.

custom_metadata#

User-provided custom metadata. An arbitrary JSON-serializable dictionary the user can use to store additional information. The field is treated as opaque by Orbax.

__init__(path, *, metadata, init_timestamp_nsecs=None, commit_timestamp_nsecs=None, custom_metadata=None)[source][source]#
property path: Path#

The path to the checkpoint.

Return type:

Path

property init_timestamp: datetime | None#

Timestamp when the checkpoint began to be written.

Return type:

UnionType[datetime, None]

property commit_timestamp: datetime | None#

Timestamp when the checkpoint finished being written.

Return type:

UnionType[datetime, None]

Loading functions#

orbax.checkpoint.experimental.v1.metadata(path, checkpointable_name='AUTO')[source][source]#

Loads the PyTree metadata from a checkpoint.

This function retrieves metadata for a PyTree checkpoint, returning an object of type CheckpointMetadata[PyTreeMetadata]. Please see documentation on this class for further details.

In short, the returned object contains a metadata attribute (among other attributes like timestamps), which is an instance of PyTreeMetadata. The PyTreeMetadata describes information specific to the PyTree itself. The most important such property is the PyTree structure, which is a tree structure matching the structure of the checkpointed PyTree, with leaf metadata objects describing each leaf.

For example:

metadata = ocp.metadata(path)  # CheckpointMetadata[PyTreeMetadata]
metadata.metadata # PyTreeMetadata
metadata.init_timestamp_nsecs  # Checkpoint creation timestamp.

metadata.metadata  # PyTree structure.

The metadata can then be used to inform checkpoint loading. For example:

metadata = ocp.metadata(path)
restored = ocp.load(path, metadata)

# Load with altered properties.
def _get_abstract_array(arr):
  # Assumes all checkpoint leaves are array types.
  new_dtype = ...
  new_sharding = ...
  return jax.ShapeDtypeStruct(arr.shape, new_dtype, sharding=new_sharding)

metadata = dataclasses.replace(metadata,
      metadata=jax.tree.map(_get_abstract_array, metadata.metadata)
)
ocp.load(path, metadata)
Parameters:
  • path (UnionType[Path, str]) – The path to the checkpoint.

  • checkpointable_name (UnionType[str, None]) – The name of the checkpointable to load. A subdirectory with this name must exist in path. If None, then path itself is expected to contain all files relevant for loading the PyTree, rather than any subdirectory. Such files include, for example, manifest.ocdbt, _METADATA, ocp.process_X. Defaults to AUTO. Setting to AUTO mode dynamically discovers and resolves a pytree checkpointable. It prioritizes the standard ‘pytree’ checkpointable name if present, then sorts any other valid pytree checkpointable names alphabetically and returns the first valid one, and ultimately falls back to interpreting the path as a flat V0 root layout if no standard pytree exists.

Return type:

CheckpointMetadata[PyTreeOf[UnionType[AbstractArray, AbstractShardedArray, int, float, number, bytes, bool, str]]]

Returns:

A CheckpointMetadata[PyTreeMetadata] object.

orbax.checkpoint.experimental.v1.checkpointables_metadata(path)[source][source]#

Loads all checkpointables metadata from a checkpoint.

This function is a more general version of pytree_metadata. The same CheckpointMetadata object is returned (with properties like init_timestamp_nsecs as shown above), but the type of the core metadata property is a dictionary, mapping checkpointable names to their metadata. This mirrors the return value of load_checkpointables, which similarly returns a dictionary mapping checkpointable names to their loaded values.

For example:

ocp.save_checkpointables(path, {
    'foo': Foo(),
    'bar': Bar(),
})
metadata = ocp.checkpointables_metadata(path)
metadata.metadata  # {'foo': AbstractFoo(), 'bar': AbstractBar()}
Parameters:

path (UnionType[Path, str]) – The path to the checkpoint.

Return type:

CheckpointMetadata[dict[str, AbstractCheckpointable]]

Returns:

A CheckpointMetadata[dict[str, Any]] object.