Partial Saving#

As deep learning models grow, often to hundreds of billions of parameters, managing their checkpoints becomes a significant challenge. Modifying a large checkpoint, even for a small change like adding metrics or a single layer, traditionally requires an inefficient “load-modify-save” cycle. This process uses lots memory and I/O bandwidth, as the entire multi-terabyte checkpoint must be loaded from storage, changed in memory, and written back out.

Partial saving is designed to solve this problem by allowing you to modify a checkpoint without loading the entire object into host memory. It dramatically reduces peak memory usage, minimizes redundant I/O, and simplifies common model update workflows.

The Core Concept: The Partial Save Session#

Partial saving operates on a “session” or “transaction” model. Instead of overwriting your checkpoint directly, Orbax stages all changes in a temporary, in-progress location. The workflow consists of two stages:

  1. Incremental Updates: Calls to functions like ocp.partial.save contribute data to an in-progress checkpointing session. These changes are staged in a temporary location and are not yet visible at the final checkpoint path. From the user’s perspective, the first save call simply begins this incremental process, and subsequent calls add to it.

  2. Finalization: A concluding call to ocp.partial.finalize completes the session. This action commits all the staged changes, making the checkpoint available at its final destination and ready for consumption.

This approach ensures that the modification process is safe and atomic. If the process is interrupted before finalization, your original checkpoint remains untouched.

Note: Partial saving currently does NOT support replacing data written out in previous save calls. If you have a need for Partial Saving Replacement (as opposed to the currently supported Partial Saving Addition), please reach out to the Orbax Checkpointing team so that development of this feature can be prioritized.

The canonical way to do Replacement without partial saving is by loading the model, updating values in memory, then saving back out.

API and Basic Usage: Adding to a Checkpoint#

The partial saving API is available in the orbax.checkpoint.v1.partial module, but you’ll likely access it via ocp.partial.

The most common (and only supported) use case is adding new data (leaves or subtrees) to an existing PyTree checkpoint. The provided PyTree in a save call represents a set of updates. If a key does not exist in the on-disk checkpoint, it is treated as an addition. If a key already exists, it is viewed as a replacement (currently not allowed), and results in a NotImplementedError.

Code Example: A Simple Addition Workflow#

Let’s start with an initial training state, then update that state with new data in a separate step.

from orbax.checkpoint import v1 as ocp
import numpy as np
import jax
from etils import epath

Initial Save#

Let’s say we have an initial training state. The first call creates a temporary directory (e.g., /tmp/partial-saving/partial_save/ckpt.partial_save) and saves the initial state there.

path = epath.Path('/tmp/partial-saving/partial_save/ckpt')
path.parent.rmtree(missing_ok=True)

initial_state = {
    'params': {
        'layer0': np.arange(8),
    },
    'step': 10000,
}

ocp.partial.save(path, initial_state)
assert not path.exists()
assert (path.parent / (path.name + '.partial_save')).exists()

Add More Data#

After training some more, we have a new layer ready to be added. A subsequent call adds the new layer to the same temporary directory. Orbax merges the new PyTree with the existing one.

new_state = {
    'params': {
        'layer1': np.ones(4),
    },
}

ocp.partial.save(path, new_state)
assert not path.exists()
assert (path.parent / (path.name + '.partial_save')).exists()

Aside: Loading Before Finalizing#

Before finalizing the checkpoint, let’s see what happens if we try to load the partial checkpoint.

try:
  ocp.load(path)
except Exception as e:
  print("LOAD ERROR")
  print(e)
LOAD ERROR
Could not recognize the checkpoint at /tmp/partial-saving/partial_save/ckpt as a valid Orbax checkpoint. If you are trying to load a checkpoint that does not conform to the standard Orbax format, use `ocp.Context(layout=...)` to specify the expected checkpoint layout.

Finalize the Checkpoint#

This atomically renames the temporary directory to the final path, making it a complete, readable checkpoint.

ocp.partial.finalize(path)
assert not (path.parent / (path.name + '.partial_save')).exists()
assert path.exists()

Verify the Result#

Now, we can load the checkpoint and see the merged result.

restored_state = ocp.load(path)

expected_state = {
  'params': {
    'layer0': np.array([0, 1, 2, 3, 4, 5, 6, 7]),
    'layer1': np.array([1., 1., 1., 1.])
  },
  'step': 10000,
}

def is_equal(x, y):
  if isinstance(x, np.ndarray):
    assert np.allclose(x, y)
  else:
    assert x == y

jax.tree.map(is_equal, restored_state, expected_state)
restored_state
WARNING:absl:TensorStore data files not found in checkpoint path /tmp/partial-saving/partial_save/ckpt. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.
{'params': {'layer0': array([0, 1, 2, 3, 4, 5, 6, 7]),
  'layer1': array([1., 1., 1., 1.])},
 'step': 10000}

API Reference#

  • ocp.partial.save() / ocp.partial.save_async(): Saves a PyTree to the temporary partial save location. These functions can be called multiple times.

  • ocp.partial.finalize(): Commits the transaction, making the checkpoint permanent at the specified path. This must be called to complete the process.

Advanced Workflow: Combining Partial Saving and Partial Restore#

When combined with Partial Restore, this feature can enable highly efficient, targeted updates to massive checkpoints with a minimal memory footprint. You can use Partial Restore for a memory-efficient read, perform modifications, and then use Partial Save for a flexible and efficient write.

Use Case: Creating an Inference-Ready Checkpoint#

Imagine you have a 2TB training checkpoint containing model params and a bulky optimizer_state. You want to create a smaller, inference-ready checkpoint that:

  • Contains only the params.

  • Has an updated encoder_stack within the params from a recent fine-tuning run.

This entire process can be done without ever loading the massive optimizer_state into memory.

from orbax.checkpoint import v1 as ocp
import numpy as np
import jax
from etils import epath

Setup#

Create a large, multi-part “base” checkpoint to simulate a real scenario. This represents a very large model, but we only write it to disk. We never load it all at once (other than to view the metadata).

base_path = epath.Path('/tmp/partial-saving/base_model/ckpt')
base_path.rmtree(missing_ok=True)

sharding = jax.sharding.NamedSharding(
    jax.sharding.Mesh(jax.devices(), ('model',)),
    jax.sharding.PartitionSpec(
        'model',
    ),
)
create_sharded_array = lambda x: jax.device_put(x, sharding)
base_model_state = {
    'params': {
        'large_embedding_table': np.ones((1024, 1024)), # A large array
        'encoder_stack': {f'layer_{i}': np.random.rand(2) for i in range(4)}, # The part we will replace
        'classification_head': np.random.rand(8),
    },
    'optimizer_state': [np.random.rand(128) for _ in range(16)],
}
base_model_state = jax.tree.map(create_sharded_array, base_model_state)
ocp.save(base_path, base_model_state)

abstract_base_model_state = jax.tree.map(
    ocp.arrays.to_shape_dtype_struct,
    base_model_state
)
init_ckpt = ocp.load(base_path, abstract_base_model_state)
print("\n--- Setup ---")
print(f"Optimizer state exists in initial checkpoint: {'optimizer_state' in init_ckpt}")
print(f"Model version exists in initial checkpoint: {'model_version' in init_ckpt}")
for layer, weights in init_ckpt['params']['encoder_stack'].items():
    print(f"Original {layer}: {weights}")
WARNING:absl:TensorStore data files not found in checkpoint path /tmp/partial-saving/base_model/ckpt. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.
--- Setup ---
Optimizer state exists in initial checkpoint: True
Model version exists in initial checkpoint: False
Original layer_0: [0.89399326 0.7459455 ]
Original layer_1: [0.05897963 0.3974571 ]
Original layer_2: [0.5589441  0.42310825]
Original layer_3: [0.16605228 0.5758625 ]

The Efficient Update Workflow#

Use Partial Restore (Omission mode) to load ONLY the params. Create a reference PyTree that only has the params structure. This tells Orbax to ignore everything else (like optimizer_state). Enable partial loading via Context to allow omitting nodes.

inference_path = epath.Path('/tmp/partial-saving/inference_model/ckpt')
inference_path.parent.rmtree(missing_ok=True)

abstract_params = jax.tree.map(
    ocp.arrays.to_shape_dtype_struct, {'params': base_model_state['params']}
)

with ocp.Context(
    pytree_options=ocp.options.PyTreeOptions(
        loading=ocp.options.PyTreeOptions.Loading(partial_load=True)
    )
):
    loaded_params = ocp.load(base_path, abstract_params)
WARNING:absl:TensorStore data files not found in checkpoint path /tmp/partial-saving/base_model/ckpt. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.

At this point, params is in memory, but optimizer_state was never loaded.

Update and Partial Save#

Modify the loaded parameters in memory. Add new metadata that might be useful for inference. Use Partial Save to write the modified params and new metadata to the new inference checkpoint location. Finalize the new, smaller, inference-ready checkpoint.

save_params = {}  # Used to create abstract params for inference

metadata = {'model_version': 'v1.2-finetuned'}
save_params = ocp.tree.merge(save_params, metadata)
ocp.partial.save(inference_path, metadata)  # Initial partial save for metadata

for layer, weights in loaded_params['params']['encoder_stack'].items():
  new_weights = weights + np.random.rand(2)
  stack_layer = {
      'params': {
          'encoder_stack': {
              layer: jax.tree.map(
                  create_sharded_array, new_weights
              ),
          }
      },
  }
  save_params = ocp.tree.merge(save_params, stack_layer)
  ocp.partial.save(inference_path, stack_layer)  # One partial save per layer

ocp.partial.finalize(inference_path)
WARNING:absl:`memory_size` is not implemented for `TypeHandler` of type: <class 'orbax.checkpoint.experimental.v1._src.serialization.compatibility.CompatibleTypeHandler'>. Using the a default implementation to measure value memory consumption that may result in inaccurate estimation.
WARNING:absl:[process=0][thread=Thread-47 (_event_loop_runner)] Skipping merge of OCDBT checkpoints: No per-process OCDBT checkpoint subdirs found in /tmp/partial-saving/inference_model/ckpt.partial_save.orbax-checkpoint-tmp/state, 

Verification#

abstract_params = jax.tree.map(
    lambda x: (
        str()
        if isinstance(x, str)
        else ocp.arrays.to_shape_dtype_struct(x)
    ),
    save_params
)
final_ckpt = ocp.load(inference_path, abstract_params)

print("\n--- Verification ---")
print(f"Optimizer state exists in final checkpoint: {'optimizer_state' in final_ckpt}")
print(f"Model version: {final_ckpt['model_version']}")
for layer, weights in final_ckpt['params']['encoder_stack'].items():
    print(f"New {layer}: {weights}")
WARNING:absl:TensorStore data files not found in checkpoint path /tmp/partial-saving/inference_model/ckpt. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.
WARNING:absl:`memory_size` is not implemented for `TypeHandler` of type: <class 'orbax.checkpoint.experimental.v1._src.serialization.compatibility.CompatibleTypeHandler'>. Using the a default implementation to measure value memory consumption that may result in inaccurate estimation.
--- Verification ---
Optimizer state exists in final checkpoint: False
Model version: v1.2-finetuned
New layer_0: [1.4163686  0.88737243]
New layer_1: [0.825537  0.8961133]
New layer_2: [1.2545017 0.9375919]
New layer_3: [0.3955825 1.3053663]

In this workflow, we created a new, pruned, and modified checkpoint. The key efficiency gain came from using Partial Restore to load only the params, completely avoiding the memory cost of the massive optimizer_state.

Atomicity Guarantees#

The use of a temporary directory and an atomic rename operation during finalization guarantees safety. If your program crashes mid-save, the original checkpoint (if any) is unharmed, and the temporary directory can be safely deleted.