Checkpointing and Exporting JAX Models: An End-to-End Guide with Orbax#

This guide demonstrates a complete, end-to-end workflow for managing JAX models using the Orbax library, from robust training-time checkpointing to final model export. We will simulate a Flax/Optax setup to show how the Checkpointer API enables policy-based management and restoration of training states. Following that, we use the standalone save_pytree function to save the final parameters for inference. At the end, we export these parameters into a TensorFlow SavedModel with orbax-export.

1. Setup#

First, we set up the necessary environment by installing the required packages and importing the modules we’ll use throughout this guide.

Note: The following cells install the packages required for this guide. If you are running this within an internal Google environment where these dependencies are already available, these installation steps can be safely skipped.

Installation#

Install orbax-checkpoint for core checkpointing, flax and optax for the JAX model and optimizer, and orbax-export with tensorflow for exporting to the SavedModel format.

!pip install orbax-checkpoint flax optax
Requirement already satisfied: orbax-checkpoint in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (0.11.40)
Requirement already satisfied: flax in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (0.12.7)
Requirement already satisfied: optax in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (0.2.8)
Requirement already satisfied: absl-py in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from orbax-checkpoint) (2.4.0)
Requirement already satisfied: etils[epath,epy] in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from orbax-checkpoint) (1.14.0)
Requirement already satisfied: typing_extensions in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from orbax-checkpoint) (4.15.0)
Requirement already satisfied: msgpack in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from orbax-checkpoint) (1.1.2)
Requirement already satisfied: jax>=0.6.0 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from orbax-checkpoint) (0.10.1)
Requirement already satisfied: numpy in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from orbax-checkpoint) (2.4.6)
Requirement already satisfied: prometheus-client>=0.20.0 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from orbax-checkpoint) (0.25.0)
Requirement already satisfied: pyyaml in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from orbax-checkpoint) (6.0.3)
Requirement already satisfied: tensorstore>=0.1.84 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from orbax-checkpoint) (0.1.84)
Requirement already satisfied: aiofiles in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from orbax-checkpoint) (25.1.0)
Requirement already satisfied: protobuf in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from orbax-checkpoint) (6.33.6)
Requirement already satisfied: humanize in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from orbax-checkpoint) (4.15.0)
Requirement already satisfied: simplejson>=3.16.0 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from orbax-checkpoint) (4.1.1)
Requirement already satisfied: psutil in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from orbax-checkpoint) (7.2.2)
Requirement already satisfied: uvloop in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from orbax-checkpoint) (0.22.1)
Requirement already satisfied: rich>=11.1 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from flax) (15.0.0)
Requirement already satisfied: treescope>=0.1.7 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from flax) (0.1.10)
Requirement already satisfied: jaxlib>=0.5.3 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from optax) (0.10.1)
Requirement already satisfied: ml_dtypes>=0.5.0 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from jax>=0.6.0->orbax-checkpoint) (0.5.4)
Requirement already satisfied: opt_einsum in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from jax>=0.6.0->orbax-checkpoint) (3.4.0)
Requirement already satisfied: scipy>=1.14 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from jax>=0.6.0->orbax-checkpoint) (1.17.1)
Requirement already satisfied: markdown-it-py>=2.2.0 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from rich>=11.1->flax) (3.0.0)
Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from rich>=11.1->flax) (2.20.0)
Requirement already satisfied: mdurl~=0.1 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from markdown-it-py>=2.2.0->rich>=11.1->flax) (0.1.2)
Requirement already satisfied: fsspec in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from etils[epath,epy]->orbax-checkpoint) (2026.4.0)
Requirement already satisfied: zipp in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from etils[epath,epy]->orbax-checkpoint) (4.1.0)
!pip install orbax-export tensorflow
Requirement already satisfied: orbax-export in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (0.0.8)
Requirement already satisfied: tensorflow in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (2.20.0rc0)
Requirement already satisfied: absl-py in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from orbax-export) (2.4.0)
Requirement already satisfied: dataclasses-json in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from orbax-export) (0.6.7)
Requirement already satisfied: etils in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from orbax-export) (1.14.0)
Requirement already satisfied: jax>=0.4.34 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from orbax-export) (0.10.1)
Requirement already satisfied: jaxlib in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from orbax-export) (0.10.1)
Requirement already satisfied: jaxtyping in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from orbax-export) (0.3.9)
Requirement already satisfied: numpy in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from orbax-export) (2.4.6)
Requirement already satisfied: protobuf in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from orbax-export) (6.33.6)
Requirement already satisfied: orbax-checkpoint>=0.9.0 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from orbax-export) (0.11.40)
Requirement already satisfied: astunparse>=1.6.0 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from tensorflow) (1.6.3)
Requirement already satisfied: flatbuffers>=24.3.25 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from tensorflow) (25.12.19)
Requirement already satisfied: gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from tensorflow) (0.7.0)
Requirement already satisfied: google_pasta>=0.1.1 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from tensorflow) (0.2.0)
Requirement already satisfied: libclang>=13.0.0 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from tensorflow) (18.1.1)
Requirement already satisfied: opt_einsum>=2.3.2 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from tensorflow) (3.4.0)
Requirement already satisfied: packaging in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from tensorflow) (26.2)
Requirement already satisfied: requests<3,>=2.21.0 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from tensorflow) (2.34.2)
Requirement already satisfied: setuptools in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from tensorflow) (82.0.1)
Requirement already satisfied: six>=1.12.0 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from tensorflow) (1.17.0)
Requirement already satisfied: termcolor>=1.1.0 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from tensorflow) (3.3.0)
Requirement already satisfied: typing_extensions>=3.6.6 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from tensorflow) (4.15.0)
Requirement already satisfied: wrapt>=1.11.0 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from tensorflow) (2.1.2)
Requirement already satisfied: grpcio<2.0,>=1.24.3 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from tensorflow) (1.80.0)
Requirement already satisfied: tensorboard~=2.20.0 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from tensorflow) (2.20.0)
Requirement already satisfied: keras>=3.10.0 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from tensorflow) (3.14.1)
Requirement already satisfied: h5py>=3.11.0 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from tensorflow) (3.16.0)
Requirement already satisfied: ml_dtypes<1.0.0,>=0.5.1 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from tensorflow) (0.5.4)
Requirement already satisfied: charset_normalizer<4,>=2 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from requests<3,>=2.21.0->tensorflow) (3.4.7)
Requirement already satisfied: idna<4,>=2.5 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from requests<3,>=2.21.0->tensorflow) (3.15)
Requirement already satisfied: urllib3<3,>=1.26 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from requests<3,>=2.21.0->tensorflow) (2.7.0)
Requirement already satisfied: certifi>=2023.5.7 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from requests<3,>=2.21.0->tensorflow) (2026.5.20)
Requirement already satisfied: markdown>=2.6.8 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from tensorboard~=2.20.0->tensorflow) (3.10.2)
Requirement already satisfied: pillow in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from tensorboard~=2.20.0->tensorflow) (12.2.0)
Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from tensorboard~=2.20.0->tensorflow) (0.7.2)
Requirement already satisfied: werkzeug>=1.0.1 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from tensorboard~=2.20.0->tensorflow) (3.1.8)
Requirement already satisfied: wheel<1.0,>=0.23.0 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from astunparse>=1.6.0->tensorflow) (0.40.0)
Requirement already satisfied: scipy>=1.14 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from jax>=0.4.34->orbax-export) (1.17.1)
Requirement already satisfied: rich in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from keras>=3.10.0->tensorflow) (15.0.0)
Requirement already satisfied: namex in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from keras>=3.10.0->tensorflow) (0.1.0)
Requirement already satisfied: optree in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from keras>=3.10.0->tensorflow) (0.19.1)
Requirement already satisfied: msgpack in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from orbax-checkpoint>=0.9.0->orbax-export) (1.1.2)
Requirement already satisfied: prometheus-client>=0.20.0 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from orbax-checkpoint>=0.9.0->orbax-export) (0.25.0)
Requirement already satisfied: pyyaml in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from orbax-checkpoint>=0.9.0->orbax-export) (6.0.3)
Requirement already satisfied: tensorstore>=0.1.84 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from orbax-checkpoint>=0.9.0->orbax-export) (0.1.84)
Requirement already satisfied: aiofiles in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from orbax-checkpoint>=0.9.0->orbax-export) (25.1.0)
Requirement already satisfied: humanize in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from orbax-checkpoint>=0.9.0->orbax-export) (4.15.0)
Requirement already satisfied: simplejson>=3.16.0 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from orbax-checkpoint>=0.9.0->orbax-export) (4.1.1)
Requirement already satisfied: psutil in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from orbax-checkpoint>=0.9.0->orbax-export) (7.2.2)
Requirement already satisfied: uvloop in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from orbax-checkpoint>=0.9.0->orbax-export) (0.22.1)
Requirement already satisfied: markupsafe>=2.1.1 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from werkzeug>=1.0.1->tensorboard~=2.20.0->tensorflow) (3.0.3)
Requirement already satisfied: marshmallow<4.0.0,>=3.18.0 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from dataclasses-json->orbax-export) (3.26.2)
Requirement already satisfied: typing-inspect<1,>=0.4.0 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from dataclasses-json->orbax-export) (0.9.0)
Requirement already satisfied: mypy-extensions>=0.3.0 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from typing-inspect<1,>=0.4.0->dataclasses-json->orbax-export) (1.1.0)
Requirement already satisfied: fsspec in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from etils[epath,epy]->orbax-checkpoint>=0.9.0->orbax-export) (2026.4.0)
Requirement already satisfied: zipp in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from etils[epath,epy]->orbax-checkpoint>=0.9.0->orbax-export) (4.1.0)
Requirement already satisfied: wadler-lindig>=0.1.3 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from jaxtyping->orbax-export) (0.1.7)
Requirement already satisfied: markdown-it-py>=2.2.0 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from rich->keras>=3.10.0->tensorflow) (3.0.0)
Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from rich->keras>=3.10.0->tensorflow) (2.20.0)
Requirement already satisfied: mdurl~=0.1 in /home/docs/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages (from markdown-it-py>=2.2.0->rich->keras>=3.10.0->tensorflow) (0.1.2)

Imports#

from orbax.checkpoint import v1 as ocp
import jax
import jax.numpy as jnp
import numpy as np
import flax.linen as nn
import optax
import os
import shutil
from etils import epath
from jax import tree_util

Helper for Directory Management#

A utility function to ensure a clean state for our checkpointing directories during each run of this tutorial.

def cleanup_directory_if_exists(path_str):
    """Removes a directory if it exists."""
    path = epath.Path(path_str)
    if path.exists():
        shutil.rmtree(path)

tutorial_base_dir = epath.Path('/tmp/orbax_tutorial')
cleanup_directory_if_exists(str(tutorial_base_dir))
tutorial_base_dir.mkdir(parents=True, exist_ok=True)
print(f"Tutorial artifacts will be saved under: {tutorial_base_dir}")
Tutorial artifacts will be saved under: /tmp/orbax_tutorial

2. Define a Simulated JAX State#

We’ll construct a PyTree representing our model’s training state. This typically includes model parameters, optimizer state, and the current training step.

Define a Model and Training State#

We will define a basic Flax model, initialize its parameters, and create an Optax optimizer. The complete training state (model parameters, optimizer state, and step count) is stored in a Python dictionary. Sharding is applied to array elements using jax.device_put.

# Model Hyperparameters
input_dim = 64
hidden_dim = 32
output_dim = 10
batch_size_for_init = 4

class SimpleFlaxModel(nn.Module):
    hidden_dim: int
    output_dim: int
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=self.hidden_dim, name="d1")(x)
        x = nn.relu(x)
        return nn.Dense(features=self.output_dim, name="d2")(x)

key = jax.random.PRNGKey(0)
model_instance = SimpleFlaxModel(hidden_dim, output_dim)

# Initialize model parameters with dummy data.
dummy_input_for_flax_init = jnp.ones((batch_size_for_init, input_dim))
initial_model_params_template = model_instance.init(key, dummy_input_for_flax_init)['params']
np_params = jax.tree_util.tree_map(np.array, initial_model_params_template)

# Initialize the optimizer state.
optimizer_instance = optax.adam(1e-3)
np_opt_state_template = optimizer_instance.init(initial_model_params_template)
# Convert all array-like elements to NumPy arrays, leaving others (like `count`) as-is.
np_opt_state = jax.tree_util.tree_map(lambda x: np.array(x) if hasattr(x, 'shape') else x, np_opt_state_template)

# Define sharding for the model (replicated across all devices).
mesh = jax.sharding.Mesh(jax.devices(), ('data',))
replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())

# Group components and apply sharding to all NumPy arrays in the PyTree.
pytree_components_np = {
    'params': np_params,
    'opt_state': np_opt_state,
}
pytree_components_jax = jax.tree_util.tree_map(lambda x: jax.device_put(x, replicated_sharding) if isinstance(x, np.ndarray) else x, pytree_components_np)

# Combine everything into the final training state PyTree.
simulated_train_state = {**pytree_components_jax, 'step': 0}
print("Initialized JAX training state PyTree with explicit sharding.")
Initialized JAX training state PyTree with explicit sharding.

3. Orbax Checkpointing Workflow#

This section covers managing checkpoints during a simulated training loop using Checkpointer. This API is designed for common training scenarios and allows for powerful configuration through save policies. For a more comprehensive introduction to the Checkpointer API, refer to the Orbax Checkpoint 101 guide.

Create a Checkpointing Directory#

We’ll create a dedicated directory to store our training checkpoints and define a constant for our save interval.

training_ckpt_dir = tutorial_base_dir / 'simulated_training_ckpts'
cleanup_directory_if_exists(str(training_ckpt_dir))
training_ckpt_dir.mkdir(parents=True, exist_ok=True)

SAVE_INTERVAL_STEPS = 2

Checkpointing During a Simulated Training Loop#

We use Checkpointer as a context manager and configure it with a FixedIntervalPolicy. Inside the loop, save_pytree(...) is called on every step, but the policy ensures that a checkpoint is only written to disk when the condition (e.g., step % 2 == 0) is met.

# A simplified function to simulate a single training step.
def train_step_for_loop(state):
  new_state = state.copy() # Work with a mutable copy of the state dict.
  new_state['step'] += 1
  # For this demo, we simulate param changes by adding small random noise.
  key_for_noise = jax.random.PRNGKey(state['step'])
  new_state['params'] = jax.tree_util.tree_map(
        lambda p: p + 0.001 * jax.random.normal(key_for_noise, p.shape, p.dtype),
        state['params']
    )
  return new_state

current_loop_state = tree_util.tree_map(lambda x: x, simulated_train_state) # Start with a fresh copy.
num_training_steps = 7

print(f"Simulating {num_training_steps} training steps...")

with ocp.training.Checkpointer(
    directory=str(training_ckpt_dir),
    save_decision_policy=ocp.training.save_decision_policies.FixedIntervalPolicy(SAVE_INTERVAL_STEPS)
) as ckptr:
    for _ in range(num_training_steps):
        step_to_save_at = current_loop_state['step']

        # `save` takes the current step, the state to save, and optional metrics.
        saved = ckptr.save(step_to_save_at, current_loop_state, metrics={'accuracy': 0.85})

        if saved: # Will be True if the save_decision_policy decided to save.
            print(f"  Saved checkpoint for step {step_to_save_at}...")

        current_loop_state = train_step_for_loop(current_loop_state)
Simulating 7 training steps...
  Saved checkpoint for step 0...
  Saved checkpoint for step 2...
  Saved checkpoint for step 4...
  Saved checkpoint for step 6...

Resuming from a Checkpoint#

To resume training, we use training.Checkpointer.load_pytree. Orbax-checkpoint can automatically find the latest completed checkpoint. We provide an abstract_pytree (an empty or example version of our state) to guide the restoration process and ensure the data is loaded with the correct structure and sharding.

with ocp.training.Checkpointer(directory=str(training_ckpt_dir)) as ckptr:
    print(f"Restore from the latest checkpoint in {training_ckpt_dir}...")

    # It returns None if no checkpoint is found.
    resumed_train_state = ckptr.load(
        abstract_state=simulated_train_state # Provide an abstract state for structure and sharding.
    )

# If a checkpoint was successfully loaded, resumed_train_state will not be None.
if resumed_train_state is not None:
    print(f"Restored state successfully. Resuming from step: {resumed_train_state['step']}")
    with ocp.training.Checkpointer(directory=str(training_ckpt_dir)) as ckptr:
        assert resumed_train_state['step'] == ckptr.latest.step
else:
    # If no checkpoint was found, fall back to the initial state.
    print("No checkpoint found to restore; using initial state.")
    resumed_train_state = simulated_train_state
Restore from the latest checkpoint in /tmp/orbax_tutorial/simulated_training_ckpts...
Restored state successfully. Resuming from step: 6

4. Saving Final JAX Parameters for Export#

After training, you often need to save just the final model parameters for inference or export. For this, Orbax provides the simple save_pytree function, which is ideal for one-off saves without the overhead of training policies. See the Checkpointing PyTrees guide for more details on this lower-level API.

Extract Final Parameters for Saving#

We extract the learned parameters from our final training state, as this is the only part we need for inference.

final_params_save_dir = tutorial_base_dir / 'exported_model_params_orbax'
final_model_params_to_save = current_loop_state['params']
print("Final model parameters extracted for saving.")
Final model parameters extracted for saving.

Using save_pytree for the Final Save#

save_pytree directly saves the given PyTree to the specified directory. It’s a straightforward way to persist the final artifacts of a training process.

# Ensure a clean state by removing the directory if it exists from a previous run.
cleanup_directory_if_exists(str(final_params_save_dir))

print(f"Saving final parameters to: {final_params_save_dir}...")
ocp.save(
    path=final_params_save_dir,
    state=final_model_params_to_save,
    overwrite=True #  overwrites an existing checkpoint in directory
)
print("Final model parameters saved via `save`.")
Saving final parameters to: /tmp/orbax_tutorial/exported_model_params_orbax...
Final model parameters saved via `save`.

Loading Exported Parameters (Verification)#

We can use load_pytree to load the parameters back and verify that the save operation was successful. Again, we can pass an abstract_pytree to help guide the restoration.

if final_params_save_dir.exists() and len(os.listdir(str(final_params_save_dir))) > 0:
    print(f"Loading parameters from {final_params_save_dir} for verification...")
    loaded_final_params = ocp.load(
        final_params_save_dir,
        abstract_state=final_model_params_to_save # Use instance as a template for structure and sharding.
    )
    # Check that the loaded parameters match the original ones.
    params_match = jax.tree_util.tree_all(
        jax.tree_util.tree_map(jnp.array_equal, final_model_params_to_save, loaded_final_params)
    )
    print(f"Verification: {'PASSED' if params_match else 'FAILED'}")
else:
    print("Saved parameters directory not found or empty. Skipping verification.")
WARNING:absl:TensorStore data files not found in checkpoint path /tmp/orbax_tutorial/exported_model_params_orbax. This may be a sign of a malformed checkpoint, unless your checkpoint consists entirely of strings or other non-standard PyTree leaves.
Loading parameters from /tmp/orbax_tutorial/exported_model_params_orbax for verification...
Verification: PASSED

5. Exporting to TensorFlow SavedModel#

This section demonstrates converting the saved JAX model parameters into a TensorFlow SavedModel format using the orbax export library. This is a common step for for exporting JAX models to TensorFlow SavedModel format.

from orbax.export import ExportManager, JaxModule, ServingConfig
from orbax.export.validate.validation_manager import ValidationManager
import tensorflow as tf
import traceback
import sys

Define JAX Model Apply Function and Pre/Post-processing for Export#

For orbax export, we need to provide a JAX function that takes (params, inputs). We can also define TensorFlow-based pre-processing and post-processing functions, which will be included in the SavedModel’s computation graph.

# `model_instance` was defined in Section 2 (the SimpleFlaxModel instance).
# `final_model_params_to_save` contains the parameters we want to export from Section 4.

# JAX Apply Function: The core JAX logic for the model's forward pass.
@jax.jit
def jax_model_apply_fn_for_export(params, inputs):
  """A JAX function with the signature (params, inputs) for orbax-export."""
  return model_instance.apply({'params': params}, inputs)


# Optional: TF Pre-processing Function.
def tf_preprocess_fn_for_export(input_tensor: tf.Tensor) -> tf.Tensor:
  """Normalizes the raw input tensor. Orbax-export will trace this into a graph."""
  return tf.cast(input_tensor, tf.float32) / 255.0


# Optional: TF Post-processing Function.
def tf_postprocess_fn_for_export(output_tensor: tf.Tensor) -> dict[str, tf.Tensor]:
  """Packages the model output into a dictionary. Orbax-export will trace this."""
  return {'predictions': output_tensor}

print("JAX apply function and plain TF pre/post-processing functions defined for export.")
JAX apply function and plain TF pre/post-processing functions defined for export.

Create JaxModule and ServingConfig#

JaxModule wraps the JAX function and its parameters. ServingConfig defines the input signature for the SavedModel and specifies which pre/post-processing functions to use for a given serving signature key (e.g., serving_default).

# Create the JaxModule, which encapsulates the JAX function and its parameters.
jax_module_for_export = JaxModule(
    params=final_model_params_to_save,
    apply_fn=jax_model_apply_fn_for_export,
    input_polymorphic_shape=f'(b, {input_dim})',
    jax2tf_kwargs={'with_gradient': False, 'native_serialization': False}
)

# This tells orbax-export how to trace the Python preprocessor function.
tf_input_signature = [
    tf.TensorSpec(shape=[None, input_dim], dtype=tf.float32)
]

# Create a serving configuration that bundles the signature key, input specs,
# and our Python processing functions.
serving_config = ServingConfig(
    signature_key='serving_default',
    input_signature=tf_input_signature,
    tf_preprocessor=tf_preprocess_fn_for_export,
    tf_postprocessor=tf_postprocess_fn_for_export
)
print("JaxModule and ServingConfig created successfully.")
JaxModule and ServingConfig created successfully.
2026-05-20 21:23:31.206349: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)

Export to TensorFlow SavedModel#

The ExportManager takes the JaxModule and a list of ServingConfig to build and save the final TensorFlow SavedModel.

# Define the directory to save the final exported model.
saved_model_dir = tutorial_base_dir / 'tf_saved_model_orbax_export'
cleanup_directory_if_exists(str(saved_model_dir))
saved_model_dir.mkdir(parents=True, exist_ok=True)

# The ExportManager orchestrates the JAX-to-TF conversion and saving process.
export_manager = ExportManager(jax_module_for_export, [serving_config])
print(f"Exporting SavedModel to: {saved_model_dir}")
try:
    export_manager.save(str(saved_model_dir))
    print("Model exported successfully to SavedModel format.")
    print(f"Contents of {saved_model_dir}: {os.listdir(str(saved_model_dir))}")
except Exception as e:
    print(f"ERROR during SavedModel export: {e}")
    import traceback
    traceback.print_exc()
Exporting SavedModel to: /tmp/orbax_tutorial/tf_saved_model_orbax_export
Model exported successfully to SavedModel format.
Contents of /tmp/orbax_tutorial/tf_saved_model_orbax_export: ['fingerprint.pb', 'assets', 'saved_model.pb', 'variables']

Validate the Exported Model#

A critical final step is to verify that the exported TensorFlow model produces the same results as the original JAX model. We use the ValidationManager, which compares the outputs of the JAX model and the loaded TF SavedModel for a given batch of inputs and generates a detailed report.

# Prepare a batch of test inputs. These should be "raw" (pre-preprocessing).
validation_batch_size = 4
raw_validation_inputs = np.random.rand(validation_batch_size, input_dim).astype(np.float32) * 255.0

# To match the positional signature pass inputs as a list of lists.
validation_mgr = ValidationManager(
    module=jax_module_for_export,
    serving_configs=[serving_config],
    model_inputs=[[raw_validation_inputs]]
)

# Load the candidate model we want to validate.
loaded_tf_model = tf.saved_model.load(str(saved_model_dir))

# Run the validation, which compares the JAX and TF outputs.
print("\nRunning validation...")
validation_reports = validation_mgr.validate(loaded_tf_model)

# Check the report. The report is a dict keyed by the signature_key.
report = validation_reports['serving_default']

# The report status is an enum. We check its string name for a simple pass/fail result.
if report.status.name == 'Pass':
    print(f"VERIFICATION PASSED! Status: {report.status.name}")
else:
    print(f"VERIFICATION FAILED! Status: {report.status.name}")

# The report can be printed as a JSON string for detailed inspection of differences and latencies.
print("\nValidation Report:")
print(report.to_json(indent=2))
Running validation...
VERIFICATION PASSED! Status: Pass

Validation Report:
{
  "outputs": {
    "FloatingPointDiffReport": {
      "total": 40,
      "max_diff": 0.0,
      "max_rel_diff": 0.0,
      "all_close": true,
      "all_close_absolute_tolerance": 1e-07,
      "all_close_relative_tolerance": 1e-07
    },
    "NonFloatingPointDiffReport": {
      "total_flattened_tensors": 0,
      "mismatches": 0,
      "mismatch_ratio": 0.0,
      "max_non_floating_mismatch_ratio": 0.01
    }
  },
  "latency": {
    "baseline": {
      "num_batches": 1,
      "avg_in_ms": 1.2004375457763672,
      "p90_in_ms": 1.2004375457763672,
      "p99_in_ms": 1.2004375457763672
    },
    "candidate": {
      "num_batches": 1,
      "avg_in_ms": 2.152681350708008,
      "p90_in_ms": 2.152681350708008,
      "p99_in_ms": 2.152681350708008
    }
  },
  "xprof_url": {
    "baseline": "N/A",
    "candidate": "N/A"
  },
  "metadata": {
    "baseline": {},
    "candidate": {}
  },
  "status": 1
}
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1779312211.583601    3696 device_compiler.h:196] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.