CheckpointHandlers

Contents

CheckpointHandlers#

Defines exported CheckpointHandler s and their arguments.

CheckpointHandler subclasses define logic used to save and restore an object to and from a checkpoint. Each CheckpointHandler has corresponding SaveArgs and RestoreArgs classes that define the arguments used to call the handler. Prefer to use ocp.args to reference these objects.

CheckpointHandler#

class orbax.checkpoint.CheckpointHandler[source][source]#

An interface providing save/restore methods used on a savable item.

Item may be a PyTree, Dataset, or any other supported object.

NOTE: Users should avoid using CheckpointHandler independently. Use Checkpointer or CheckpointManager instead.

abstractmethod save(directory, *args, **kwargs)[source][source]#

Saves the provided item synchronously.

Parameters:
  • directory (Path) – the directory to save to.

  • *args – additional arguments for save.

  • **kwargs – additional arguments for save.

abstractmethod restore(directory, *args, **kwargs)[source][source]#

Restores the provided item synchronously.

Parameters:
  • directory (Path) – the directory to restore from.

  • *args – additional arguments for restore.

  • **kwargs – additional arguments for restore.

Return type:

Any

Returns:

The restored item.

metadata(directory)[source][source]#

Returns metadata about the saved item.

Ideally, this is a cheap way to collect information about the checkpoint without requiring a full restoration.

Parameters:

directory (Path) – the directory where the checkpoint is located.

Return type:

Optional[Any, None]

Returns:

item metadata

finalize(directory)[source][source]#

Optional, custom checkpoint finalization callback.

Will be called on a single worker after all workers have finished writing.

Parameters:

directory (Path) – the directory where the checkpoint is located.

Return type:

None

close()[source][source]#

Closes the CheckpointHandler.

classmethod typestr()[source][source]#

A unique identifier for the CheckpointHandler type.

Return type:

str

AsyncCheckpointHandler#

class orbax.checkpoint.AsyncCheckpointHandler[source][source]#

An interface providing async methods used with AsyncCheckpointer.

abstractmethod async async_save(directory, *args, **kwargs)[source][source]#

Saves the given item to the provided directory.

Parameters:
  • directory (Path) – the directory to save to.

  • *args – additional arguments for save.

  • **kwargs – additional arguments for save.

Return type:

Optional[List[Future], None]

Returns:

A list of commit futures which can be awaited upon to complete the save operation.

StandardCheckpointHandler#

class orbax.checkpoint.StandardCheckpointHandler(*, save_concurrent_gb=96, restore_concurrent_gb=96, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), use_ocdbt=True)[source][source]#

A CheckpointHandler implementation for any PyTree structure.

See JAX documentation for more information on what constitutes a “PyTree”. This handler is capable of saving and restoring PyTrees with leaves of type Python scalar, np.ndarray, and jax.Array

As with all CheckpointHandler subclasses, StandardCheckpointHandler should only be used in conjunction with a Checkpointer (or subclass). By itself, the CheckpointHandler is non-atomic.

Example:

ckptr = Checkpointer(StandardCheckpointHandler())
# OR
ckptr = StandardCheckpointer()

If you find that your use case is not covered by StandardCheckpointHandler, consider using the parent class directly, or explore a custom implementation of CheckpointHandler.

__init__(*, save_concurrent_gb=96, restore_concurrent_gb=96, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), use_ocdbt=True)[source][source]#

Creates StandardCheckpointHandler.

Parameters:
  • save_concurrent_gb (int) – max concurrent GB that are allowed to be writing to disk at any given time. This limits the amount of data currently being written to disk, which can help to reduce the possibility of OOM’s when large checkpoints are saved. Note that this does NOT limit device-to-host transfer, meaning that the limit specified here may still be exceeded by the total memory usage of the process.

  • restore_concurrent_gb (int) – max concurrent GB that are allowed to be restored. Can help to reduce the possibility of OOM’s when large checkpoints are restored.

  • multiprocessing_options (MultiprocessingOptions) – See orbax.checkpoint.options.

  • pytree_metadata_options (PyTreeMetadataOptions) – Options to control types like tuple and namedtuple in pytree metadata.

  • use_ocdbt (bool) – Whether to enable Tensorstore OCDBT driver.

async async_save(directory, item=None, save_args=None, args=None)[source][source]#

Saves a PyTree of array-like objects.

See PyTreeCheckpointHandler.

Parameters:
  • directory (UnionType[Path, PathAwaitingCreation]) – path to the directory where the checkpoint will be saved.

  • item (Optional[Any, None]) – Deprecated, use args.

  • save_args (Optional[Any, None]) – Deprecated, use args.

  • args (Optional[StandardSaveArgs, None]) – StandardSaveArgs (see below).

Return type:

Optional[List[Future], None]

Returns:

A list of futures that will be completed when the save is complete.

save(directory, *args, **kwargs)[source][source]#

Saves the provided item synchronously.

restore(directory, item=None, args=None)[source][source]#

Restores a PyTree. See PyTreeCheckpointHandler.

Example:

ckptr = StandardCheckpointer()
item = {
    'layer0': {
        'w': jax.Array(...),
        'b': np.ndarray(...),
    },
}
ckptr.save(dir, StandardSaveArgs(item))

target = {
    'layer0': {
        'w': jax.ShapeDtypeStruct(...),
        'b': jax.Array(...),
    },
}
ckptr.restore(dir, StandardRestoreArgs(target))
Parameters:
  • directory (Path) – path from which to restore.

  • item (Optional[Any, None]) – Deprecated, use args.

  • args (Optional[StandardRestoreArgs, None]) – StandardRestoreArgs (see below).

Return type:

Any

Returns:

a restored PyTree.

metadata(directory)[source][source]#

Returns metadata about the saved item.

Return type:

TreeMetadata

finalize(directory)[source][source]#

Optional, custom checkpoint finalization callback.

Will be called on a single worker after all workers have finished writing.

Parameters:

directory (Path) – the directory where the checkpoint is located.

Return type:

None

close()[source][source]#

Closes the CheckpointHandler.

class orbax.checkpoint.handlers.StandardSaveArgs(item, save_args=None, custom_metadata=None)[source][source]#

Parameters for saving a standard PyTree.

Also see PyTreeSave for additional options.

item#

a PyTree to be saved.

Type:

required

save_args#

a PyTree with the same structure of item, which consists of ocp.SaveArgs objects as values. None can be used for values where no SaveArgs are specified.

Type:

Optional[PyTree]

custom_metadata#

User-provided custom metadata. An arbitrary JSON-serializable dictionary the user can use to store additional information. The field is treated as opaque by Orbax.

Type:

tree_types.JsonType | None

__eq__(other)#

Return self==value.

__hash__ = None#
__init__(item, save_args=None, custom_metadata=None)#
class orbax.checkpoint.handlers.StandardRestoreArgs(item=None, strict=True, support_layout=False, fallback_sharding=None)[source][source]#

Parameters for restoring a standard PyTree.

Also see PyTreeRestore for additional options.

Attributes (all optional):
item: target PyTree. Currently non-optional. Values may be either real

array or scalar values, or they may be jax.ShapeDtypeStruct, or ocp.metadata.value.Metadata objects (which come from calling the metadata method). If real values are provided, that value will be restored as the given type, with the given properties. If jax.ShapeDtypeStruct is provided, the value will be restored as np.ndarray, unless sharding is specified. If item is a custom PyTree class, the tree will be restored with the same structure as provided. If not provided, restores as a serialized nested dict representation of the custom class. TreeMetadata is also allowed as the tree used to define the restored structure.

strict: if False, restoration allows silent truncating/padding of arrays if

the stored array shape does not match the target shape. Otherwise, raises an error.

support_layout: if True, restores with the layouts in item. fallback_sharding: If provided, this sharding will be used as a fallback

if the saved sharding fails to load from the checkpoint.

__eq__(other)#

Return self==value.

__hash__ = None#
__init__(item=None, strict=True, support_layout=False, fallback_sharding=None)#

PyTreeCheckpointHandler#

class orbax.checkpoint.PyTreeCheckpointHandler(aggregate_filename=None, *, save_concurrent_gb=None, restore_concurrent_gb=None, save_device_host_concurrent_gb=None, memory_limit_options=None, use_ocdbt=True, use_zarr3=False, use_compression=True, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), type_handler_registry=<orbax.checkpoint._src.serialization.type_handler_registry._TypeHandlerRegistryImpl object>, handler_impl=None, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_validator=<orbax.checkpoint._src.metadata.array_metadata_store.Validator object>, enable_pinned_host_transfer=None, is_prioritized_key_fn=None)[source][source]#

A CheckpointHandler implementation for any PyTree structure.

See JAX documentation for more information on what consistutes a “PyTree”. This handler is capable of saving and restoring any leaf object for which a TypeHandler (see documentation) is registered. By default, TypeHandler`s for standard types like `np.ndarray, jax.Array, Python scalars, and others are registered.

As with all CheckpointHandler subclasses, PyTreeCheckpointHandler should only be used in conjunction with a Checkpointer (or subclass). By itself, the CheckpointHandler is non-atomic.

Example:

ckptr = Checkpointer(PyTreeCheckpointHandler())

# TODO(cpgaffney) Cut down on the protected methods accessed by this class.

__init__(aggregate_filename=None, *, save_concurrent_gb=None, restore_concurrent_gb=None, save_device_host_concurrent_gb=None, memory_limit_options=None, use_ocdbt=True, use_zarr3=False, use_compression=True, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), type_handler_registry=<orbax.checkpoint._src.serialization.type_handler_registry._TypeHandlerRegistryImpl object>, handler_impl=None, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_validator=<orbax.checkpoint._src.metadata.array_metadata_store.Validator object>, enable_pinned_host_transfer=None, is_prioritized_key_fn=None)[source][source]#

Creates PyTreeCheckpointHandler.

Parameters:
  • aggregate_filename (Optional[str, None]) – name that the aggregated checkpoint should be saved as.

  • save_concurrent_gb (Optional[int, None]) – max concurrent GB that are allowed to be writing to disk at any given time. This limits the amount of data currently being written to disk, which can help to reduce the possibility of OOM’s when large checkpoints are saved. Note that this does NOT limit device-to-host transfer, meaning that the limit specified here may still be exceeded by the total memory usage of the process.

  • restore_concurrent_gb (Optional[int, None]) – max concurrent GB that are allowed to be restored. Can help to reduce the possibility of OOM’s when large checkpoints are restored.

  • save_device_host_concurrent_gb (UnionType[int, str, None]) – max concurrent GB allowed to be transferred from device to host memory at once when saving, defined on a per-worker basis. When the limit is reached, arrays must be finished writing to the checkpoint before a new array can start being transferred. This option is a stronger version of save_concurrent_gb. Unlike save_concurrent_gb which only limits the amount of data currently being written to disk, this option limits the amount of data transferred from device to host. Note that asynchronous saves may not be truly asynchronous with this option enabled, as we have to block on some array writes before beginning others. Also see is_prioritized_key_fn. Can be set to “auto” to enable Memory Regulator.

  • memory_limit_options (UnionType[MemoryLimitOptions, None]) – Memory limit options for the checkpoint handler.

  • use_ocdbt (bool) – enables Tensorstore OCDBT driver. This option allows using a different checkpoint format which is faster to read and write, as well as more space efficient.

  • use_zarr3 (bool) – If True, use Zarr ver3 otherwise Zarr ver2

  • use_compression (bool) – If True and zarr2 is used, use zstd compression.

  • multiprocessing_options (MultiprocessingOptions) – See orbax.checkpoint.options

  • type_handler_registry (TypeHandlerRegistry) – a type_handlers.TypeHandlerRegistry. If not specified, the global type handler registry will be used.

  • handler_impl (Optional[BasePyTreeCheckpointHandler, None]) – Allows overriding the internal implementation.

  • pytree_metadata_options (PyTreeMetadataOptions) – PyTreeMetadataOptions to manage metadata.

  • array_metadata_validator (Validator) – Validator for ArrayMetadata.

  • enable_pinned_host_transfer (Optional[bool, None]) – Whether to use pinned_host memory for the transfer from device to host memory. Passing None will enable pinned_host memory depending on the platform used (currently only enables it for the GPU backend).

  • is_prioritized_key_fn (Optional[IsPrioritizedKeyFn, None]) – A function that accepts a PyTree keypath (obtained using jax.tree.map_with_path) that should be scheduled for D2H transfer before other keys. The transfer is scheduled before returning to the caller, so the values will never be corrupted by a concurrent update. Keys that are not prioritized will not be scheduled for transfer until all prioritized keys have been fully written to the checkpoint. This means that these values may be altered if the values are updated concurrently. Callers should take care to call wait_until_finished before updating array values (e.g. apply_gradients) if some keys are not prioritized. Note that any “prioritized” keys are assumed to be lightweight, and save_device_host_concurrent_gb will be ignored for them.

async async_save(directory, item=None, save_args=None, args=None)[source][source]#

Saves a PyTree to a given directory.

This operation is compatible with a multi-host, multi-device setting. Tree leaf values must be supported by the type_handler_registry given in the constructor. Standard supported types include Python scalars, np.ndarray, jax.Array, and strings.

After saving, all files will be located in “directory/”. The exact files that are saved depend on the specific combination of options, including use_ocdbt. A JSON metadata file will be present to store the tree structure.

Example usage:

ckptr = ocp.Checkpointer(ocp.PyTreeCheckpointHandler())
item = {
    'layer0': {
        'w': np.ndarray(...),
        'b': np.ndarray(...),
    },
    'layer1': {
        'w': np.ndarray(...),
        'b': np.ndarray(...),
    },
}
# Note: save_args may be None if no customization is desired for saved
# parameters.
# Otherwise, settings can be used to customize save behavior, e.g.
# casting.
save_args = jax.tree.map(lambda x: ocp.SaveArgs(dtype=np.int32), item)
# Eventually calls through to `async_save`.
ckptr.save(path, args=ocp.PyTreeSave(item, save_args))
Parameters:
  • directory (UnionType[Path, PathAwaitingCreation]) – save location directory.

  • item (Optional[Any, None]) – Deprecated, use `args.

  • save_args (Optional[PyTreeSaveArgs, None]) – Deprecated, use args.

  • args (Optional[PyTreeSaveArgs, None]) – PyTreeSaveArgs (see below).

Return type:

Optional[List[Future], None]

Returns:

A Future that will commit the data to directory when awaited. Copying the data from its source will be awaited in this function.

save(directory, item=None, save_args=None, args=None)[source][source]#

Saves the provided item. See async_save.

restore(directory, item=None, restore_args=None, transforms=None, transforms_default_to_original=True, legacy_transform_fn=None, args=None)[source][source]#

Restores a PyTree from the checkpoint directory at the given path.

In the most basic case, only directory is required. The tree will be restored exactly as saved, and all leaves will be restored as the correct types (assuming the tree metadata is present).

However, restore_args is often required as well. This PyTree gives a RestoreArgs object (or subclass) for every leaf in the tree. Many types, such as string or np.ndarray do not require any special options for restoration. When restoring an individual leaf as jax.Array, however, some properties may be required.

One example is sharding, which defines how a jax.Array in the restored tree should be partitioned. mesh and mesh_axes can also be used to specify sharding, but sharding is the preferred way of specifying this partition since mesh and mesh_axes only constructs jax.sharding.NamedSharding. For more information, see ArrayTypeHandler documentation and JAX sharding documentation.

Example:

ckptr = ocp.Checkpointer(ocp.PyTreeCheckpointHandler())
restore_args = {
    'layer0': {
        'w': ocp.RestoreArgs(),
        'b': ocp.RestoreArgs(),
    },
    'layer1': {
        'w': ocp.ArrayRestoreArgs(
            # Restores as jax.Array, regardless of how it was saved.
            restore_type=jax.Array,
            sharding=jax.sharding.Sharding(...),
            # Warning: may truncate or pad!
            global_shape=(x, y),
          ),
        'b': ocp.ArrayRestoreArgs(
            restore_type=jax.Array,
            sharding=jax.sharding.Sharding(...),
            global_shape=(x, y),
          ),
    },
}
ckptr.restore(path, args=ocp.PyTreeRestore(restore_args=restore_args))

Providing item is typically only necessary when restoring a custom PyTree class (or when using transformations). In this case, the restored object will take on the same structure as item.

Example:

@flax.struct.dataclass
class TrainState:
  layer0: dict[str, jax.Array]
  layer1: dict[str, jax.Array]

ckptr = Checkpointer(PyTreeCheckpointHandler())
train_state = TrainState(
    layer0={
        'w': jax.Array(...),  # zeros
        'b': jax.Array(...),  # zeros
    },
    layer1={
        'w': jax.Array(...),  # zeros
        'b': jax.Array(...),  # zeros
    },
)
restore_args = jax.tree.map(_make_restore_args, train_state)
ckptr.restore(path, item=train_state, restore_args=restore_args)
# restored tree is of type `TrainState`.
Parameters:
  • directory (Path) – saved checkpoint location directory.

  • item (Optional[Any, None]) – Deprecated, use args.

  • restore_args (Optional[Any, None]) – Deprecated, use args.

  • transforms (Optional[Any, None]) – Deprecated, use args.

  • transforms_default_to_original (bool) – See transform_utils.apply_transformations.

  • legacy_transform_fn (Optional[Callable[[Any, Any, Any], Tuple[Any, Any]], None]) – Deprecated, use args.

  • args (Optional[PyTreeRestoreArgs, None]) – PyTreeRestoreArgs (see below).

Return type:

Any

Returns:

A PyTree matching the structure of item.

Raises:
  • FileNotFoundErrordirectory does not exist or is missing required files

  • ValueErrortransforms is provided without item.

  • ValueErrortransforms contains elements with multi_value_fn.

metadata(directory)[source][source]#

Returns tree metadata.

The result will be a PyTree matching the structure of the saved checkpoint. Note that if the item saved was a custom class, the restored metadata will be returned as a nested dictionary representation.

Example:

{
  'layer0': {
      'w': ArrayMetadata(dtype=jnp.float32, shape=(8, 8), shards=(1, 2)),
      'b': ArrayMetadata(dtype=jnp.float32, shape=(8,), shards=(1,)),
  },
  'step': ScalarMetadata(dtype=jnp.int64),
}

If the required metadata file is not present, this method will raise an error.

Parameters:

directory (Path) – checkpoint location.

Return type:

TreeMetadata

Returns:

tree containing metadata.

finalize(directory)[source][source]#

Finalization step.

Called automatically by the Checkpointer/AsyncCheckpointer just before the checkpoint is considered “finalized” in the sense of ensuring atomicity. See documentation for type_handlers.merge_ocdbt_per_process_files.

Parameters:

directory (Path) – Path where the checkpoint is located.

Return type:

None

close()[source][source]#

Closes the handler. Called automatically by Checkpointer.

class orbax.checkpoint.handlers.PyTreeSaveArgs(item, save_args=None, ocdbt_target_data_file_size=None, custom_metadata=None)[source][source]#

Parameters for saving a PyTree.

item#

a PyTree to be saved.

Type:

required

save_args#

a PyTree with the same structure of item, which consists of ocp.SaveArgs objects as values. None can be used for values where no SaveArgs are specified.

Type:

Optional[PyTree]

ocdbt_target_data_file_size#

Specifies the target size (in bytes) of each OCDBT data file. It only applies when OCDBT is enabled and Zarr3 must be turned on. If left unspecified, default size is 2GB. A value of 0 indicates no maximum file size limit. For best results, ensure chunk_byte_size is smaller than this value. For more details, refer to https://google.github.io/tensorstore/kvstore/ocdbt/index.html#json-kvstore/ocdbt.target_data_file_size

Type:

Optional[int]

custom_metadata#

User-provided custom metadata. An arbitrary JSON-serializable dictionary the user can use to store additional information. The field is treated as opaque by Orbax.

Type:

tree_types.JsonType | None

__eq__(other)#

Return self==value.

__hash__ = None#
__init__(item, save_args=None, ocdbt_target_data_file_size=None, custom_metadata=None)#
class orbax.checkpoint.handlers.PyTreeRestoreArgs(item=None, restore_args=None, transforms=None, transforms_default_to_original=True, legacy_transform_fn=None, partial_restore=False)[source][source]#

Parameters for restoring a PyTree.

Attributes (all optional):
item: provides the tree structure for the restored item. If not provided,

will infer the structure from the saved checkpoint. Transformations will not be run in this case. Necessary particularly in the case where the caller needs to restore the tree as a custom object. TreeMetadata is also allowed as the tree used to define the restored structure.

restore_args: optional object containing additional arguments for

restoration. It should be a PyTree matching the structure of item, or if item is not provided, then it should match the structure of the checkpoint. Each value in the tree should be a RestoreArgs object (OR a subclass of RestoreArgs). Importantly, note that when restoring a leaf as a certain type, a specific subclass of RestoreArgs may be required. RestoreArgs also provides the option to customize the restore type of an individual leaf. TreeMetadata is also allowed as the restore_args tree.

transforms: a PyTree of transformations that should be applied to the

saved tree in order to obtain a final structure. The transforms tree structure should conceptually match that of item, but the use of regexes and implicit keys means that it does not need to match completely. See transform_utils for further information. TreeMetadata is also allowed as the transforms tree.

transforms_default_to_original:

See transform_utils.apply_transformations.

legacy_transform_fn: WARNING: NOT GENERALLY SUPPORTED. A function which

accepts the item argument, a PyTree checkpoint structure and a PyTree of ParamInfos based on the checkpoint. Returns a transformed PyTree matching the desired return tree structure, and a matching ParamInfo tree.

partial_restore: If True, only restore the parameters that are specified

in PyTreeRestoreArgs.

__eq__(other)#

Return self==value.

__hash__ = None#
__init__(item=None, restore_args=None, transforms=None, transforms_default_to_original=True, legacy_transform_fn=None, partial_restore=False)#

CompositeCheckpointHandler#

class orbax.checkpoint.CompositeCheckpointHandler(*item_names, composite_options=CompositeOptions(multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), temporary_path_class=None, file_options=None, async_options=None), handler_registry=None, **items_and_handlers)[source][source]#

CheckpointHandler for saving multiple items.

As with all CheckpointHandler implementations, use only in conjunction with an instance of AbstractCheckpointer.

The CompositeCheckpointHandler allows dealing with multiple items of different types or logical distinctness, such as training state (PyTree), dataset iterator, JSON metadata, or anything else. To specify the handler and args for each checkpointable item, use the handler_registry option at initialization (see CheckpointHandlerRegistry).

If a handler registry is provided, the CompositeCheckpointHandler will use this registry to determine the CheckpointHandler for each checkpointable item. If no registry is provided, an empty registry will be constructed by default internally. Additional items not registered in the handler registry can be specified when calling save or restore. The handler will be determined from the provided CheckpointArgs during the first call to save or restore. Subsequent calls to save or restore will use the same handler and will require the same CheckpointArgs to be provided.

item_names and items_and_handlers are deprecated arguments for specifying the items and handlers. Please use handler_registry instead.

Usage:

# The simplest use-case, with no handler registry provided on construction.
checkpointer = ocp.Checkpointer(
    ocp.CompositeCheckpointHandler()
)

checkpointer.save(directory,
    ocp.args.Composite(
        # The handler will be determined from the global registry using the
        # provided args.
        state=ocp.args.StandardSave(pytree),
    )
)

restored: ocp.args.Composite = checkpoint_handler.restore(directory,
    ocp.args.Composite(
        state=ocp.args.StandardRestore(abstract_pytree),
    )
)
restored.state ...

# The "state" item's args was determined on the first call to `save`.
# Trying to save "state" with a different set of args will raise an error.
checkpointer.save(directory,
  ocp.args.Composite(
    # Will raise `ValueError: Item "state" and args "JsonSave [...]" does
    # not match with any registered handler! ...`.
    state=ocp.args.JsonSave(pytree),
    )
)

# You can also register specific items and handlers using a handler
# registry.
handler_registry = (
    ocp.handler_registration.DefaultCheckpointHandlerRegistry()
)
# Some subclass of `CheckpointHandler` that handles the "state" item.
handler = state_handler()
# Multiple save args and handlers can be registered for the same item.
handler_registry.add(
    'state,
    StateSaveArgs,
    handler,
)
handler_registry.add(
    'state',
    StateRestoreArgs,
    handler,
)

# The `CompositeCheckpointHandler` will use the handlers registered in the
# registry when it encounters the item name `state`.
checkpointer_with_registry = ocp.Checkpointer(
    ocp.CompositeCheckpointHandler(handler_registry=handler_registry)
)

checkpointer_with_registry.save(directory,
    ocp.args.Composite(
        # The handler knows that the "state" item should use the
        # `state_handler` specified in the registry. When an item has been
        # added to the handler registry, only save args that are registered
        # for that item are allowed.
        state=StateSaveArgs(pytree),
        # You can also provide other arguments that are not registered in
        # the handler registry as long as a handler has been globally
        # registered for the args.
        other_param=ocp.args.StandardSave(other_pytree),
    )
)

restored: ocp.args.Composite = checkpointer_with_registry.restore(directory,
    ocp.args.Composite(
        state=StateRestoreArgs(abstract_pytree),
        other_param=ocp.args.StandardRestore(abstract_other_pytree),
    )
)
restored.state ...
restored.other_param...

# Skip restoring `other_params` (you can save a subset of items too, in a
# similar fashion).
restored_state_only: ocp.args.Composite = (
  checkpointer_with_registry.restore(directory,
    ocp.args.Composite(
        state=StateRestoreArgs(abstract_pytree),
    )
  )
)
restored.state ...
restored.other_param ... # None.
__init__(*item_names, composite_options=CompositeOptions(multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), temporary_path_class=None, file_options=None, async_options=None), handler_registry=None, **items_and_handlers)[source][source]#

Constructor.

If you are item_names and/or items_and_handlers, all items must be provided up-front, at initialization. If you are using a handler_registry, you can register items at any time, even after the first call to save or restore.

Parameters:
  • *item_names – A list of string item names that this handler will manage. item_names is deprecated. Please use handler_registry instead.

  • composite_options (CompositeOptions) – Options.

  • handler_registry (Optional[CheckpointHandlerRegistry, None]) – A CheckpointHandlerRegistry instance. If provided, the CompositeCheckpointHandler will use this registry to determine the CheckpointHandler for each item. This option is mutually exclusive with items_and_handlers and item_names.

  • **items_and_handlers – A mapping of item name to CheckpointHandler instance, which will be used as the handler for objects of the corresponding name. items_and_handlers is deprecated. Please use handler_registry instead.

async async_save(directory, args)[source][source]#

Saves multiple items to individual subdirectories.

Return type:

Optional[List[Future], None]

save(*args, **kwargs)[source][source]#

Saves synchronously.

restore(directory, args=None)[source][source]#

Restores the provided item synchronously.

Restoration can happen in a variety of modes, depending on which args is passed:

  • args is not None and not empty. Item names present in args will be

restored as long as they exist in the checkpoint and have a registered handler. - args is None or empty. All items in the checkpoint that have a registered handler will be restored.

Parameters:
  • directory (Path) – Path to restore from.

  • args (Optional[CompositeArgs, None]) – CompositeArgs object used to restore individual sub-items. May be None or “empty” (CompositeArgs()), which is used to indicate that the handler should restore everything. If an individual item was not specified during the save, None will be returned for that item’s entry.

Return type:

CompositeArgs

Returns:

A CompositeResults object with keys matching CompositeArgs, or with keys for all known items as specified at creation.

Raises:

KeyError – If an item could not be restored.

metadata_from_temporary_paths(directory)[source][source]#

Metadata for each item in the temporary checkpoint.

Return type:

StepMetadata

metadata(directory)[source][source]#

Metadata for each item in the checkpoint.

This has much the same logic as restore, in the sense that it tries to restore the metadata for each item present in the checkpoint. However, for any items that do not have a registered handler, the metadata for that item will simply be returned as None.

Parameters:

directory (Path) – Path to the checkpoint.

Return type:

StepMetadata

Returns:

StepMetadata

Raises:

FileNotFoundError – If the directory does not exist.

finalize(directory)[source][source]#

Optional, custom checkpoint finalization callback.

Will be called on a single worker after all workers have finished writing.

Parameters:

directory (Path) – the directory where the checkpoint is located.

close()[source][source]#

Closes the CheckpointHandler.

class orbax.checkpoint.handlers.CompositeArgs(**items)[source][source]#

Args for wrapping multiple checkpoint items together.

orbax.checkpoint.handlers.CompositeResults[source]#

alias of CompositeArgs

JsonCheckpointHandler#

class orbax.checkpoint.JsonCheckpointHandler(filename=None, *, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None))[source][source]#

Saves nested dictionary using json.

__init__(filename=None, *, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None))[source][source]#

Initializes JsonCheckpointHandler.

Parameters:
  • filename (Optional[str, None]) – optional file name given to the written file; defaults to ‘metadata’

  • multiprocessing_options (MultiprocessingOptions) – See orbax.checkpoint.options.

async async_save(directory, item=None, args=None)[source][source]#

Saves the given item.

Parameters:
  • directory (Path) – save location directory.

  • item (Optional[Mapping[str, Any], None]) – Deprecated, use args instead.

  • args (Optional[JsonSaveArgs, None]) – JsonSaveArgs (see below).

Return type:

Optional[List[Future], None]

Returns:

A list of commit futures.

save(directory, item=None, args=None)[source][source]#

Saves the provided item synchronously.

Parameters:
  • directory (Path) – the directory to save to.

  • *args – additional arguments for save.

  • **kwargs – additional arguments for save.

restore(directory, item=None, args=None)[source][source]#

Restores json mapping from directory.

item is unused.

Parameters:
  • directory (Path) – restore location directory.

  • item (Optional[Mapping[str, Any], None]) – unused

  • args (Optional[JsonRestoreArgs, None]) – unused

Return type:

Mapping[str, Any]

Returns:

JSON dict.

Raises:

FileNotFoundError – if the file does not exist.

class orbax.checkpoint.handlers.JsonSaveArgs(item)[source][source]#

Parameters for saving to json.

item#

a nested dictionary.

Type:

required

__eq__(other)#

Return self==value.

__hash__ = None#
__init__(item)#
class orbax.checkpoint.handlers.JsonRestoreArgs(item=None)[source][source]#

Json restore args.

item#

unused, but included for legacy-compatibility reasons. New code should not set this attribute.

Type:

bytes | None

__eq__(other)#

Return self==value.

__hash__ = None#
__init__(item=None)#

ArrayCheckpointHandler#

class orbax.checkpoint.ArrayCheckpointHandler(checkpoint_name=None)[source][source]#

Handles saving and restoring individual arrays and scalars.

__init__(checkpoint_name=None)[source][source]#

Initializes the handler.

Parameters:

checkpoint_name (Optional[str, None]) – Provides a name for the directory under which Tensorstore files will be saved. Defaults to ‘checkpoint’.

async async_save(directory, item=None, save_args=None, args=None)[source][source]#

Saves an object asynchronously.

Parameters:
  • directory (Path) – Folder in which to save.

  • item (Union[int, float, number, ndarray, Array, None]) – Deprecated, use args.

  • save_args (Optional[SaveArgs, None]) – Deprecated, use args.

  • args (Optional[ArraySaveArgs, None]) – An ocp.array_checkpoint_handler.ArraySaveArgs (see below).

Return type:

Optional[List[Future], None]

Returns:

A list of commit futures which can be run to complete the save.

save(directory, *args, **kwargs)[source][source]#

Saves an array synchronously.

restore(directory, item=None, restore_args=None, args=None)[source][source]#

Restores an object.

Parameters:
  • directory (Path) – folder from which to read.

  • item (Union[int, float, number, ndarray, Array, None]) – Deprecated, use args.

  • restore_args (Optional[RestoreArgs, None]) – Deprecated, use args.

  • args (Optional[ArrayRestoreArgs, None]) – An ocp.array_checkpoint_handler.ArrayRestoreArgs object (see below).

Return type:

Union[int, float, number, ndarray, Array]

Returns:

The restored object.

finalize(directory)[source][source]#

Optional, custom checkpoint finalization callback.

Will be called on a single worker after all workers have finished writing.

Parameters:

directory (Path) – the directory where the checkpoint is located.

close()[source][source]#

See superclass documentation.

class orbax.checkpoint.handlers.ArraySaveArgs(item, save_args=None)[source][source]#

Parameters for saving an array or scalar.

item#

an array or scalar object.

Type:

required

save_args#

a ocp.SaveArgs object specifying save options.

Type:

orbax.checkpoint._src.serialization.types.SaveArgs | None

__eq__(other)#

Return self==value.

__hash__ = None#
__init__(item, save_args=None)#
class orbax.checkpoint.handlers.ArrayRestoreArgs(item=None, restore_args=None)[source][source]#

Array restore args.

item#

unused, but provided as an option for legacy-compatibility reasons.

Type:

int | float | numpy.number | numpy.ndarray | jax.Array | None

restore_args#

a ocp.RestoreArgs object specifying restore options.

Type:

orbax.checkpoint._src.serialization.types.RestoreArgs | None

__eq__(other)#

Return self==value.

__hash__ = None#
__init__(item=None, restore_args=None)#

ProtoCheckpointHandler#

class orbax.checkpoint.ProtoCheckpointHandler(filename='proto.pbtxt', *, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None))[source][source]#

Serializes/deserializes protocol buffers.

__init__(filename='proto.pbtxt', *, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None))[source][source]#

Initializes ProtoCheckpointHandler.

Parameters:
  • filename (str) – file name given to the written file.

  • multiprocessing_options (MultiprocessingOptions) – See orbax.checkpoint.options.

async async_save(directory, item=None, args=None)[source][source]#

Saves the given proto.

Parameters:
  • directory (Path) – save location directory.

  • item (Optional[Message, None]) – Deprecated, use args.

  • args (Optional[ProtoSaveArgs, None]) – ProtoSaveArgs (see below).

Return type:

Optional[List[Future], None]

Returns:

A commit future.

save(*args, **kwargs)[source][source]#

Saves the provided item.

restore(directory, item=None, args=None)[source][source]#

Restores the proto from directory.

Parameters:
  • directory (Path) – restore location directory.

  • item (Optional[Type[Message], None]) – Deprecated, use args.

  • args (Optional[ProtoRestoreArgs, None]) – ProtoRestoreArgs (see below).

Returns:

The deserialized proto read from directory if item is not None

class orbax.checkpoint.handlers.ProtoSaveArgs(item)[source][source]#

Parameters for saving a proto.

item#

the proto to serialize.

Type:

required

__eq__(other)#

Return self==value.

__hash__ = None#
__init__(item)#
class orbax.checkpoint.handlers.ProtoRestoreArgs(item)[source][source]#

Proto restore args.

item#

the proto class

Type:

required

__eq__(other)#

Return self==value.

__hash__ = None#
__init__(item)#

JaxRandomKeyCheckpointHandler#

class orbax.checkpoint.JaxRandomKeyCheckpointHandler(key_name=None)[source][source]#

Handles saving and restoring individual Jax random key in both typed and untyped format.

save(directory, *args, **kwargs)[source]#

Saves a random key synchronously.

async async_save(directory, args)[source]#

Saves a random key asynchronously.

Parameters:
  • directory (Path) – Folder in which to save.

  • args (CheckpointArgs) – An ocp.checkpoint_args.CheckpointArgs.

Return type:

Optional[List[Future], None]

Returns:

A list of commit futures which can be run to complete the save.

restore(directory, args)[source]#

Restores a random key.

Parameters:
  • directory (Path) – folder from which to read.

  • args (CheckpointArgs) – An ocp.checkpoint_args.CheckpointArgs.

Return type:

Any

Returns:

The restored object.

class orbax.checkpoint.handlers.JaxRandomKeySaveArgs(item, save_args=None)[source][source]#

Parameters for saving a JAX random key.

item#

a JAX random key.

Type:

required

save_args#

a ocp.SaveArgs object specifying save options.

Type:

orbax.checkpoint._src.serialization.types.SaveArgs | None

__eq__(other)#

Return self==value.

__hash__ = None#
__init__(item, save_args=None)#
class orbax.checkpoint.handlers.JaxRandomKeyRestoreArgs(restore_args=None)[source][source]#

Jax random key restore args.

restore_args#

a ocp.RestoreArgs object specifying restore options for JaxArray.

Type:

orbax.checkpoint._src.serialization.types.RestoreArgs | None

__eq__(other)#

Return self==value.

__hash__ = None#
__init__(restore_args=None)#

NumpyRandomKeyCheckpointHandler#

class orbax.checkpoint.NumpyRandomKeyCheckpointHandler(key_name=None)[source][source]#

Saves Nnumpy random key in legacy or non-lagacy format.

save(directory, *args, **kwargs)[source]#

Saves a random key synchronously.

async async_save(directory, args)[source]#

Saves a random key asynchronously.

Parameters:
  • directory (Path) – Folder in which to save.

  • args (CheckpointArgs) – An ocp.checkpoint_args.CheckpointArgs.

Return type:

Optional[List[Future], None]

Returns:

A list of commit futures which can be run to complete the save.

restore(directory, args)[source]#

Restores a random key.

Parameters:
  • directory (Path) – folder from which to read.

  • args (CheckpointArgs) – An ocp.checkpoint_args.CheckpointArgs.

Return type:

Any

Returns:

The restored object.

class orbax.checkpoint.handlers.NumpyRandomKeySaveArgs(item)[source][source]#

Parameters for saving a Numpy random key.

item#

a Numpy random key in legacy or nonlegacy format

Type:

required

__eq__(other)#

Return self==value.

__hash__ = None#
__init__(item)#
class orbax.checkpoint.handlers.NumpyRandomKeyRestoreArgs[source][source]#

Numpy random key restore args.

__eq__(other)#

Return self==value.

__hash__ = None#
__init__()#