ocp.v1.partial module#
Public API for partial saving checkpoints.
Saving#
- orbax.checkpoint.experimental.v1.partial.save(path, state, *, custom_metadata=None)[source][source]#
Partially saves a PyTree.
This function allows for incrementally updating a checkpoint. It is designed to be called multiple times. The first call initiates a new partial save “session” in a temporary location. Subsequent calls will update this session by modifying the checkpoint in place.
The operation is atomic; if it is interrupted, the previous version of the partial save will be preserved.
IMPORTANT: The checkpoint is not finalized at the target path until
finalize()is called. The intermediate checkpoints are temporary and should not be used directly.### Workflow
A typical partial save workflow involves one or more calls to
save()followed by a single call tofinalize():path = '/path/to/my/checkpoint' # The first call creates a temporary directory: # '/path/to/my/checkpoint.partial_save' # Note: the exact temporary directory name is an implementation detail that # depends on the file system and should not be relied on. ocp.partial.save(path, {'layer1': ..., 'step': 1}) # A subsequent call reads the previous version and applies new updates # to the temporary directory: # '/path/to/my/checkpoint.partial_save' ocp.partial.save(path, {'layer2': ..., 'metrics': ...}) # This call commits the latest version to the final destination at # '/path/to/my/checkpoint'. ocp.partial.finalize(path)
### Additions vs. Replacements
The provided state represents a set of updates. - If a key in state (e.g., ‘metrics’) does not exist in the on-disk
checkpoint, it is treated as an addition. In other words, the sets of keys of the on-disk PyTree and the provided state are disjoint.
If a key (e.g., ‘step’) already exists, its value is replaced. In other words, the sets of keys of the on-disk PyTree and the provided state overlap. Replacements are currently NOT supported. Please reach out to the Orbax team if you need this functionality.
See
save()for general PyTree saving documentation.- Parameters:
path (
UnionType[Path,str]) – The path to save the checkpoint to.state (
PyTreeOf[UnionType[Array,ndarray,int,float,number,bytes,bool,str]]) – A PyTree representing the additions to be applied to the on-disk checkpoint.custom_metadata (
UnionType[list[JsonValue],dict[str, JsonValue],None]) – User-provided custom metadata. This will be merged with any existing custom metadata. Values from this dictionary will overwrite existing values if keys conflict.
- orbax.checkpoint.experimental.v1.partial.save_async(path, state, *, custom_metadata=None)[source][source]#
Partially saves a PyTree asynchronously.
Unlike
save(), this function returns anAsyncResponseimmediately after scheduling the save operation. The actual writing to disk happens in a background thread. You can use response.result() to block until the operation is complete.This function allows for incrementally updating a checkpoint. It is designed to be called multiple times. The first call initiates a new partial save “session” in a temporary location. Subsequent calls will update this session by creating a new version that includes all previous changes plus the new ones.
The operation is atomic; if it is interrupted, the previous version of the partial save will be preserved.
IMPORTANT: The checkpoint is not finalized at the target path until
finalize()is called. The intermediate checkpoints are temporary and may be garbage collected in certain environments.### Workflow
A typical partial save workflow involves one or more calls to
save_async()followed by a single call tofinalize():path = '/path/to/my/checkpoint' # The first call creates a temporary directory and returns immediately. response1 = ocp.partial.save_async(path, {'layer1': ..., 'step': 1}) # A subsequent call also returns immediately. Orbax ensures that this # operation waits for the first one to complete before starting. response2 = ocp.partial.save_async( path, {'layer2': ..., 'metrics': ...} ) # Wait for all async partial saves to complete before finalizing. response1.result() response2.result() # This call commits the latest version to the final destination at # '/path/to/my/checkpoint'. ocp.partial.finalize(path)
### Additions vs. Replacements
The provided state represents a set of updates. - If a key in state (e.g., ‘metrics’) does not exist in the on-disk
checkpoint, it is treated as an addition.
If a key (e.g., ‘step’) already exists, its value is replaced. Replacements are currently NOT supported. Please reach out to the Orbax team if you need this functionality.
See
save_async()for general PyTree saving documentation.- Parameters:
path (
UnionType[Path,str]) – The path to save the checkpoint to.state (
PyTreeOf[UnionType[Array,ndarray,int,float,number,bytes,bool,str]]) – The PyTree to save. This may be any JAX PyTree consisting of supported leaf types (seeLeaf). Default supported leaf types include jax.Array, np.ndarray, simple types like int, float, str, and empty nodes.custom_metadata (
UnionType[list[JsonValue],dict[str, JsonValue],None]) – User-provided custom metadata. An arbitrary JSON-serializable dictionary the user can use to store additional information. The field is treated as opaque by Orbax.
- Return type:
AsyncResponse[None]- Returns:
An
AsyncResponsethat can be used to block until the save is complete. Blocking can be done using response.result(), which returns None.- Raises:
FileExistsError – If a finalized checkpoint already exists at path. To overwrite, it must be deleted first.
- orbax.checkpoint.experimental.v1.partial.finalize(path)[source][source]#
Finalizes a partially-saved checkpoint, making it permanent and readable.
This function commits all changes made during a partial save session, concluding the transaction. It should be called once after all desired
save()operations are complete.The finalization process is atomic. It renames the temporary, versioned partial save directory to the final target path, making the updated checkpoint “live”.
IMPORTANT: Until finalize is called, the checkpoint at the target path is not created or modified. All changes are buffered in a temporary location. This function is what makes those changes permanent.
- ### Example::
path = ‘/path/to/my/checkpoint’
# These calls write to a temporary, versioned directory, not the final path. ocp.partial.save(path, {‘step’: 1}) ocp.partial.save_checkpointables(path, {‘metrics’: …})
# This call performs the atomic rename, making the checkpoint available at # ‘/path/to/my/checkpoint’. ocp.partial.finalize(path)
- Parameters:
path (
UnionType[Path,str]) – The final, target path of the checkpoint to be finalized. This should be the same path that was passed tosave()calls.- Raises:
FileExistsError – If a finalized checkpoint already exists at path. To overwrite, it must be deleted first.
FileNotFoundError – If no partial save session is found for the given path. This can happen if
save()was not called first.
- Return type:
None