Checkpoint Loading#
Defines exported symbols for orbax.checkpoint.experimental.v1.
Prefer to use the style:
import orbax.checkpoint.experimental.v1 as ocp
Loading functions#
- orbax.checkpoint.experimental.v1.load(path, abstract_state=None, *, checkpointable_name='AUTO')[source][source]#
Loads a PyTree.
Loads from a PyTree checkpoint. A PyTree checkpoint must be a path containing a subdirectory with the name provided by checkpointable_name, with default value AUTO. See checkpointable_name for more details.
This function must be called on all available controller processes.
The operation blocks until complete. For improved performance, consider using
load_async()instead.If abstract_state is not provided, the PyTree will be loaded exactly as saved.
IMPORTANT: Loading is more brittle and error-prone when not providing abstract_state. Always provide abstract_state if possible. Note that you can always obtain the tree structure from a saved checkpoint using
metadata().Providing the abstract_state guarantees two things:
1. The restored tree will exactly match the structure of abstract_state (or raise an error if it is impossible to guarantee this). For example, if abstract_state is a custom object registered as a PyTree, the checkpoint will be restored as the same object, if possible.
2. The leaves of the restored tree will be restored with the properties indicated by the abstract leaves. For example, if a leaf in abstract_state is a jax.ShapeDtypeStruct, the restored leaf will be a jax.Array with the same shape and dtype. Each AbstractLeaf has a corresponding Leaf that is restored. See orbax.checkpoint.v1.tree for a table of standard supported leaf types.
Example Usage:
Load a saved PyTree with and without providing its abstract structure:
path = '/tmp/my_checkpoint' # Save a checkpoint state = {'a': jnp.arange(8), 'b': jnp.zeros(4)} ocp.save(path, state) # Load the checkpoint # Highly recommended to provide the abstract pytree (structure/shapes) abstract_state = jax.eval_shape(lambda: state) # Method A: Load using the abstract structure. # This automatically looks for the 'pytree' subdirectory inside 'path'. restored = ocp.load(path, abstract_state) # Method B: Infer structure from file (Not recommended for production use) # cases or for complex trees. restored_inferred = ocp.load(path)
- Parameters:
path (
UnionType[Path,str]) – The path to load the checkpoint from. This path must contain a subdirectory with name provided by checkpointable_name. See checkpointable_name for more details.abstract_state (
Union[PyTreeOf[UnionType[AbstractArray,AbstractShardedArray,int,float,number,bytes,bool,str]],CheckpointMetadata[PyTreeOf[UnionType[AbstractArray,AbstractShardedArray,int,float,number,bytes,bool,str]]],None]) – Provides a tree structure for the checkpoint to be restored into. May be omitted to load exactly as saved, but this is much more brittle than providing the tree.checkpointable_name (
UnionType[str,None]) – The name of the checkpointable to load. A subdirectory with this name must exist in path. If None, then path itself is expected to contain all files relevant for loading the PyTree, rather than any subdirectory. Such files include, for example, manifest.ocdbt, _METADATA, ocp.process_X. Defaults to AUTO. Setting to AUTO mode dynamically discovers and resolves a pytree checkpointable. It prioritizes the standard ‘pytree’ checkpointable name if present, then sorts any other valid pytree checkpointable names alphabetically and returns the first valid one, and ultimately falls back to interpreting the path as a flat V0 root layout if no standard pytree exists.
- Return type:
PyTreeOf[UnionType[Array,ndarray,int,float,number,bytes,bool,str]]- Returns:
The restored PyTree.
- orbax.checkpoint.experimental.v1.load_checkpointables(path, abstract_checkpointables=None)[source][source]#
Loads checkpointables.
See documentation for
save_checkpointables()for more context on what a checkpointable is.This function can be used to load any checkpoint saved by
save_checkpointables()(orsave()). The path should contain a number of subdirectories - each of these represents the name of a checkpointable.This function must be called on all available controller processes.
The operation blocks until complete. For improved performance, consider using
load_checkpointables_async()instead.If abstract_checkpointables is not provided, the checkpointables will be loaded exactly as saved.
IMPORTANT: Loading is more brittle and error-prone when not providing abstract_checkpointables. Always provide abstract_checkpointables if possible. Note that you can always obtain the information about the checkpointables using
checkpointables_metadata().If abstract_checkpointables is provided, the value provided for each key is treated as the abstract type for the given checkpointable. For example, for a PyTree of jax.Array, the corresponding abstract checkpointable is a PyTree of jax.ShapeDtypeStruct. None is always a valid abstract checkpointable, which just indicates that the checkpointable should be loaded exactly as saved.
The keys provided in abstract_checkpointables may be any subset of the checkpointables in the checkpoint. Any checkpointables names not provided in abstract_checkpointables will not be loaded.
Example Usage:
Load checkpointables from a saved checkpoint:
path = '/tmp/my_checkpoint_step_100' # Save multiple components (checkpointables) params = {'w': jnp.ones((8, 8)), 'b': jnp.zeros(8)} opt_state = {'count': jnp.array(100)} # Setup Grain (Stateful Checkpointable) import grain dataset_iter = iter( grain.MapDataset.range(30) .batch(3) .map(lambda x: x.tolist()) ) ocp.save_checkpointables(path, { 'model': params, 'optimizer': opt_state, 'dataset': dataset_iter, }) # Load the checkpointables abstract_params = jax.eval_shape(lambda: params) abstract_opt = jax.eval_shape(lambda: opt_state) abstract_checkpointables = { 'model': abstract_params, 'optimizer': abstract_opt, # Dataset is restored statefully. An initialized object must be # passed, but its position will be set to the position recorded in the # checkpoint after restoring. 'dataset': dataset_iter, } # Load all components restored = ocp.load_checkpointables(path, abstract_checkpointables) # Load only a subset restored_subset = ocp.load_checkpointables( path, {'model': abstract_params} )
- Parameters:
path (
UnionType[Path,str]) – The path to load the checkpoint from. This path must contain a subdirectory for each checkpointable.abstract_checkpointables (
Union[dict[str,AbstractCheckpointable],CheckpointMetadata[dict[str,AbstractCheckpointable]],None]) – A dictionary of abstract checkpointables. Dictionary keys represent the names of the checkpointables, while the values are the abstract checkpointable objects themselves.
- Return type:
dict[str,Checkpointable]- Returns:
A dictionary of checkpointables. Dictionary keys represent the names of the checkpointables, while the values are the checkpointable objects themselves.
- Raises:
FileNotFoundError – If the checkpoint path does not exist.
- orbax.checkpoint.experimental.v1.load_async(path, abstract_state=None, *, checkpointable_name='state')[source][source]#
Loads a PyTree asynchronously. Currently has limited support.
- Return type:
AsyncResponse[PyTreeOf[UnionType[Array,ndarray,int,float,number,bytes,bool,str]]]
- orbax.checkpoint.experimental.v1.load_checkpointables_async(path, abstract_checkpointables=None)[source][source]#
Loads checkpointables asynchronously. Not yet implemented.
- Return type:
AsyncResponse[dict[str,Checkpointable]]