CheckpointHandlers

Contents

CheckpointHandlers#

Defines exported symbols for the namespace package orbax.checkpoint.

CheckpointHandler#

class orbax.checkpoint.CheckpointHandler[source][source]#

An interface providing save/restore methods used on a savable item.

Item may be a PyTree, Dataset, or any other supported object.

NOTE: Users should avoid using CheckpointHandler independently. Use Checkpointer or CheckpointManager.

abstract save(directory, *args, **kwargs)[source][source]#

Saves the provided item synchronously.

Parameters:
  • directory (Path) – the directory to save to.

  • *args – additional arguments for save.

  • **kwargs – additional arguments for save.

abstract restore(directory, *args, **kwargs)[source][source]#

Restores the provided item synchronously.

Parameters:
  • directory (Path) – the directory to restore from.

  • *args – additional arguments for restore.

  • **kwargs – additional arguments for restore.

Return type:

Any

Returns:

The restored item.

metadata(directory)[source][source]#

Returns metadata about the saved item.

Ideally, this is a cheap way to collect information about the checkpoint without requiring a full restoration.

Parameters:

directory (Path) – the directory where the checkpoint is located.

Return type:

Optional[Any]

Returns:

item metadata

finalize(directory)[source][source]#

Optional, custom checkpoint finalization callback.

Will be called on a single worker after all workers have finished writing.

Parameters:

directory (Path) – the directory where the checkpoint is located.

Return type:

None

close()[source][source]#

Closes the CheckpointHandler.

AsyncCheckpointHandler#

class orbax.checkpoint.AsyncCheckpointHandler[source][source]#

An interface providing async methods that can be used with CheckpointHandler.

abstract async async_save(directory, *args, **kwargs)[source][source]#

Constructs a save operation.

Synchronously awaits a copy of the item, before returning commit futures necessary to save the item.

Parameters:
  • directory (Path) – the directory to save to.

  • *args – additional arguments for save.

  • **kwargs – additional arguments for save.

Return type:

Optional[List[Future]]

StandardCheckpointHandler#

class orbax.checkpoint.StandardCheckpointHandler(concurrent_gb=96, primary_host=0)[source][source]#

A CheckpointHandler implementation for any PyTree structure.

See JAX documentation for more information on what constitutes a “PyTree”. This handler is capable of saving and restoring PyTrees with leaves of type Python scalar, np.ndarray, and jax.Array

As with all CheckpointHandler subclasses, StandardCheckpointHandler should only be used in conjunction with a Checkpointer (or subclass). By itself, the CheckpointHandler is non-atomic.

Example:

ckptr = Checkpointer(StandardCheckpointHandler())
# OR
ckptr = StandardCheckpointer()

If you find that your use case is not covered by StandardCheckpointHandler, consider using the parent class directly, or explore a custom implementation of CheckpointHandler.

__init__(concurrent_gb=96, primary_host=0)[source][source]#

Creates StandardCheckpointHandler.

Parameters:
  • concurrent_gb (int) – max concurrent GB that are allowed to be read. Can help to reduce the possibility of OOM’s when large checkpoints are restored.

  • primary_host (Optional[int]) – the host id of the primary host. Default to 0. If it’s set to None, then all hosts will be considered as primary. It’s useful in the case that all hosts are only working with local storage.

async async_save(directory, item=None, save_args=None, args=None)[source][source]#

Saves a PyTree of array-like objects. See PyTreeCheckpointHandler.

Return type:

Optional[List[Future]]

save(directory, *args, **kwargs)[source][source]#

Saves the provided item synchronously.

restore(directory, item=None, args=None)[source][source]#

Restores a PyTree. See PyTreeCheckpointHandler.

Example:

ckptr = StandardCheckpointer()
item = {
    'layer0': {
        'w': jax.Array(...),
        'b': np.ndarray(...),
    },
}
ckptr.save(dir, StandardSaveArgs(item))

target = {
    'layer0': {
        'w': jax.ShapeDtypeStruct(...),
        'b': jax.Array(...),
    },
}
ckptr.restore(dir, StandardRestoreArgs(target))
Parameters:
  • directory (Path) – path from which to restore.

  • item (Optional[Any]) – Deprecated, use args.

  • args (Optional[StandardRestoreArgs]) – StandardRestoreArgs (see below).

Return type:

Any

Returns:

a restored PyTree.

metadata(directory)[source][source]#

Returns metadata about the saved item.

Return type:

Any

finalize(directory)[source][source]#

Optional, custom checkpoint finalization callback.

Will be called on a single worker after all workers have finished writing.

Parameters:

directory (Path) – the directory where the checkpoint is located.

Return type:

None

close()[source][source]#

Closes the CheckpointHandler.

PyTreeCheckpointHandler#

class orbax.checkpoint.PyTreeCheckpointHandler(aggregate_filename=None, concurrent_gb=96, use_ocdbt=True, use_zarr3=False, primary_host=0, type_handler_registry=<orbax.checkpoint.type_handlers._TypeHandlerRegistryImpl object>)[source][source]#

A CheckpointHandler implementation for any PyTree structure.

See JAX documentation for more information on what consistutes a “PyTree”. This handler is capable of saving and restoring any leaf object for which a TypeHandler (see documentation) is registered. By default, TypeHandler`s for standard types like `np.ndarray, jax.Array, Python scalars, and others are registered.

As with all CheckpointHandler subclasses, PyTreeCheckpointHandler should only be used in conjunction with a Checkpointer (or subclass). By itself, the CheckpointHandler is non-atomic.

Example:

ckptr = Checkpointer(PyTreeCheckpointHandler())
__init__(aggregate_filename=None, concurrent_gb=96, use_ocdbt=True, use_zarr3=False, primary_host=0, type_handler_registry=<orbax.checkpoint.type_handlers._TypeHandlerRegistryImpl object>)[source][source]#

Creates PyTreeCheckpointHandler.

Parameters:
  • aggregate_filename (Optional[str]) – name that the aggregated checkpoint should be saved as.

  • concurrent_gb (int) – max concurrent GB that are allowed to be read. Can help to reduce the possibility of OOM’s when large checkpoints are restored.

  • use_ocdbt (bool) – enables Tensorstore OCDBT driver. This option allows using a different checkpoint format which is faster to read and write, as well as more space efficient.

  • use_zarr3 (bool) – If True, use Zarr ver3 otherwise Zarr ver2

  • primary_host (Optional[int]) – the host id of the primary host. Default to 0. If it’s set to None, then all hosts will be considered as primary. It’s useful in the case that all hosts are only working with local storage.

  • type_handler_registry (TypeHandlerRegistry) – a type_handlers.TypeHandlerRegistry. If not specified, the global type handler registry will be used.

async async_save(directory, item=None, save_args=None, args=None)[source][source]#

Saves a PyTree to a given directory.

This operation is compatible with a multi-host, multi-device setting. Tree leaf values must be supported by the type_handler_registry given in the constructor. Standard supported types include Python scalars, np.ndarray, jax.Array, and strings.

After saving, all files will be located in “directory/”. The exact files that are saved depend on the specific combination of options, including use_ocdbt. A JSON metadata file will be present to store the tree structure. In addition, a msgpack file may be present, allowing users to store aggregated values (see below).

Example usage:

ckptr = Checkpointer(PyTreeCheckpointHandler())
item = {
    'layer0': {
        'w': np.ndarray(...),
        'b': np.ndarray(...),
    },
    'layer1': {
        'w': np.ndarray(...),
        'b': np.ndarray(...),
    },
}
# Note: save_args may be None if no customization is desired for saved
# parameters.
# In this case, we "aggregate" small parameters into a single file to
# allow for greater file read/write efficiency (and potentially less)
# wasted space). With OCDBT format active, this parameter is obsolete.
save_args =
  jax.tree_util.tree_map(
      lambda x: SaveArgs(aggregate=x.size < some_size), item)
# Eventually calls through to `async_save`.
ckptr.save(path, item, save_args)
Parameters:
  • directory (Path) – save location directory.

  • item (Optional[Any]) – Deprecated, use `args.

  • save_args (Optional[PyTreeSaveArgs]) – Deprecated, use args.

  • args (Optional[PyTreeSaveArgs]) – PyTreeSaveArgs (see below).

Return type:

Optional[List[Future]]

Returns:

A Future that will commit the data to directory when awaited. Copying the data from its source will be awaited in this function.

save(directory, item=None, save_args=None, args=None)[source][source]#

Saves the provided item. See async_save.

restore(directory, item=None, restore_args=None, transforms=None, transforms_default_to_original=True, legacy_transform_fn=None, args=None)[source][source]#

Restores a PyTree from the checkpoint directory at the given path.

In the most basic case, only directory is required. The tree will be restored exactly as saved, and all leaves will be restored as the correct types (assuming the tree metadata is present).

However, restore_args is often required as well. This PyTree gives a RestoreArgs object (or subclass) for every leaf in the tree. Many types, such as string or np.ndarray do not require any special options for restoration. When restoring an individual leaf as jax.Array, however, some properties may be required.

One example is sharding, which defines how a jax.Array in the restored tree should be partitioned. mesh and mesh_axes can also be used to specify sharding, but sharding is the preferred way of specifying this partition since mesh and mesh_axes only constructs jax.sharding.NamedSharding. For more information, see ArrayTypeHandler documentation and JAX sharding documentation.

Example:

ckptr = Checkpointer(PyTreeCheckpointHandler())
restore_args = {
    'layer0': {
        'w': RestoreArgs(),
        'b': RestoreArgs(),
    },
    'layer1': {
        'w': ArrayRestoreArgs(
            # Restores as jax.Array, regardless of how it was saved.
            restore_type=jax.Array,
            sharding=jax.sharding.Sharding(...),
            # Warning: may truncate or pad!
            global_shape=(x, y),
          ),
        'b': ArrayRestoreArgs(
            restore_type=jax.Array,
            sharding=jax.sharding.Sharding(...),
            global_shape=(x, y),
          ),
    },
}
ckptr.restore(path, restore_args=restore_args)

Providing item is typically only necessary when restoring a custom PyTree class (or when using transformations). In this case, the restored object will take on the same structure as item.

Example:

@flax.struct.dataclass
class TrainState:
  layer0: dict[str, jax.Array]
  layer1: dict[str, jax.Array]

ckptr = Checkpointer(PyTreeCheckpointHandler())
train_state = TrainState(
    layer0={
        'w': jax.Array(...),  # zeros
        'b': jax.Array(...),  # zeros
    },
    layer1={
        'w': jax.Array(...),  # zeros
        'b': jax.Array(...),  # zeros
    },
)
restore_args = jax.tree_util.tree_map(_make_restore_args, train_state)
ckptr.restore(path, item=train_state, restore_args=restore_args)
# restored tree is of type `TrainState`.
Parameters:
  • directory (Path) – saved checkpoint location directory.

  • item (Optional[Any]) – Deprecated, use args.

  • restore_args (Optional[Any]) – Deprecated, use args.

  • transforms (Optional[Any]) – Deprecated, use args.

  • transforms_default_to_original (bool) – See transform_utils.apply_transformations.

  • legacy_transform_fn (Optional[Callable[[Any, Any, Any], Tuple[Any, Any]]]) – Deprecated, use args.

  • args (Optional[PyTreeRestoreArgs]) – PyTreeRestoreArgs (see below).

Return type:

Any

Returns:

A PyTree matching the structure of item.

Raises:
  • FileNotFoundErrordirectory does not exist or is missing required files

  • ValueErrortransforms is provided without item.

  • ValueErrortransforms contains elements with multi_value_fn.

CompositeCheckpointHandler#

class orbax.checkpoint.CompositeCheckpointHandler(*item_names, composite_options=CompositeOptions(primary_host=0, active_processes=None), **items_and_handlers)[source][source]#

CheckpointHandler for saving multiple items.

As with all CheckpointHandler implementations, use only in conjunction with an instance of AbstractCheckpointer.

CompositeCheckpointHandler allows dealing with multiple items of different types or logical distinctness, such as training state (PyTree), dataset iterator, JSON metadata, or anything else. The items managed by the CompositeCheckpointHandler must be specified at initialization.

For an individual item, CompositeCheckpointHandler provides two mechanisms for ensuring that the object gets saved and restored as the correct type and with the correct logic. The item-specific handler can either be (1) specified when the CompositeCheckpointHandler is created, or (2) it can be deferred (you just need to give the item name up-front). When deferred, the handler will be determined from which CheckpointArgs are provided during the first call to save or restore.

Usage:

ckptr = ocp.Checkpointer(
    ocp.CompositeCheckpointHandler('state', 'metadata')
)
ckptr.save(directory,
    ocp.args.Composite(
        # The handler now knows `state` uses `StandardCheckpointHandler`
        # and `metadata` uses `JsonCheckpointHandler`. Any subsequent calls
        # will have to conform to this assumption.
        state=ocp.args.StandardSave(pytree),
        metadata=ocp.args.JsonSave(metadata),
    )
)

restored: ocp.args.Composite = ckptr.restore(directory,
    ocp.args.Composite(
        state=ocp.args.StandardSave(abstract_pytree),
        # Only provide the restoration arguments you actually need.
        metadata=ocp.args.JsonRestore(),
    )
)
restored.state ...
restored.metadata ...

# Skip restoring `metadata` (you can save a subset of items too, in a
# similar fashion).
restored: ocp.args.Composite = ckptr.restore(directory,
    ocp.args.Composite(
        state=ocp.args.StandardRestore(abstract_pytree),
    )
)
restored.state ...
restored.metadata ... # Error

# If the per-item handler doesn't require any extra information in order to
# restore, in many cases you can use the following pattern if you just want
# to restore everything:
restored: ocp.args.Composite = ckptr.restore(directory)
restored.state ...
restored.metadata ...

ckptr = ocp.Checkpointer(
    ocp.CompositeCheckpointHandler(
        'state',
        metadata=ocp.JsonCheckpointHandler()
    )
)
ckptr.save(directory,
    ocp.args.Composite(
        state=ocp.args.StandardSave(pytree),
        # Error because `metadata` was specified to use JSON save/restore
        # logic, not `StandardCheckpointHandler`.
        metadata=ocp.args.StandardSave(metadata),
    )
)
__init__(*item_names, composite_options=CompositeOptions(primary_host=0, active_processes=None), **items_and_handlers)[source][source]#

Constructor.

All items must be provided up-front, at initialization.

Parameters:
  • *item_names – A list of string item names that this handler will manage.

  • composite_options (CompositeOptions) – Options.

  • **items_and_handlers – A mapping of item name to CheckpointHandler instance, which will be used as the handler for objects of the corresponding name.

async async_save(directory, args)[source][source]#

Saves multiple items to individual subdirectories.

Return type:

Optional[List[Future]]

save(*args, **kwargs)[source][source]#

Saves synchronously.

restore(directory, args=None)[source][source]#

Restores the provided item synchronously.

Parameters:
  • directory (Path) – Path to restore from.

  • args (Optional[CompositeArgs]) – CompositeArgs object used to restore individual sub-items. May be None or “empty” (CompositeArgs()), which is used to indicate that the handler should restore everything. If an individual item was not specified during the save, None will be returned for that item’s entry.

Return type:

CompositeArgs

Returns:

A CompositeResults object with keys matching CompositeArgs, or with keys for all known items as specified at creation.

metadata(directory)[source][source]#

Returns metadata about the saved item.

Ideally, this is a cheap way to collect information about the checkpoint without requiring a full restoration.

Parameters:

directory (Path) – the directory where the checkpoint is located.

Return type:

CompositeArgs

Returns:

item metadata

finalize(directory)[source][source]#

Optional, custom checkpoint finalization callback.

Will be called on a single worker after all workers have finished writing.

Parameters:

directory (Path) – the directory where the checkpoint is located.

close()[source][source]#

Closes the CheckpointHandler.

JsonCheckpointHandler#

class orbax.checkpoint.JsonCheckpointHandler(filename=None, primary_host=0)[source][source]#

Saves nested dictionary using json.

__init__(filename=None, primary_host=0)[source][source]#

Initializes JsonCheckpointHandler.

Parameters:
  • filename (Optional[str]) – optional file name given to the written file; defaults to ‘metadata’

  • primary_host (Optional[int]) – the host id of the primary host. Default to 0. If it’s set to None, then all hosts will be considered as primary. It’s useful in the case that all hosts are only working with local storage.

save(directory, item=None, args=None)[source][source]#

Saves the given item.

Parameters:
  • directory (Path) – save location directory.

  • item (Optional[Mapping[str, Any]]) – Deprecated, use args instead.

  • args (Optional[JsonSaveArgs]) – JsonSaveArgs (see below).

restore(directory, item=None, args=None)[source][source]#

Restores json mapping from directory.

item is unused.

Parameters:
  • directory (Path) – restore location directory.

  • item (Optional[Mapping[str, Any]]) – unused

  • args (Optional[JsonRestoreArgs]) – unused

Return type:

bytes

Returns:

Binary data read from directory.

ArrayCheckpointHandler#

class orbax.checkpoint.ArrayCheckpointHandler(checkpoint_name=None)[source][source]#

Handles saving and restoring individual arrays and scalars.

__init__(checkpoint_name=None)[source][source]#

Initializes the handler.

Parameters:

checkpoint_name (Optional[str]) – Provides a name for the directory under which Tensorstore files will be saved. Defaults to ‘checkpoint’.

async async_save(directory, item=None, save_args=None, args=None)[source][source]#

Saves an object asynchronously.

Parameters:
  • directory (Path) – Folder in which to save.

  • item (Union[int, float, number, ndarray, Array, None]) – Deprecated, use args.

  • save_args (Optional[SaveArgs]) – Deprecated, use args.

  • args (Optional[ArraySaveArgs]) – An ocp.array_checkpoint_handler.ArraySaveArgs (see below).

Return type:

Optional[List[Future]]

Returns:

A list of commit futures which can be run to complete the save.

save(directory, *args, **kwargs)[source][source]#

Saves an array synchronously.

restore(directory, item=None, restore_args=None, args=None)[source][source]#

Restores an object.

Parameters:
  • directory (Path) – folder from which to read.

  • item (Union[int, float, number, ndarray, Array, None]) – Deprecated, use args.

  • restore_args (Optional[RestoreArgs]) – Deprecated, use args.

  • args (Optional[ArrayRestoreArgs]) – An ocp.array_checkpoint_handler.ArrayRestoreArgs object (see below).

Return type:

Union[int, float, number, ndarray, Array]

Returns:

The restored object.

finalize(directory)[source][source]#

Optional, custom checkpoint finalization callback.

Will be called on a single worker after all workers have finished writing.

Parameters:

directory (Path) – the directory where the checkpoint is located.

close()[source][source]#

See superclass documentation.

ProtoCheckpointHandler#

class orbax.checkpoint.ProtoCheckpointHandler(filename, primary_host=0)[source][source]#

Serializes/deserializes protocol buffers.

__init__(filename, primary_host=0)[source][source]#

Initializes ProtoCheckpointHandler.

Parameters:
  • filename (str) – file name given to the written file.

  • primary_host (Optional[int]) – primary host to write on. If None, writes on all hosts.

async async_save(directory, item=None, args=None)[source][source]#

Saves the given proto.

Parameters:
  • directory (Path) – save location directory.

  • item (Optional[Message]) – Deprecated, use args.

  • args (Optional[ProtoSaveArgs]) – ProtoSaveArgs (see below).

Return type:

Optional[List[Future]]

Returns:

A commit future.

save(*args, **kwargs)[source][source]#

Saves the provided item.

restore(directory, item=None, args=None)[source][source]#

Restores the proto from directory.

Parameters:
  • directory (Path) – restore location directory.

  • item (Optional[Type[Message]]) – Deprecated, use args.

  • args (Optional[ProtoRestoreArgs]) – ProtoRestoreArgs (see below).

Returns:

The deserialized proto read from directory if item is not None

close()[source][source]#

Closes the CheckpointHandler.

JaxRandomKeyCheckpointHandler#

class orbax.checkpoint.JaxRandomKeyCheckpointHandler(key_name=None)[source][source]#

Handles saving and restoring individual Jax random key in both typed and untyped format.

save(directory, *args, **kwargs)[source]#

Saves a random key synchronously.

async async_save(directory, args)[source]#

Saves a random key asynchronously.

Parameters:
  • directory (Path) – Folder in which to save.

  • args (CheckpointArgs) – An ocp.checkpoint_args.CheckpointArgs.

Return type:

Optional[List[Future]]

Returns:

A list of commit futures which can be run to complete the save.

restore(directory, args)[source]#

Restores a random key.

Parameters:
  • directory (Path) – folder from which to read.

  • args (CheckpointArgs) – An ocp.checkpoint_args.CheckpointArgs.

Return type:

Any

Returns:

The restored object.

NumpyRandomKeyCheckpointHandler#

class orbax.checkpoint.NumpyRandomKeyCheckpointHandler(key_name=None)[source][source]#

Saves Nnumpy random key in legacy or non-lagacy format.

save(directory, *args, **kwargs)[source]#

Saves a random key synchronously.

async async_save(directory, args)[source]#

Saves a random key asynchronously.

Parameters:
  • directory (Path) – Folder in which to save.

  • args (CheckpointArgs) – An ocp.checkpoint_args.CheckpointArgs.

Return type:

Optional[List[Future]]

Returns:

A list of commit futures which can be run to complete the save.

restore(directory, args)[source]#

Restores a random key.

Parameters:
  • directory (Path) – folder from which to read.

  • args (CheckpointArgs) – An ocp.checkpoint_args.CheckpointArgs.

Return type:

Any

Returns:

The restored object.