orbax.checkpoint V1 API Reference#

Submodules#

Top-level Symbols#

Types#

orbax.checkpoint.v1.PLACEHOLDER = Ellipsis#

Loading#

orbax.checkpoint.v1.load(path, abstract_state=None, *, checkpointable_name='AUTO')[source]#

Loads a PyTree.

Loads from a PyTree checkpoint. A PyTree checkpoint must be a path containing a subdirectory with the name provided by checkpointable_name, with default value AUTO. See checkpointable_name for more details.

This function must be called on all available controller processes.

The operation blocks until complete. For improved performance, consider using load_async() instead.

If abstract_state is not provided, the PyTree will be loaded exactly as saved.

IMPORTANT: Loading is more brittle and error-prone when not providing abstract_state. Always provide abstract_state if possible. Note that you can always obtain the tree structure from a saved checkpoint using metadata().

Providing the abstract_state guarantees two things:

1. The restored tree will exactly match the structure of abstract_state (or raise an error if it is impossible to guarantee this). For example, if abstract_state is a custom object registered as a PyTree, the checkpoint will be restored as the same object, if possible.

2. The leaves of the restored tree will be restored with the properties indicated by the abstract leaves. For example, if a leaf in abstract_state is a jax.ShapeDtypeStruct, the restored leaf will be a jax.Array with the same shape and dtype. Each AbstractLeaf has a corresponding Leaf that is restored. See orbax.checkpoint.v1.tree for a table of standard supported leaf types.

Example Usage:

Load a saved PyTree with and without providing its abstract structure:

path = '/tmp/my_checkpoint'

# Save a checkpoint
state = {'a': jnp.arange(8), 'b': jnp.zeros(4)}
ocp.save(path, state)

# Load the checkpoint
# Highly recommended to provide the abstract pytree (structure/shapes)
abstract_state = jax.eval_shape(lambda: state)

# Method A: Load using the abstract structure.
# This automatically looks for the 'pytree' subdirectory inside 'path'.
restored = ocp.load(path, abstract_state)

# Method B: Infer structure from file (Not recommended for production use)
# cases or for complex trees.
restored_inferred = ocp.load(path)
Parameters:
  • path (UnionType[Path, str]) – The path to load the checkpoint from. This path must contain a subdirectory with name provided by checkpointable_name. See checkpointable_name for more details.

  • abstract_state (Union[PyTreeOf[UnionType[AbstractArray, AbstractShardedArray, int, float, number, bytes, bool, str]], CheckpointMetadata[PyTreeOf[UnionType[AbstractArray, AbstractShardedArray, int, float, number, bytes, bool, str]]], None]) – Provides a tree structure for the checkpoint to be restored into. May be omitted to load exactly as saved, but this is much more brittle than providing the tree.

  • checkpointable_name (UnionType[str, None]) – The name of the checkpointable to load. A subdirectory with this name must exist in path. If None, then path itself is expected to contain all files relevant for loading the PyTree, rather than any subdirectory. Such files include, for example, manifest.ocdbt, _METADATA, ocp.process_X. Defaults to AUTO. Setting to AUTO mode dynamically discovers and resolves a pytree checkpointable. It prioritizes the standard ‘pytree’ checkpointable name if present, then sorts any other valid pytree checkpointable names alphabetically and returns the first valid one, and ultimately falls back to interpreting the path as a flat V0 root layout if no standard pytree exists.

Return type:

PyTreeOf[UnionType[Array, ndarray, int, float, number, bytes, bool, str]]

Returns:

The restored PyTree.

orbax.checkpoint.v1.load_async(path, abstract_state=None, *, checkpointable_name='state')[source]#

Loads a PyTree asynchronously. Currently has limited support.

Return type:

AsyncResponse[PyTreeOf[UnionType[Array, ndarray, int, float, number, bytes, bool, str]]]

orbax.checkpoint.v1.load_checkpointables(path, abstract_checkpointables=None)[source]#

Loads checkpointables.

See documentation for save_checkpointables() for more context on what a checkpointable is.

This function can be used to load any checkpoint saved by save_checkpointables() (or save()). The path should contain a number of subdirectories - each of these represents the name of a checkpointable.

This function must be called on all available controller processes.

The operation blocks until complete. For improved performance, consider using load_checkpointables_async() instead.

If abstract_checkpointables is not provided, the checkpointables will be loaded exactly as saved.

IMPORTANT: Loading is more brittle and error-prone when not providing abstract_checkpointables. Always provide abstract_checkpointables if possible. Note that you can always obtain the information about the checkpointables using checkpointables_metadata().

If abstract_checkpointables is provided, the value provided for each key is treated as the abstract type for the given checkpointable. For example, for a PyTree of jax.Array, the corresponding abstract checkpointable is a PyTree of jax.ShapeDtypeStruct. None is always a valid abstract checkpointable, which just indicates that the checkpointable should be loaded exactly as saved.

The keys provided in abstract_checkpointables may be any subset of the checkpointables in the checkpoint. Any checkpointables names not provided in abstract_checkpointables will not be loaded.

Example Usage:

Load checkpointables from a saved checkpoint:

path = '/tmp/my_checkpoint_step_100'

# Save multiple components (checkpointables)
params = {'w': jnp.ones((8, 8)), 'b': jnp.zeros(8)}
opt_state = {'count': jnp.array(100)}

# Setup Grain (Stateful Checkpointable)
import grain
dataset_iter = iter(
    grain.MapDataset.range(30)
    .batch(3)
    .map(lambda x: x.tolist())
)

ocp.save_checkpointables(path, {
    'model': params,
    'optimizer': opt_state,
    'dataset': dataset_iter,
})

# Load the checkpointables
abstract_params = jax.eval_shape(lambda: params)
abstract_opt = jax.eval_shape(lambda: opt_state)

abstract_checkpointables = {
    'model': abstract_params,
    'optimizer': abstract_opt,
    # Dataset is restored statefully. An initialized object must be
    # passed, but its position will be set to the position recorded in the
    # checkpoint after restoring.
    'dataset': dataset_iter,
}

# Load all components
restored = ocp.load_checkpointables(path, abstract_checkpointables)

# Load only a subset
restored_subset = ocp.load_checkpointables(
    path,
    {'model': abstract_params}
)
Parameters:
  • path (UnionType[Path, str]) – The path to load the checkpoint from. This path must contain a subdirectory for each checkpointable.

  • abstract_checkpointables (Union[dict[str, AbstractCheckpointable], CheckpointMetadata[dict[str, AbstractCheckpointable]], None]) – A dictionary of abstract checkpointables. Dictionary keys represent the names of the checkpointables, while the values are the abstract checkpointable objects themselves.

Return type:

dict[str, Checkpointable]

Returns:

A dictionary of checkpointables. Dictionary keys represent the names of the checkpointables, while the values are the checkpointable objects themselves.

Raises:

FileNotFoundError – If the checkpoint path does not exist.

orbax.checkpoint.v1.load_checkpointables_async(path, abstract_checkpointables=None)[source]#

Loads checkpointables asynchronously. Not yet implemented.

Return type:

AsyncResponse[dict[str, Checkpointable]]

Saving#

orbax.checkpoint.v1.save(path, state, *, checkpointable_name='state', overwrite=False, custom_metadata=None)[source]#

Saves a PyTree.

The operation blocks until complete. For improved performance, consider using save_async() instead. This function should be called on all available controller processes.

Example usage:
Simple save of a dictionary containing JAX arrays::
state = {
‘params’: {

‘w’: jnp.ones((8, 8)), ‘b’: jnp.zeros(8),

}, ‘step’: 100

} # Saves to /tmp/my_checkpoint/ ocp.save(‘/tmp/my_checkpoint’, state)

Parameters:
  • path (UnionType[Path, str]) – The path to save the checkpoint to.

  • state (PyTreeOf[UnionType[Array, ndarray, int, float, number, bytes, bool, str]]) – The PyTree to save. This may be any JAX PyTree (including custom objects registered as PyTrees) consisting of supported leaf types. See orbax.checkpoint.experimental.v1.tree for a table of standard supported leaf types.

  • checkpointable_name (str) – The name of the checkpointable to save a pytree under. Defaults to ‘pytree’.

  • overwrite (bool) – If True, fully overwrites an existing checkpoint in path. Otherwise, raises an error if the checkpoint already exists.

  • custom_metadata (UnionType[list[JsonValue], dict[str, JsonValue], None]) – User-provided custom metadata. An arbitrary JSON-serializable dictionary the user can use to store additional information. The field is treated as opaque by Orbax.

orbax.checkpoint.v1.save_async(path, state, *, checkpointable_name='state', overwrite=False, custom_metadata=None)[source]#

Saves a PyTree asynchronously.

Unlike save(), this function returns immediately after the save operation is scheduled (except for certain operations, like device-to-host copying of on-device arrays, which must happen on the main thread). Further writing operations continue in a background thread. An AsyncResponse is returned that can be used to block until the save is complete (using response.result()). Make sure to wait for completion before attempting to load the checkpoint or exiting the program. This function should be called on all available controller processes.

Example usage:

Simple save of a dictionary containing JAX arrays asynchronously:

state = {
    'params': {
        'w': jnp.ones((8, 8)),
        'b': jnp.zeros(8),
    },
    'step': 100
}
# Saves to /tmp/my_checkpoint/
future = ocp.experimental.v1.save_async(
    '/tmp/my_checkpoint', state
)

# Perform other work here...

# Wait for completion only when necessary
future.result()
Parameters:
  • path (UnionType[Path, str]) – The path to save the checkpoint to.

  • state (PyTreeOf[UnionType[Array, ndarray, int, float, number, bytes, bool, str]]) – The PyTree to save. This may be any JAX PyTree (including custom objects registered as PyTrees) consisting of supported leaf types. See orbax.checkpoint.v1.tree for a table of standard supported leaf types.

  • checkpointable_name (str) – The name of the checkpointable to save a pytree under. Defaults to ‘pytree’.

  • overwrite (bool) – If True, fully overwrites an existing checkpoint in path. Otherwise, raises an error if the checkpoint already exists.

  • custom_metadata (UnionType[list[JsonValue], dict[str, JsonValue], None]) – User-provided custom metadata. An arbitrary JSON-serializable dictionary the user can use to store additional information. The field is treated as opaque by Orbax.

Return type:

AsyncResponse[None]

Returns:

An AsyncResponse that can be used to block until the save is complete. Blocking can be done using response.result(), which returns None.

orbax.checkpoint.v1.save_checkpointables(path, checkpointables, *, overwrite=False, custom_metadata=None)[source]#

Saves a dictionary of checkpointables.

A “checkpointable” refers to a logical piece of the checkpoint that is distinct in some way from other pieces. Checkpointables are separable; they may or may not be loaded concurrently and some may be omitted from the checkpoint entirely. Checkpointables are often represented by different types, and have different representations on disk. The quintessential example is model params vs. dataset.

For example, one might do:

ocp.save_checkpointables(
    path,
    {
        'params': pytree_of_arrays,
        'dataset': pygrain.DatasetIterator(...),
    }
)

It is also possible to do:

train_state = {
    'params': params_pytree_of_arrays,
    'opt_state': opt_state_pytree_of_arrays,
    'step': step,
    ...
}
ocp.save_checkpointables(path, train_state)

This is not the ideal way of doing things because it is then difficult to run transformations that involve the entire train state (see the load_and_transform API).

This function should be called on all available controller processes.

Parameters:
  • path (UnionType[Path, str]) – The path to save the checkpoint to.

  • checkpointables (dict[str, Checkpointable]) – A dictionary of checkpointables. Dictionary keys represent the names of the checkpointables, while the values are the checkpointable objects themselves.

  • overwrite (bool) – If True, fully overwrites an existing checkpoint in path. Otherwise, raises an error if the checkpoint already exists.

  • custom_metadata (UnionType[list[JsonValue], dict[str, JsonValue], None]) – User-provided custom metadata. An arbitrary JSON-serializable dictionary the user can use to store additional information. The field is treated as opaque by Orbax.

Return type:

None

orbax.checkpoint.v1.save_checkpointables_async(path, checkpointables, *, overwrite=False, custom_metadata=None)[source]#

Saves a dictionary of checkpointables asynchronously.

See save_checkpointables() documentation.

Unlike save_checkpointables(), this function returns immediately after the save operation is scheduled (except for certain operations, like device-to-host copying of on-device arrays, which must happen on the main thread). Further writing operations continue in a background thread. An AsyncResponse is returned that can be used to block until the save is complete (using response.result()). Make sure to wait for completion before attempting to load the checkpoint or exiting the program. This function should be called on all available controller processes.

Example usage:

Saving multiple distinct components (e.g. model parameters and dataset iterator) asynchronously:

path = '/tmp/my_checkpoint_step_100'

# Setup components
params = {'w': jnp.ones((8, 8)), 'b': jnp.zeros(8)}

# Setup Grain iterator (Stateful Checkpointable)
import grain
dataset_iter = iter(
    grain.MapDataset.range(30)
    .batch(3)
    .map(lambda x: x.tolist())
)

# Save multiple components
checkpointables = {
    'model': params,
    'dataset': dataset_iter,
}

# Start the async save
response = ocp.save_checkpointables_async(path, checkpointables)

# Perform other operations here...

# Wait for the save to finish
response.result()
Parameters:
  • path (UnionType[Path, str]) – The path to save the checkpoint to.

  • checkpointables (dict[str, Checkpointable]) – A dictionary of checkpointables. Dictionary keys represent the names of the checkpointables, while the values are the checkpointable objects themselves.

  • overwrite (bool) – If True, fully overwrites an existing checkpoint in path. Otherwise, raises an error if the checkpoint already exists.

  • custom_metadata (UnionType[list[JsonValue], dict[str, JsonValue], None]) – User-provided custom metadata. An arbitrary JSON-serializable dictionary the user can use to store additional information. The field is treated as opaque by Orbax.

Return type:

AsyncResponse[None]

Returns:

An AsyncResponse that can be used to block until the save is complete. Blocking can be done using response.result(), which returns None.

Metadata#

orbax.checkpoint.v1.metadata(path, checkpointable_name='AUTO')[source]#

Loads the PyTree metadata from a checkpoint.

This function retrieves metadata for a PyTree checkpoint, returning an object of type CheckpointMetadata[PyTreeMetadata]. Please see documentation on this class for further details.

In short, the returned object contains a metadata attribute (among other attributes like timestamps), which is an instance of PyTreeMetadata. The PyTreeMetadata describes information specific to the PyTree itself. The most important such property is the PyTree structure, which is a tree structure matching the structure of the checkpointed PyTree, with leaf metadata objects describing each leaf.

For example:

metadata = ocp.metadata(path)  # CheckpointMetadata[PyTreeMetadata]
metadata.metadata # PyTreeMetadata
metadata.init_timestamp_nsecs  # Checkpoint creation timestamp.

metadata.metadata  # PyTree structure.

The metadata can then be used to inform checkpoint loading. For example:

metadata = ocp.metadata(path)
restored = ocp.load(path, metadata)

# Load with altered properties.
def _get_abstract_array(arr):
  # Assumes all checkpoint leaves are array types.
  new_dtype = ...
  new_sharding = ...
  return jax.ShapeDtypeStruct(arr.shape, new_dtype, sharding=new_sharding)

metadata = dataclasses.replace(metadata,
      metadata=jax.tree.map(_get_abstract_array, metadata.metadata)
)
ocp.load(path, metadata)
Parameters:
  • path (UnionType[Path, str]) – The path to the checkpoint.

  • checkpointable_name (UnionType[str, None]) – The name of the checkpointable to load. A subdirectory with this name must exist in path. If None, then path itself is expected to contain all files relevant for loading the PyTree, rather than any subdirectory. Such files include, for example, manifest.ocdbt, _METADATA, ocp.process_X. Defaults to AUTO. Setting to AUTO mode dynamically discovers and resolves a pytree checkpointable. It prioritizes the standard ‘pytree’ checkpointable name if present, then sorts any other valid pytree checkpointable names alphabetically and returns the first valid one, and ultimately falls back to interpreting the path as a flat V0 root layout if no standard pytree exists.

Return type:

CheckpointMetadata[PyTreeOf[UnionType[AbstractArray, AbstractShardedArray, int, float, number, bytes, bool, str]]]

Returns:

A CheckpointMetadata[PyTreeMetadata] object.

orbax.checkpoint.v1.checkpointables_metadata(path)[source]#

Loads all checkpointables metadata from a checkpoint.

This function is a more general version of pytree_metadata. The same CheckpointMetadata object is returned (with properties like init_timestamp_nsecs as shown above), but the type of the core metadata property is a dictionary, mapping checkpointable names to their metadata. This mirrors the return value of load_checkpointables, which similarly returns a dictionary mapping checkpointable names to their loaded values.

For example:

ocp.save_checkpointables(path, {
    'foo': Foo(),
    'bar': Bar(),
})
metadata = ocp.checkpointables_metadata(path)
metadata.metadata  # {'foo': AbstractFoo(), 'bar': AbstractBar()}
Parameters:

path (UnionType[Path, str]) – The path to the checkpoint.

Return type:

CheckpointMetadata[dict[str, AbstractCheckpointable]]

Returns:

A CheckpointMetadata[dict[str, Any]] object.

orbax.checkpoint.v1.PyTreeMetadata#

alias of PyTreeOf[AbstractArray | AbstractShardedArray | int | float | number | bytes | bool | str]

final class orbax.checkpoint.v1.CheckpointMetadata(path, *, metadata, init_timestamp_nsecs=None, commit_timestamp_nsecs=None, custom_metadata=None)[source]#

Represents complete metadata describing a checkpoint.

Note that this class has a generic type CheckpointableMetadataT. This will typically be either PyTreeMetadata (see above), or dict[str, AbstractCheckpointable].

CheckpointMetadata can be accessed via one of two metadata methods. Please see metadata() and checkpointables_metadata() for more information and usage instructions.

If the checkpoint contains a PyTree, this metadata can be accessed via:

metadata = ocp.metadata(path)

# Inspect various properties
metadata.init_timestamp_nsecs

# Inspect the tree structure
metadata.metadata.pytree
metadata.metadata.pytree['layer0']['bias'].shape
metadata.metadata.pytree['layer0']['bias'].dtype

The checkpoint metadata can also be accessed more generically via:

metadata = ocp.checkpointables_metadata(path)

metadata.metadata.keys()  # == ['pytree', 'dataset', etc.]
metadata.metadata['pytree']  # instance of PyTreeMetadata
metadata#

Metadata for the checkpointable.

init_timestamp_nsecs#

The timestamp when the uncommitted checkpoint was initialized, specified in nanoseconds since the epoch. Defaults to None.

commit_timestamp_nsecs#

The commit timestamp of a checkpoint, specified in nanoseconds since the epoch. Defaults to None.

custom_metadata#

User-provided custom metadata. An arbitrary JSON-serializable dictionary the user can use to store additional information. The field is treated as opaque by Orbax.

Path Utilities#

orbax.checkpoint.v1.is_orbax_checkpoint(path)[source]#

Returns True if the path is an Orbax checkpoint.

Return type:

bool

Synchronization#

class orbax.checkpoint.v1.AsyncResponse(*args, **kwargs)[source]#

Checkpointable Handlers#

class orbax.checkpoint.v1.StatefulCheckpointable(*args, **kwargs)[source]#

An interface that defines save/load logic for a checkpointable object.

class orbax.checkpoint.v1.CheckpointableHandler(*args, **kwargs)[source]#

An interface that defines save/load logic for a checkpointable object.

NOTE: Prefer to use StatefulCheckpointable interface when possible.

A PyTree of arrays, representing model parameters, is the most basic “checkpointable”. A singular array is also a checkpointable.

In most contexts, when dealing with just a PyTree, the API of choice is:

ocp.save(directory, pytree)

The concept of “checkpointable” is not so obvious in this case. When dealing with multiple objects, we can use:

ocp.save_checkpointables(
    directory,
    dict(
        pytree=model_params,
        dataset=dataset_iterator,
        # other checkpointables, e.g. extra metadata, etc.
    ),
)

Now, it is easy to simply skip loading the dataset, as is commonly desired when running evals or inference:

ocp.load_checkpointables(
    directory,
    dict(
        pytree=abstract_model_params,
    ),
)
# Equivalently,
ocp.load(directory, abstract_model_params)

With the methods defined in this Protocol (save, load), logic within the method itself is executed in the main thread, in a blocking fashion. Additional logic can be executed in the background by returning an Awaitable function (which itself may return a result).

Let’s look at some suggestions on how to implement a CheckpointableHandler.

To create a custom handler, you must define a class that implements the methods defined in this Protocol. The class should be generic over the concrete type Checkpointable (the object being saved/loaded) and the abstract type AbstractCheckpointable (the lightweight metadata representation).

Crucially, once implemented, the handler must be registered with the global registry or a context-local registry so that save_checkpointables and load_checkpointables can automatically detect and use it for the corresponding types. Use orbax.checkpoint.v1.handlers.register_handler for global registration, or provide handlers via orbax.checkpoint.v1.context.CheckpointablesOptions for context-local registration.

First, take a look at orbax/checkpoint/experimental/v1/_src/testing/handler_utils.py for some toy implementations used for unit testing.

Here are some details on how to implement is_handleable and is_abstract_handleable.

For example, if a handler may be defined as follows:

class FooHandler(CheckpointableHandler[Foo, AbstractFoo]):

  def is_handleable(self, checkpointable: Foo) -> bool:
    return isinstance(foo, Foo)

  def is_abstract_handleable(
      self, abstract_checkpointable: AbstractFoo) -> bool:
    return isinstance(abstract_foo, AbstractFoo)

This is simple because the handler only works with Foo and AbstractFoo. But the handler may work on more generic types. In a toy example, let’s say we’ve developed an improved way of storing very large arrays, which is still suboptimal for more normal-sized arrays. We can implement the handler as:

class FooHandler(CheckpointableHandler[jax.Array, jax.ShapeDtypeStruct]):

  def is_handleable(self, checkpointable: jax.Array) -> bool:
    return (
        isinstance(checkpointable, jax.Array)
        and checkpointable.size > LARGE_ARRAY_THRESHOLD
    )

  def is_abstract_handleable(
      self, abstract_checkpointable: jax.ShapeDtypeStruct) -> bool:
    return (
        isinstance(abstract_checkpointable, jax.ShapeDtypeStruct)
        and abstract_checkpointable.size > LARGE_ARRAY_THRESHOLD
    )

In many cases, no information is needed for loading. In this case, AbstractCheckpointable may be defined as None. For example:

class FooHandler(CheckpointableHandler[Foo, None]):

  def is_handleable(self, checkpointable: Foo) -> bool:
    return isinstance(checkpointable, Foo)

  def is_abstract_handleable(self, abstract_checkpointable: None) -> bool:
    return abstract_checkpointable is None

Context#

final class orbax.checkpoint.v1.Context(context=None, *, pytree_options=None, array_options=None, async_options=None, multiprocessing_options=None, file_options=None, checkpointables_options=None, pathways_options=None, checkpoint_layout=None, deletion_options=None, memory_options=None, safetensors_options=None)[source]#

Context for customized checkpointing.

This class manages the configuration options (e.g., async, multiprocessing, array handling) used during Orbax checkpoint operations.

Creating a new Context within an existing Context sets all parameters from scratch by default. To inherit properties from a parent Context, you must explicitly pass the parent context as the first argument. The new context will inherit the parent’s properties, except for any options explicitly provided as keyword arguments to the child context.

WARNING: The context is thread-local and is not shared across threads. The entire context block must be executed within the same thread. If you dispatch a checkpointing operation to a worker thread (e.g., via ThreadPoolExecutor), that thread will not inherit the context and will fall back to default settings.

Note: When testing or mixing checkpointer instances and free functions, explicitly wrap free functions inside their own with ocp.Context(…) block, or pass explicit contexts to Checkpointer constructors, to ensure each actor receives its correct active configuration independent of the surrounding context.

Example

Basic usage and explicit inheritance:

import orbax.checkpoint as ocp

# Basic usage
with ocp.Context(pytree_options=ocp.options.PyTreeOptions()):
  ocp.save(directory, tree)

# Inheriting properties from an existing context
with ocp.Context(pytree_options=ocp.options.PyTreeOptions()) as outer_ctx:
  # inner_ctx inherits pytree_options, but overrides/adds array_options
  with ocp.Context(outer_ctx,
      array_options=ocp.options.ArrayOptions()
      ) as inner_ctx:
    ocp.save(directory, tree)

Context is not shared across threads:

from concurrent.futures import ThreadPoolExecutor
import orbax.checkpoint as ocp

executor = ThreadPoolExecutor(max_workers=1)
with ocp.Context(
    pytree_options=ocp.options.PyTreeOptions()
):  # Thread #1 creates Context.
  # The following save call is executed in Thread #2, which sees
  # a "default" Context, NOT the one created above.
  executor.submit(ocp.save, directory, tree)
pytree_options#

Options for PyTree checkpointing. See PyTreeOptions.

array_options#

Options for saving and loading array (and array-like objects). See ArrayOptions.

async_options#

Options for controlling asynchronous behavior. See AsyncOptions.

multiprocessing_options#

Options for multiprocessing behavior. See MultiprocessingOptions.

file_options#

Options for working with the file system. See FileOptions.

checkpointables_options#

Options for controlling checkpointables behavior. See CheckpointablesOptions.

pathways_options#

Options for Pathways checkpointing. See PathwaysOptions.

checkpoint_layout#

The layout of the checkpoint. Defaults to ORBAX. See CheckpointLayout.

deletion_options#

Options for controlling deletion behavior. See DeletionOptions.

memory_options#

Options for controlling memory limits during save / load. See MemoryOptions.