CheckpointHandlers#
Defines exported CheckpointHandler s and their arguments.
CheckpointHandler subclasses define logic used to save and restore
an object to and from a checkpoint. Each CheckpointHandler has
corresponding SaveArgs and RestoreArgs classes that
define the arguments used to call the handler.
Prefer to use ocp.args to reference these objects.
CheckpointHandler#
- class orbax.checkpoint.CheckpointHandler[source][source]#
An interface providing save/restore methods used on a savable item.
Item may be a PyTree, Dataset, or any other supported object.
NOTE: Users should avoid using CheckpointHandler independently. Use
CheckpointerorCheckpointManagerinstead.- abstractmethod save(directory, *args, **kwargs)[source][source]#
Saves the provided item synchronously.
- Parameters:
directory (
Path) – the directory to save to.*args – additional arguments for save.
**kwargs – additional arguments for save.
- abstractmethod restore(directory, *args, **kwargs)[source][source]#
Restores the provided item synchronously.
- Parameters:
directory (
Path) – the directory to restore from.*args – additional arguments for restore.
**kwargs – additional arguments for restore.
- Return type:
Any- Returns:
The restored item.
- metadata(directory)[source][source]#
Returns metadata about the saved item.
Ideally, this is a cheap way to collect information about the checkpoint without requiring a full restoration.
- Parameters:
directory (
Path) – the directory where the checkpoint is located.- Return type:
Optional[Any,None]- Returns:
item metadata
AsyncCheckpointHandler#
- class orbax.checkpoint.AsyncCheckpointHandler[source][source]#
An interface providing async methods used with AsyncCheckpointer.
- abstractmethod async async_save(directory, *args, **kwargs)[source][source]#
Saves the given item to the provided directory.
- Parameters:
directory (
Path) – the directory to save to.*args – additional arguments for save.
**kwargs – additional arguments for save.
- Return type:
Optional[List[Future],None]- Returns:
A list of commit futures which can be awaited upon to complete the save operation.
StandardCheckpointHandler#
- class orbax.checkpoint.StandardCheckpointHandler(*, save_concurrent_gb=96, restore_concurrent_gb=96, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), use_ocdbt=True)[source][source]#
A CheckpointHandler implementation for any PyTree structure.
See JAX documentation for more information on what constitutes a “PyTree”. This handler is capable of saving and restoring PyTrees with leaves of type Python scalar, np.ndarray, and jax.Array
As with all
CheckpointHandlersubclasses, StandardCheckpointHandler should only be used in conjunction with aCheckpointer(or subclass). By itself, the CheckpointHandler is non-atomic.Example:
ckptr = Checkpointer(StandardCheckpointHandler()) # OR ckptr = StandardCheckpointer()
If you find that your use case is not covered by StandardCheckpointHandler, consider using the parent class directly, or explore a custom implementation of CheckpointHandler.
- __init__(*, save_concurrent_gb=96, restore_concurrent_gb=96, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), use_ocdbt=True)[source][source]#
Creates StandardCheckpointHandler.
- Parameters:
save_concurrent_gb (
int) – max concurrent GB that are allowed to be writing to disk at any given time. This limits the amount of data currently being written to disk, which can help to reduce the possibility of OOM’s when large checkpoints are saved. Note that this does NOT limit device-to-host transfer, meaning that the limit specified here may still be exceeded by the total memory usage of the process.restore_concurrent_gb (
int) – max concurrent GB that are allowed to be restored. Can help to reduce the possibility of OOM’s when large checkpoints are restored.multiprocessing_options (
MultiprocessingOptions) – See orbax.checkpoint.options.pytree_metadata_options (
PyTreeMetadataOptions) – Options to control types like tuple and namedtuple in pytree metadata.use_ocdbt (
bool) – Whether to enable Tensorstore OCDBT driver.
- async async_save(directory, item=None, save_args=None, args=None)[source][source]#
Saves a PyTree of array-like objects.
- Parameters:
directory (
UnionType[Path,PathAwaitingCreation]) – path to the directory where the checkpoint will be saved.item (
Optional[Any,None]) – Deprecated, use args.save_args (
Optional[Any,None]) – Deprecated, use args.args (
Optional[StandardSaveArgs,None]) – StandardSaveArgs (see below).
- Return type:
Optional[List[Future],None]- Returns:
A list of futures that will be completed when the save is complete.
- restore(directory, item=None, args=None)[source][source]#
Restores a PyTree. See
PyTreeCheckpointHandler.Example:
ckptr = StandardCheckpointer() item = { 'layer0': { 'w': jax.Array(...), 'b': np.ndarray(...), }, } ckptr.save(dir, StandardSaveArgs(item)) target = { 'layer0': { 'w': jax.ShapeDtypeStruct(...), 'b': jax.Array(...), }, } ckptr.restore(dir, StandardRestoreArgs(target))
- Parameters:
directory (
Path) – path from which to restore.item (
Optional[Any,None]) – Deprecated, use args.args (
Optional[StandardRestoreArgs,None]) – StandardRestoreArgs (see below).
- Return type:
Any- Returns:
a restored PyTree.
- metadata(directory)[source][source]#
Returns metadata about the saved item.
- Return type:
TreeMetadata
- class orbax.checkpoint.handlers.StandardSaveArgs(item, save_args=None, custom_metadata=None)[source][source]#
Parameters for saving a standard PyTree.
Also see
PyTreeSavefor additional options.- item#
a PyTree to be saved.
- Type:
required
- save_args#
a PyTree with the same structure of item, which consists of ocp.SaveArgs objects as values. None can be used for values where no SaveArgs are specified.
- Type:
Optional[PyTree]
- custom_metadata#
User-provided custom metadata. An arbitrary JSON-serializable dictionary the user can use to store additional information. The field is treated as opaque by Orbax.
- Type:
tree_types.JsonType | None
- __eq__(other)#
Return self==value.
- __hash__ = None#
- __init__(item, save_args=None, custom_metadata=None)#
- class orbax.checkpoint.handlers.StandardRestoreArgs(item=None, strict=True, support_layout=False, fallback_sharding=None)[source][source]#
Parameters for restoring a standard PyTree.
Also see
PyTreeRestorefor additional options.- Attributes (all optional):
- item: target PyTree. Currently non-optional. Values may be either real
array or scalar values, or they may be jax.ShapeDtypeStruct, or ocp.metadata.value.Metadata objects (which come from calling the metadata method). If real values are provided, that value will be restored as the given type, with the given properties. If jax.ShapeDtypeStruct is provided, the value will be restored as np.ndarray, unless sharding is specified. If item is a custom PyTree class, the tree will be restored with the same structure as provided. If not provided, restores as a serialized nested dict representation of the custom class. TreeMetadata is also allowed as the tree used to define the restored structure.
- strict: if False, restoration allows silent truncating/padding of arrays if
the stored array shape does not match the target shape. Otherwise, raises an error.
support_layout: if True, restores with the layouts in item. fallback_sharding: If provided, this sharding will be used as a fallback
if the saved sharding fails to load from the checkpoint.
- __eq__(other)#
Return self==value.
- __hash__ = None#
- __init__(item=None, strict=True, support_layout=False, fallback_sharding=None)#
PyTreeCheckpointHandler#
- class orbax.checkpoint.PyTreeCheckpointHandler(aggregate_filename=None, *, save_concurrent_gb=None, restore_concurrent_gb=None, save_device_host_concurrent_gb=None, memory_limit_options=None, use_ocdbt=True, use_zarr3=False, use_compression=True, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), type_handler_registry=<orbax.checkpoint._src.serialization.type_handler_registry._TypeHandlerRegistryImpl object>, handler_impl=None, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_validator=<orbax.checkpoint._src.metadata.array_metadata_store.Validator object>, enable_pinned_host_transfer=None, is_prioritized_key_fn=None)[source][source]#
A CheckpointHandler implementation for any PyTree structure.
See JAX documentation for more information on what consistutes a “PyTree”. This handler is capable of saving and restoring any leaf object for which a
TypeHandler(see documentation) is registered. By default, TypeHandler`s for standard types like `np.ndarray, jax.Array, Python scalars, and others are registered.As with all
CheckpointHandlersubclasses, PyTreeCheckpointHandler should only be used in conjunction with aCheckpointer(or subclass). By itself, the CheckpointHandler is non-atomic.Example:
ckptr = Checkpointer(PyTreeCheckpointHandler())
# TODO(cpgaffney) Cut down on the protected methods accessed by this class.
- __init__(aggregate_filename=None, *, save_concurrent_gb=None, restore_concurrent_gb=None, save_device_host_concurrent_gb=None, memory_limit_options=None, use_ocdbt=True, use_zarr3=False, use_compression=True, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), type_handler_registry=<orbax.checkpoint._src.serialization.type_handler_registry._TypeHandlerRegistryImpl object>, handler_impl=None, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_validator=<orbax.checkpoint._src.metadata.array_metadata_store.Validator object>, enable_pinned_host_transfer=None, is_prioritized_key_fn=None)[source][source]#
Creates PyTreeCheckpointHandler.
- Parameters:
aggregate_filename (
Optional[str,None]) – name that the aggregated checkpoint should be saved as.save_concurrent_gb (
Optional[int,None]) – max concurrent GB that are allowed to be writing to disk at any given time. This limits the amount of data currently being written to disk, which can help to reduce the possibility of OOM’s when large checkpoints are saved. Note that this does NOT limit device-to-host transfer, meaning that the limit specified here may still be exceeded by the total memory usage of the process.restore_concurrent_gb (
Optional[int,None]) – max concurrent GB that are allowed to be restored. Can help to reduce the possibility of OOM’s when large checkpoints are restored.save_device_host_concurrent_gb (
UnionType[int,str,None]) – max concurrent GB allowed to be transferred from device to host memory at once when saving, defined on a per-worker basis. When the limit is reached, arrays must be finished writing to the checkpoint before a new array can start being transferred. This option is a stronger version of save_concurrent_gb. Unlike save_concurrent_gb which only limits the amount of data currently being written to disk, this option limits the amount of data transferred from device to host. Note that asynchronous saves may not be truly asynchronous with this option enabled, as we have to block on some array writes before beginning others. Also see is_prioritized_key_fn. Can be set to “auto” to enable Memory Regulator.memory_limit_options (
UnionType[MemoryLimitOptions,None]) – Memory limit options for the checkpoint handler.use_ocdbt (
bool) – enables Tensorstore OCDBT driver. This option allows using a different checkpoint format which is faster to read and write, as well as more space efficient.use_zarr3 (
bool) – If True, use Zarr ver3 otherwise Zarr ver2use_compression (
bool) – If True and zarr2 is used, use zstd compression.multiprocessing_options (
MultiprocessingOptions) – See orbax.checkpoint.optionstype_handler_registry (
TypeHandlerRegistry) – a type_handlers.TypeHandlerRegistry. If not specified, the global type handler registry will be used.handler_impl (
Optional[BasePyTreeCheckpointHandler,None]) – Allows overriding the internal implementation.pytree_metadata_options (
PyTreeMetadataOptions) – PyTreeMetadataOptions to manage metadata.array_metadata_validator (
Validator) – Validator for ArrayMetadata.enable_pinned_host_transfer (
Optional[bool,None]) – Whether to use pinned_host memory for the transfer from device to host memory. Passing None will enable pinned_host memory depending on the platform used (currently only enables it for the GPU backend).is_prioritized_key_fn (
Optional[IsPrioritizedKeyFn,None]) – A function that accepts a PyTree keypath (obtained using jax.tree.map_with_path) that should be scheduled for D2H transfer before other keys. The transfer is scheduled before returning to the caller, so the values will never be corrupted by a concurrent update. Keys that are not prioritized will not be scheduled for transfer until all prioritized keys have been fully written to the checkpoint. This means that these values may be altered if the values are updated concurrently. Callers should take care to call wait_until_finished before updating array values (e.g. apply_gradients) if some keys are not prioritized. Note that any “prioritized” keys are assumed to be lightweight, and save_device_host_concurrent_gb will be ignored for them.
- async async_save(directory, item=None, save_args=None, args=None)[source][source]#
Saves a PyTree to a given directory.
This operation is compatible with a multi-host, multi-device setting. Tree leaf values must be supported by the type_handler_registry given in the constructor. Standard supported types include Python scalars, np.ndarray, jax.Array, and strings.
After saving, all files will be located in “directory/”. The exact files that are saved depend on the specific combination of options, including use_ocdbt. A JSON metadata file will be present to store the tree structure.
Example usage:
ckptr = ocp.Checkpointer(ocp.PyTreeCheckpointHandler()) item = { 'layer0': { 'w': np.ndarray(...), 'b': np.ndarray(...), }, 'layer1': { 'w': np.ndarray(...), 'b': np.ndarray(...), }, } # Note: save_args may be None if no customization is desired for saved # parameters. # Otherwise, settings can be used to customize save behavior, e.g. # casting. save_args = jax.tree.map(lambda x: ocp.SaveArgs(dtype=np.int32), item) # Eventually calls through to `async_save`. ckptr.save(path, args=ocp.PyTreeSave(item, save_args))
- Parameters:
directory (
UnionType[Path,PathAwaitingCreation]) – save location directory.item (
Optional[Any,None]) – Deprecated, use `args.save_args (
Optional[PyTreeSaveArgs,None]) – Deprecated, use args.args (
Optional[PyTreeSaveArgs,None]) – PyTreeSaveArgs (see below).
- Return type:
Optional[List[Future],None]- Returns:
A Future that will commit the data to directory when awaited. Copying the data from its source will be awaited in this function.
- save(directory, item=None, save_args=None, args=None)[source][source]#
Saves the provided item. See async_save.
- restore(directory, item=None, restore_args=None, transforms=None, transforms_default_to_original=True, legacy_transform_fn=None, args=None)[source][source]#
Restores a PyTree from the checkpoint directory at the given path.
In the most basic case, only directory is required. The tree will be restored exactly as saved, and all leaves will be restored as the correct types (assuming the tree metadata is present).
However, restore_args is often required as well. This PyTree gives a RestoreArgs object (or subclass) for every leaf in the tree. Many types, such as string or np.ndarray do not require any special options for restoration. When restoring an individual leaf as jax.Array, however, some properties may be required.
One example is sharding, which defines how a jax.Array in the restored tree should be partitioned. mesh and mesh_axes can also be used to specify sharding, but sharding is the preferred way of specifying this partition since mesh and mesh_axes only constructs jax.sharding.NamedSharding. For more information, see ArrayTypeHandler documentation and JAX sharding documentation.
Example:
ckptr = ocp.Checkpointer(ocp.PyTreeCheckpointHandler()) restore_args = { 'layer0': { 'w': ocp.RestoreArgs(), 'b': ocp.RestoreArgs(), }, 'layer1': { 'w': ocp.ArrayRestoreArgs( # Restores as jax.Array, regardless of how it was saved. restore_type=jax.Array, sharding=jax.sharding.Sharding(...), # Warning: may truncate or pad! global_shape=(x, y), ), 'b': ocp.ArrayRestoreArgs( restore_type=jax.Array, sharding=jax.sharding.Sharding(...), global_shape=(x, y), ), }, } ckptr.restore(path, args=ocp.PyTreeRestore(restore_args=restore_args))
Providing item is typically only necessary when restoring a custom PyTree class (or when using transformations). In this case, the restored object will take on the same structure as item.
Example:
@flax.struct.dataclass class TrainState: layer0: dict[str, jax.Array] layer1: dict[str, jax.Array] ckptr = Checkpointer(PyTreeCheckpointHandler()) train_state = TrainState( layer0={ 'w': jax.Array(...), # zeros 'b': jax.Array(...), # zeros }, layer1={ 'w': jax.Array(...), # zeros 'b': jax.Array(...), # zeros }, ) restore_args = jax.tree.map(_make_restore_args, train_state) ckptr.restore(path, item=train_state, restore_args=restore_args) # restored tree is of type `TrainState`.
- Parameters:
directory (
Path) – saved checkpoint location directory.item (
Optional[Any,None]) – Deprecated, use args.restore_args (
Optional[Any,None]) – Deprecated, use args.transforms (
Optional[Any,None]) – Deprecated, use args.transforms_default_to_original (
bool) – See transform_utils.apply_transformations.legacy_transform_fn (
Optional[Callable[[Any,Any,Any],Tuple[Any,Any]],None]) – Deprecated, use args.args (
Optional[PyTreeRestoreArgs,None]) – PyTreeRestoreArgs (see below).
- Return type:
Any- Returns:
A PyTree matching the structure of item.
- Raises:
FileNotFoundError – directory does not exist or is missing required files
ValueError – transforms is provided without item.
ValueError – transforms contains elements with multi_value_fn.
- metadata(directory)[source][source]#
Returns tree metadata.
The result will be a PyTree matching the structure of the saved checkpoint. Note that if the item saved was a custom class, the restored metadata will be returned as a nested dictionary representation.
Example:
{ 'layer0': { 'w': ArrayMetadata(dtype=jnp.float32, shape=(8, 8), shards=(1, 2)), 'b': ArrayMetadata(dtype=jnp.float32, shape=(8,), shards=(1,)), }, 'step': ScalarMetadata(dtype=jnp.int64), }
If the required metadata file is not present, this method will raise an error.
- Parameters:
directory (
Path) – checkpoint location.- Return type:
TreeMetadata- Returns:
tree containing metadata.
- finalize(directory)[source][source]#
Finalization step.
Called automatically by the Checkpointer/AsyncCheckpointer just before the checkpoint is considered “finalized” in the sense of ensuring atomicity. See documentation for type_handlers.merge_ocdbt_per_process_files.
- Parameters:
directory (
Path) – Path where the checkpoint is located.- Return type:
None
- class orbax.checkpoint.handlers.PyTreeSaveArgs(item, save_args=None, ocdbt_target_data_file_size=None, custom_metadata=None)[source][source]#
Parameters for saving a PyTree.
- item#
a PyTree to be saved.
- Type:
required
- save_args#
a PyTree with the same structure of item, which consists of ocp.SaveArgs objects as values. None can be used for values where no SaveArgs are specified.
- Type:
Optional[PyTree]
- ocdbt_target_data_file_size#
Specifies the target size (in bytes) of each OCDBT data file. It only applies when OCDBT is enabled and Zarr3 must be turned on. If left unspecified, default size is 2GB. A value of 0 indicates no maximum file size limit. For best results, ensure chunk_byte_size is smaller than this value. For more details, refer to https://google.github.io/tensorstore/kvstore/ocdbt/index.html#json-kvstore/ocdbt.target_data_file_size
- Type:
Optional[int]
- custom_metadata#
User-provided custom metadata. An arbitrary JSON-serializable dictionary the user can use to store additional information. The field is treated as opaque by Orbax.
- Type:
tree_types.JsonType | None
- __eq__(other)#
Return self==value.
- __hash__ = None#
- __init__(item, save_args=None, ocdbt_target_data_file_size=None, custom_metadata=None)#
- class orbax.checkpoint.handlers.PyTreeRestoreArgs(item=None, restore_args=None, transforms=None, transforms_default_to_original=True, legacy_transform_fn=None, partial_restore=False)[source][source]#
Parameters for restoring a PyTree.
- Attributes (all optional):
- item: provides the tree structure for the restored item. If not provided,
will infer the structure from the saved checkpoint. Transformations will not be run in this case. Necessary particularly in the case where the caller needs to restore the tree as a custom object. TreeMetadata is also allowed as the tree used to define the restored structure.
- restore_args: optional object containing additional arguments for
restoration. It should be a PyTree matching the structure of item, or if item is not provided, then it should match the structure of the checkpoint. Each value in the tree should be a RestoreArgs object (OR a subclass of RestoreArgs). Importantly, note that when restoring a leaf as a certain type, a specific subclass of RestoreArgs may be required. RestoreArgs also provides the option to customize the restore type of an individual leaf. TreeMetadata is also allowed as the restore_args tree.
- transforms: a PyTree of transformations that should be applied to the
saved tree in order to obtain a final structure. The transforms tree structure should conceptually match that of item, but the use of regexes and implicit keys means that it does not need to match completely. See transform_utils for further information. TreeMetadata is also allowed as the transforms tree.
- transforms_default_to_original:
See transform_utils.apply_transformations.
- legacy_transform_fn: WARNING: NOT GENERALLY SUPPORTED. A function which
accepts the item argument, a PyTree checkpoint structure and a PyTree of ParamInfos based on the checkpoint. Returns a transformed PyTree matching the desired return tree structure, and a matching ParamInfo tree.
- partial_restore: If True, only restore the parameters that are specified
in PyTreeRestoreArgs.
- __eq__(other)#
Return self==value.
- __hash__ = None#
- __init__(item=None, restore_args=None, transforms=None, transforms_default_to_original=True, legacy_transform_fn=None, partial_restore=False)#
CompositeCheckpointHandler#
- class orbax.checkpoint.CompositeCheckpointHandler(*item_names, composite_options=CompositeOptions(multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), temporary_path_class=None, file_options=None, async_options=None), handler_registry=None, **items_and_handlers)[source][source]#
CheckpointHandler for saving multiple items.
As with all
CheckpointHandlerimplementations, use only in conjunction with an instance ofAbstractCheckpointer.The
CompositeCheckpointHandlerallows dealing with multiple items of different types or logical distinctness, such as training state (PyTree), dataset iterator, JSON metadata, or anything else. To specify the handler and args for each checkpointable item, use the handler_registry option at initialization (seeCheckpointHandlerRegistry).If a handler registry is provided, the
CompositeCheckpointHandlerwill use this registry to determine theCheckpointHandlerfor each checkpointable item. If no registry is provided, an empty registry will be constructed by default internally. Additional items not registered in the handler registry can be specified when calling save or restore. The handler will be determined from the providedCheckpointArgsduring the first call to save or restore. Subsequent calls to save or restore will use the same handler and will require the sameCheckpointArgsto be provided.item_names and items_and_handlers are deprecated arguments for specifying the items and handlers. Please use handler_registry instead.
Usage:
# The simplest use-case, with no handler registry provided on construction. checkpointer = ocp.Checkpointer( ocp.CompositeCheckpointHandler() ) checkpointer.save(directory, ocp.args.Composite( # The handler will be determined from the global registry using the # provided args. state=ocp.args.StandardSave(pytree), ) ) restored: ocp.args.Composite = checkpoint_handler.restore(directory, ocp.args.Composite( state=ocp.args.StandardRestore(abstract_pytree), ) ) restored.state ... # The "state" item's args was determined on the first call to `save`. # Trying to save "state" with a different set of args will raise an error. checkpointer.save(directory, ocp.args.Composite( # Will raise `ValueError: Item "state" and args "JsonSave [...]" does # not match with any registered handler! ...`. state=ocp.args.JsonSave(pytree), ) ) # You can also register specific items and handlers using a handler # registry. handler_registry = ( ocp.handler_registration.DefaultCheckpointHandlerRegistry() ) # Some subclass of `CheckpointHandler` that handles the "state" item. handler = state_handler() # Multiple save args and handlers can be registered for the same item. handler_registry.add( 'state, StateSaveArgs, handler, ) handler_registry.add( 'state', StateRestoreArgs, handler, ) # The `CompositeCheckpointHandler` will use the handlers registered in the # registry when it encounters the item name `state`. checkpointer_with_registry = ocp.Checkpointer( ocp.CompositeCheckpointHandler(handler_registry=handler_registry) ) checkpointer_with_registry.save(directory, ocp.args.Composite( # The handler knows that the "state" item should use the # `state_handler` specified in the registry. When an item has been # added to the handler registry, only save args that are registered # for that item are allowed. state=StateSaveArgs(pytree), # You can also provide other arguments that are not registered in # the handler registry as long as a handler has been globally # registered for the args. other_param=ocp.args.StandardSave(other_pytree), ) ) restored: ocp.args.Composite = checkpointer_with_registry.restore(directory, ocp.args.Composite( state=StateRestoreArgs(abstract_pytree), other_param=ocp.args.StandardRestore(abstract_other_pytree), ) ) restored.state ... restored.other_param... # Skip restoring `other_params` (you can save a subset of items too, in a # similar fashion). restored_state_only: ocp.args.Composite = ( checkpointer_with_registry.restore(directory, ocp.args.Composite( state=StateRestoreArgs(abstract_pytree), ) ) ) restored.state ... restored.other_param ... # None.
- __init__(*item_names, composite_options=CompositeOptions(multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), temporary_path_class=None, file_options=None, async_options=None), handler_registry=None, **items_and_handlers)[source][source]#
Constructor.
If you are item_names and/or items_and_handlers, all items must be provided up-front, at initialization. If you are using a handler_registry, you can register items at any time, even after the first call to save or restore.
- Parameters:
*item_names – A list of string item names that this handler will manage. item_names is deprecated. Please use handler_registry instead.
composite_options (
CompositeOptions) – Options.handler_registry (
Optional[CheckpointHandlerRegistry,None]) – ACheckpointHandlerRegistryinstance. If provided, theCompositeCheckpointHandlerwill use this registry to determine theCheckpointHandlerfor each item. This option is mutually exclusive with items_and_handlers and item_names.**items_and_handlers – A mapping of item name to CheckpointHandler instance, which will be used as the handler for objects of the corresponding name. items_and_handlers is deprecated. Please use handler_registry instead.
- async async_save(directory, args)[source][source]#
Saves multiple items to individual subdirectories.
- Return type:
Optional[List[Future],None]
- restore(directory, args=None)[source][source]#
Restores the provided item synchronously.
Restoration can happen in a variety of modes, depending on which args is passed:
args is not None and not empty. Item names present in args will be
restored as long as they exist in the checkpoint and have a registered handler. - args is None or empty. All items in the checkpoint that have a registered handler will be restored.
- Parameters:
directory (
Path) – Path to restore from.args (
Optional[CompositeArgs,None]) – CompositeArgs object used to restore individual sub-items. May be None or “empty” (CompositeArgs()), which is used to indicate that the handler should restore everything. If an individual item was not specified during the save, None will be returned for that item’s entry.
- Return type:
- Returns:
A CompositeResults object with keys matching CompositeArgs, or with keys for all known items as specified at creation.
- Raises:
KeyError – If an item could not be restored.
- metadata_from_temporary_paths(directory)[source][source]#
Metadata for each item in the temporary checkpoint.
- Return type:
StepMetadata
- metadata(directory)[source][source]#
Metadata for each item in the checkpoint.
This has much the same logic as restore, in the sense that it tries to restore the metadata for each item present in the checkpoint. However, for any items that do not have a registered handler, the metadata for that item will simply be returned as None.
- Parameters:
directory (
Path) – Path to the checkpoint.- Return type:
StepMetadata- Returns:
StepMetadata
- Raises:
FileNotFoundError – If the directory does not exist.
- class orbax.checkpoint.handlers.CompositeArgs(**items)[source][source]#
Args for wrapping multiple checkpoint items together.
- orbax.checkpoint.handlers.CompositeResults[source]#
alias of
CompositeArgs
JsonCheckpointHandler#
- class orbax.checkpoint.JsonCheckpointHandler(filename=None, *, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None))[source][source]#
Saves nested dictionary using json.
- __init__(filename=None, *, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None))[source][source]#
Initializes JsonCheckpointHandler.
- Parameters:
filename (
Optional[str,None]) – optional file name given to the written file; defaults to ‘metadata’multiprocessing_options (
MultiprocessingOptions) – See orbax.checkpoint.options.
- async async_save(directory, item=None, args=None)[source][source]#
Saves the given item.
- Parameters:
directory (
Path) – save location directory.item (
Optional[Mapping[str,Any],None]) – Deprecated, use args instead.args (
Optional[JsonSaveArgs,None]) – JsonSaveArgs (see below).
- Return type:
Optional[List[Future],None]- Returns:
A list of commit futures.
- save(directory, item=None, args=None)[source][source]#
Saves the provided item synchronously.
- Parameters:
directory (
Path) – the directory to save to.*args – additional arguments for save.
**kwargs – additional arguments for save.
- restore(directory, item=None, args=None)[source][source]#
Restores json mapping from directory.
item is unused.
- Parameters:
directory (
Path) – restore location directory.item (
Optional[Mapping[str,Any],None]) – unusedargs (
Optional[JsonRestoreArgs,None]) – unused
- Return type:
Mapping[str,Any]- Returns:
JSON dict.
- Raises:
FileNotFoundError – if the file does not exist.
ArrayCheckpointHandler#
- class orbax.checkpoint.ArrayCheckpointHandler(checkpoint_name=None)[source][source]#
Handles saving and restoring individual arrays and scalars.
- __init__(checkpoint_name=None)[source][source]#
Initializes the handler.
- Parameters:
checkpoint_name (
Optional[str,None]) – Provides a name for the directory under which Tensorstore files will be saved. Defaults to ‘checkpoint’.
- async async_save(directory, item=None, save_args=None, args=None)[source][source]#
Saves an object asynchronously.
- Parameters:
directory (
Path) – Folder in which to save.item (
Union[int,float,number,ndarray,Array,None]) – Deprecated, use args.save_args (
Optional[SaveArgs,None]) – Deprecated, use args.args (
Optional[ArraySaveArgs,None]) – An ocp.array_checkpoint_handler.ArraySaveArgs (see below).
- Return type:
Optional[List[Future],None]- Returns:
A list of commit futures which can be run to complete the save.
- restore(directory, item=None, restore_args=None, args=None)[source][source]#
Restores an object.
- Parameters:
directory (
Path) – folder from which to read.item (
Union[int,float,number,ndarray,Array,None]) – Deprecated, use args.restore_args (
Optional[RestoreArgs,None]) – Deprecated, use args.args (
Optional[ArrayRestoreArgs,None]) – An ocp.array_checkpoint_handler.ArrayRestoreArgs object (see below).
- Return type:
- Returns:
The restored object.
- class orbax.checkpoint.handlers.ArraySaveArgs(item, save_args=None)[source][source]#
Parameters for saving an array or scalar.
- item#
an array or scalar object.
- Type:
required
- save_args#
a ocp.SaveArgs object specifying save options.
- Type:
- __eq__(other)#
Return self==value.
- __hash__ = None#
- __init__(item, save_args=None)#
- class orbax.checkpoint.handlers.ArrayRestoreArgs(item=None, restore_args=None)[source][source]#
Array restore args.
- item#
unused, but provided as an option for legacy-compatibility reasons.
- Type:
int | float | numpy.number | numpy.ndarray | jax.Array | None
- restore_args#
a ocp.RestoreArgs object specifying restore options.
- __eq__(other)#
Return self==value.
- __hash__ = None#
- __init__(item=None, restore_args=None)#
ProtoCheckpointHandler#
- class orbax.checkpoint.ProtoCheckpointHandler(filename='proto.pbtxt', *, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None))[source][source]#
Serializes/deserializes protocol buffers.
- __init__(filename='proto.pbtxt', *, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None))[source][source]#
Initializes ProtoCheckpointHandler.
- Parameters:
filename (
str) – file name given to the written file.multiprocessing_options (
MultiprocessingOptions) – See orbax.checkpoint.options.
- async async_save(directory, item=None, args=None)[source][source]#
Saves the given proto.
- Parameters:
directory (
Path) – save location directory.item (
Optional[Message,None]) – Deprecated, use args.args (
Optional[ProtoSaveArgs,None]) – ProtoSaveArgs (see below).
- Return type:
Optional[List[Future],None]- Returns:
A commit future.
- restore(directory, item=None, args=None)[source][source]#
Restores the proto from directory.
- Parameters:
directory (
Path) – restore location directory.item (
Optional[Type[Message],None]) – Deprecated, use args.args (
Optional[ProtoRestoreArgs,None]) – ProtoRestoreArgs (see below).
- Returns:
The deserialized proto read from directory if item is not None
JaxRandomKeyCheckpointHandler#
- class orbax.checkpoint.JaxRandomKeyCheckpointHandler(key_name=None)[source][source]#
Handles saving and restoring individual Jax random key in both typed and untyped format.
- async async_save(directory, args)[source]#
Saves a random key asynchronously.
- Parameters:
directory (
Path) – Folder in which to save.args (
CheckpointArgs) – An ocp.checkpoint_args.CheckpointArgs.
- Return type:
Optional[List[Future],None]- Returns:
A list of commit futures which can be run to complete the save.
- restore(directory, args)[source]#
Restores a random key.
- Parameters:
directory (
Path) – folder from which to read.args (
CheckpointArgs) – An ocp.checkpoint_args.CheckpointArgs.
- Return type:
Any- Returns:
The restored object.
- class orbax.checkpoint.handlers.JaxRandomKeySaveArgs(item, save_args=None)[source][source]#
Parameters for saving a JAX random key.
- item#
a JAX random key.
- Type:
required
- save_args#
a ocp.SaveArgs object specifying save options.
- Type:
- __eq__(other)#
Return self==value.
- __hash__ = None#
- __init__(item, save_args=None)#
NumpyRandomKeyCheckpointHandler#
- class orbax.checkpoint.NumpyRandomKeyCheckpointHandler(key_name=None)[source][source]#
Saves Nnumpy random key in legacy or non-lagacy format.
- async async_save(directory, args)[source]#
Saves a random key asynchronously.
- Parameters:
directory (
Path) – Folder in which to save.args (
CheckpointArgs) – An ocp.checkpoint_args.CheckpointArgs.
- Return type:
Optional[List[Future],None]- Returns:
A list of commit futures which can be run to complete the save.
- restore(directory, args)[source]#
Restores a random key.
- Parameters:
directory (
Path) – folder from which to read.args (
CheckpointArgs) – An ocp.checkpoint_args.CheckpointArgs.
- Return type:
Any- Returns:
The restored object.