π Orbax Checkpoint Benchmark Framework#
Measure the performance β and shape β of Orbax checkpoint operations (save, restore, reshard, broadcast, β¦) across model sizes, mesh topologies, and option combinations β locally, on a multi-host cluster, or on Pathways/cloud.
A benchmark has two parts, and you provide both:
β A benchmark class |
A small Python class whose |
β‘ A YAML config |
Selects which class to run, supplies its options, and describes the checkpoint(s) and mesh(es) to run against. |
Given those, the framework does everything else: expands every option combination, runs each (optionally several times), collects a rich metric suite per host, aggregates it across hosts, and writes results to the logs and TensorBoard β and can capture a baseline and diff later runs against it.
β οΈ The YAML alone is not enough. It always points at a benchmark class; the measured code lives in that classβs
test_fn, not in the config.
Capabilities at a glance#
Capability |
What you get |
Where |
|---|---|---|
βοΈ Write your own benchmark |
A class + |
|
β»οΈ Reuse built-ins |
Save/restore, resharding, restore+broadcast, β¦ |
|
β±οΈ One-line metric capture |
|
|
π Rich metric suite |
time, host & device memory, throughput, per-stage timings, TensorStore I/O, compile-cache |
|
π Parameter sweeps |
list-valued options Γ meshes Γ checkpoints (Cartesian product) |
|
π Repeats + cross-host aggregation |
|
|
π§ͺ Synthetic or real data |
generate from a |
|
πΊοΈ Multi-topology meshes |
list several meshes; incompatible ones are auto-skipped |
|
π Baselines (A/B) |
capture |
|
π¬ Profiler traces + HLO |
|
|
π TensorBoard |
scalars, HParams, profile traces, inventory cards |
|
βοΈ Pathways & cloud |
auto backend init; colocated-Python load dispatcher; XPK launcher |
Contents#
1. 60-second tour#
π The fastest path: reuse the built-in save/restore benchmark.
1 β a config (sweep.yaml):
suite_name: "ocdbt vs non-ocdbt"
num_repeats: 3
checkpoint_config: # synthetic data - no real model needed
spec:
params: {dtype: bfloat16, shape: [8192, 8192], sharding: [fsdp]}
mesh_config:
mesh_axes: ["fsdp"]
ici_parallelism: {"fsdp": 8}
benchmarks:
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.benchmark.Benchmark"
options:
use_ocdbt: [true, false] # a list β swept
2 β run it (by file, from the repo root):
python checkpoint/orbax/checkpoint/_src/testing/benchmarks/run_benchmarks.py \
--config_file=sweep.yaml --output_directory=/tmp/bench/
3 β read the results β per-operation timing, throughput, memory and I/O are printed to the logs and written to TensorBoard:
tensorboard --logdir=/tmp/bench/tensorboard/
Two benchmarks ran (use_ocdbt true/false), each 3Γ, aggregated across hosts.
To benchmark something the built-ins donβt cover, write a class β see Β§3.
2. How it works#
π§ The pieces (in orbax/checkpoint/_src/testing/benchmarks/):
Object |
Role |
|---|---|
Frozen dataclass of your knobs. Any list field becomes a sweep axis. |
|
The class you write/reuse. Its |
|
The one-line context manager you wrap measured code in. |
|
|
|
|
|
Orchestrator built from your YAML: expands, repeats, aggregates, baselines. |
your YAML βββΊ TestSuite βββΊ your BenchmarksGenerator.generate()
β options Γ meshes Γ checkpoints
βΌ
[ Benchmark, β¦ ] βββΊ Benchmark.run() Γ num_repeats
β
βββββββββ your test_fn(TestContext) βββββββββ
β with metrics.measure("save"): ... β β the code you measure
β return TestResult(metrics) β
ββββββββββββββββββββββββββββββββββββββββββββββ
β
cross-host aggregation βββΊ logs + TensorBoard (+ optional baseline A/B)
Everything outside the box is handled for you. Your job: the test_fn and the
YAML that drives it.
3. Writing a benchmark#
βοΈ A benchmark class is three small pieces: an options dataclass, the
@benchmark_options decorator, and a BenchmarksGenerator with a
test_fn.
Capturing metrics β the measure() block#
Wrap any block in metrics.measure("name") and the framework captures the whole
default metric suite around it β time, host & device memory, compile-cache,
TensorStore I/O, and Orbaxβs per-stage timings/throughput:
with metrics.measure("save"): # one line β the whole metric suite
ocp.save(path, pytree)
The name becomes the operation prefix on every metric the block emits
(save_0_basics/time_s, β¦). Use one measure() per operation you want broken
out.
π‘
measure()is not checkpoint-specific β it times any block. Time, memory, device and compile-cache metrics populate for anything; the Orbax save/load breakdown tags simply stay empty unless the block does Orbax I/O.
Blocks nest β capture an aggregate and each step at once:
with metrics.measure("train_steps"): # whole-loop aggregate
for step in range(options.num_steps):
with metrics.measure(f"step_{step}"): # per-step breakdown
params, opt_state = train_step(params, opt_state, next(data_iter))
yields train_steps_0_basics/time_s alongside step_0_β¦, step_1_β¦ (the
collectors are reentrant-safe, so nesting is fine).
A complete example#
import dataclasses
import jax
from orbax.checkpoint import v1 as ocp
from orbax.checkpoint._src.testing.benchmarks.core import core as benchmarks_core
from orbax.checkpoint._src.testing.benchmarks.core import metric as metric_lib
# β Your knobs. A list-typed field is swept (see Β§8).
@dataclasses.dataclass(frozen=True)
class SaveRestoreOptions(benchmarks_core.BenchmarkOptions):
use_ocdbt: bool | list[bool] = True
def is_valid(self) -> bool: # optional: drop nonsensical combinations
return True
# β‘ Bind the options to the generator.
@benchmarks_core.benchmark_options(SaveRestoreOptions)
class SaveRestoreBenchmark(benchmarks_core.BenchmarksGenerator):
# β’ The measured code. Called once per (option Γ mesh Γ checkpoint).
def test_fn(
self, context: benchmarks_core.TestContext
) -> benchmarks_core.TestResult:
metrics = metric_lib.Metrics()
options = context.options # a SaveRestoreOptions instance
pytree = context.pytree # generated from checkpoint_config.spec
abstract = jax.tree.map(ocp.arrays.to_shape_dtype_struct, pytree)
path = context.path / "ckpt" # a fresh per-run directory
with ocp.Context(
array_options=ocp.options.ArrayOptions(
saving=ocp.options.ArrayOptions.Saving(use_ocdbt=options.use_ocdbt)
)
):
with metrics.measure("save"):
ocp.save(path, pytree)
with metrics.measure("load"):
_ = ocp.load(path, abstract_state=abstract)
return benchmarks_core.TestResult(metrics=metrics)
Thatβs the whole benchmark. The config (Β§5) decides context.pytree /
context.mesh, the repeat count, and which use_ocdbt values to sweep.
test_fn input β TestContext#
Field |
Meaning |
|---|---|
|
Checkpoint data the framework generated (from |
|
A fresh per-run working directory. |
|
The resolved options for this sweep point. |
|
The |
|
Which repeat this is (or |
|
Profiler-trace directory for operation |
test_fn output β TestResult#
Return TestResult(metrics=metrics). The framework fills in the run path and a
checkpoint inventory (bytes/files) automatically. If test_fn raises, the
error is recorded and the run exits non-zero β you donβt catch exceptions
yourself.
4. Reusing a built-in benchmark#
β»οΈ For common cases, just reference a built-in under benchmarks: - generator:.
v1.benchmark.Benchmark β save + restore round-trip#
The canonical generator: saves the pytree then restores it, measuring
save_blocking, save_background, and load. Options
(v1.benchmark.BenchmarkOptions):
Option |
Default |
Meaning |
|---|---|---|
|
|
Use |
|
|
Use the OCDBT driver. |
|
|
Use Zarr v3. |
|
|
Compress array data. |
|
|
Parallelize writes across replicas. |
|
|
Separate folder per replica (needs |
|
|
Primary host loads, then broadcasts. |
|
|
Array chunk size. |
|
|
Concurrency budgets (GiB). |
|
|
Add the |
|
|
Profiler traces (Β§8). |
Other built-ins#
Generator |
Measures |
|---|---|
Loading an existing checkpoint into a target sharding ( |
|
|
Restore-on-one-replica-then-broadcast. |
|
Replica-parallel saving across slices. |
Other generators in this directory (
checkpoint_manager_benchmark,array_handler_benchmark,pytree_checkpoint_benchmark,emergency_checkpoint_manager_benchmark,single_replica_benchmark,lustre_benchmark,pytorch_checkpoint_benchmark, β¦) target specific subsystems and follow the same pattern.
5. The config file#
βοΈ A single YAML file that selects and parametrizes your benchmark class.
Top-level keys#
Key |
Required |
Default |
Meaning |
|---|---|---|---|
|
yes |
β |
Human-readable run name. |
|
no |
|
Times to run each generated benchmark. |
|
no |
one empty config |
Checkpoint(s) to save/load. Plural β swept. |
|
no |
none |
Device mesh(es). Plural β swept. Omitted β |
|
yes |
β |
List of |
|
no |
none |
Write captured baseline JSON here (overridden by |
|
no |
none |
Compare against this stored baseline (overridden by |
benchmarks#
benchmarks:
- generator: "my.module.SaveRestoreBenchmark" # import path (built-in or yours)
options:
use_ocdbt: [true, false] # any list β swept
checkpoint_config#
Describes context.pytree β generated from a spec or loaded from a
path (exactly one).
Field |
Default |
Meaning |
|---|---|---|
|
|
Synthetic pytree: |
|
|
Load an existing checkpoint. Mutually exclusive with |
|
|
Seed for synthetic generation (deterministic). |
|
|
Per-tensor target sharding JSON, used with |
|
|
On Pathways, load via the colocated-Python dispatcher (Β§6). |
mesh_config#
Translated into context.mesh.
Field |
Default |
Meaning |
|---|---|---|
|
β |
Parallelism dimension names, e.g. |
|
|
Per-axis degree within a slice, e.g. |
|
|
Per-axis degree across slices (multi-slice). |
|
|
Allow splitting physical axes. |
|
|
Treat processes as the outer-network unit. |
π‘ The product of all axis degrees must equal the device count, or the mesh is skipped. List several meshes and only the ones that fit the live hardware run.
Running it#
Run the run_benchmarks.py script directly (paths relative to the repo
root):
python checkpoint/orbax/checkpoint/_src/testing/benchmarks/run_benchmarks.py \
--config_file=<config.yaml> \
--output_directory=<dir> \
[flags...]
Flag |
Required |
Meaning |
|---|---|---|
|
yes |
Path to the YAML config. |
|
yes |
Where results, TensorBoard logs, and traces go. |
|
no |
Local scratch directory (some checkpoint-manager benchmarks). |
|
no |
Dump XLA HLO protos to |
|
no |
Delete the generated |
|
no |
Write each benchmarkβs baseline JSON (cross-host aggregated) here. |
|
no |
Compare against the stored baseline and log the per-metric delta. |
The runner enables jax_enable_x64. On a single process it runs locally; on
CPU, simulate devices with
XLA_FLAGS=--xla_force_host_platform_device_count=8.
6. Running on Pathways and cloud#
βοΈ The same config and the same benchmark class run unchanged on a local host, a multi-process cluster, or the Pathways single-controller backend β the framework adapts automatically.
Automatic backend init.
run_benchmarks.pydetects the active backend: in a multi-process JAX cluster it callsjax.distributed.initialize()from the standard env vars (JAX_COORDINATOR_ADDRESS,JAX_PROCESS_ID,JAX_NUM_PROCESSES,JAX_COORDINATOR_PORT); if Pathways is in use (pathwaysutils.is_pathways_backend_used()) it initializes Pathways instead. No change to your config or class.Colocated-Python load dispatcher. On Pathways, the frameworkβs checkpoint loader runs through Orbaxβs colocated-Python path: set
load_with_colocated_python: trueincheckpoint_config(with apath). The loader (checkpoint_generation.load_checkpoint) then builds the PathwaysCheckpointingImplwith colocated Python and registers the colocated type handlers, so deserialization runs colocated with the TPU workers, dispatched from the single controller β the production Pathways load path, measured end to end.checkpoint_config: path: "gs://my-bucket/ckpt/items" sharding_config_path: "gs://my-bucket/sharding/abstract_state.json" load_with_colocated_python: true # colocated-Python dispatcher on Pathways
Cloud benchmarking via XPK. The
xpk/launcher runs a suite on a GKE/Pathways cluster:xpk/launch_xpk.py--enable_pathwaysprovisions the Pathways server / proxy / colocated-Python sidecar images and runsrun_benchmarkson the cluster (the sidecar executes the colocated-Python code). For the end-to-end setup seexpk/PathwaysColocatedPythonGuide.md; for the launchers seexpk/README.md(GKE/XPK) andtpu_vm/README.md(single TPU VM).
7. Metrics reference#
π Every result from a measure("<operation>", metric_keys) block is a
TensorBoard scalar named:
<operation>_<namespace>/<metric> e.g. load_0_basics/time_s
The <operation> prefix is the string you passed to measure().
Collectors#
measure() with no metric_keys uses the defaults: time, rss,
jax_monitoring, device_memory, tensorstore. tracemalloc is opt-in.
Key |
Records |
|---|---|
|
Wall-clock duration of the block. |
|
Host RSS memory delta. |
|
Orbaxβs |
|
|
|
TensorStore kvstore op counts, cache hit/miss, tcmalloc deltas (whitelisted). |
|
Python allocation peak + top sites (opt-in; |
Scope it explicitly with a list:
metrics.measure("load", ["time", "device_memory"]).
Namespaces#
Results group into ordered namespaces so the dashboard reads top-to-bottom:
Namespace |
Source |
Representative metrics (units) |
|---|---|---|
|
time, rss |
|
|
jax_monitoring |
|
|
jax_monitoring |
|
|
jax_monitoring |
|
|
jax_monitoring |
|
|
tensorstore |
|
|
device_memory, tracemalloc |
|
|
jax_monitoring |
|
0_basics/6_tensorstore/7_memorypopulate for any measured block; the2_β5_/8_namespaces come from Orbaxβsjax.monitoringevents, so they populate when the block performs Orbax I/O. The collectors live incore/metric.py; the eventβtag map is incore/jax_monitoring_tags.py(unmapped events still surface).
Headline metrics#
π― The five to watch:
<op>_4_throughput/*_gbpsβ effective bandwidth (the primary perf signal).<op>_0_basics/time_sβ wall-clock per operation.<op>_2_save_breakdown/*/*_3_load_breakdown/*β where the time went.<op>_7_memory/device_hbm_peak_diff_gbβ peak device-memory cost.<op>_8_jax/cache_hit_rateβ compilation-cache effectiveness.
8. Features#
π Parameter sweeps#
Any options field set to a list is a sweep axis; the benchmark runs the
Cartesian product of all axes. checkpoint_configs / mesh_configs (plural)
sweep too. is_valid() on your options drops invalid combinations.
options:
use_ocdbt: [true, false]
use_zarr3: [true, false] # β 4 benchmarks
π Repeats & cross-host aggregation#
num_repeats runs each benchmark N times; metrics are aggregated across
hosts (min/mean/max/β¦), so multi-process runs report cluster-wide
statistics, not just rank 0. --remove_repeated_dir cleans up per-repeat
checkpoint directories.
π Baselines (capture / compare)#
# Capture on the baseline revision (writes cross-host aggregates as <git_sha>.json):
... --baseline_capture_path=gs://bucket/baselines/my_suite/
# Compare a later revision (logs a per-metric delta):
... --baseline_path=gs://bucket/baselines/my_suite/<git_sha>.json
If no git sha resolves, the baseline is written as unknown.json.
π¬ Profiler traces & HLO dumps#
enable_trace: true(an options field) captures ajax.profilertrace per measured operation, surfaced as its own run in the TensorBoard Profile tab. Only the first repeat is traced by default;trace_every_repeat: truetraces all.--enable_hlo_dumpwrites XLA HLO protos to<output_directory>/hlo_dump/.
π TensorBoard output#
Under <output_directory>/tensorboard/: scalars (every metric), HParams
(the option combination per benchmark, for filtering/comparison), profile
traces (when enable_trace is on), and markdown cards (checkpoint
inventory + run summary).
9. Recipes#
A β Sweep storage options (built-in class)#
See the 60-second tour.
B β Run a benchmark class you wrote#
With SaveRestoreBenchmark from Β§3 importable as my.module.SaveRestoreBenchmark:
suite_name: "my save/restore"
num_repeats: 3
checkpoint_config:
spec: {params: {dtype: float32, shape: [4096, 4096], sharding: [data]}}
mesh_config:
mesh_axes: ["data"]
ici_parallelism: {"data": 8}
benchmarks:
- generator: "my.module.SaveRestoreBenchmark"
options: {use_ocdbt: [true, false]}
C β Load a real checkpoint into a target sharding (built-in)#
suite_name: "resharding"
num_repeats: 10
mesh_config:
mesh_axes: ["data", "fsdp", "tensor"]
ici_parallelism: {"data": 1, "fsdp": 16, "tensor": 1}
benchmarks:
- generator: "orbax.checkpoint._src.testing.benchmarks.v1.resharding_benchmark.ReshardingBenchmark"
options:
reference_checkpoint_path: "gs://my-bucket/ckpt/items"
reference_sharding_path: "gs://my-bucket/sharding/abstract_state.json"
D β Capture a baseline, then compare#
RB=checkpoint/orbax/checkpoint/_src/testing/benchmarks/run_benchmarks.py
# baseline revision:
python $RB --config_file=sweep.yaml --output_directory=/tmp/baseline/ \
--baseline_capture_path=gs://my-bucket/baselines/sweep/
# candidate revision:
python $RB --config_file=sweep.yaml --output_directory=/tmp/candidate/ \
--baseline_path=gs://my-bucket/baselines/sweep/<git_sha>.json
10. Output layout#
<output_directory>/
βββ tensorboard/ # scalars, HParams, markdown cards
β βββ <benchmark>__<op>/ # per-operation profiler traces (enable_trace)
βββ hlo_dump/ # XLA HLO protos (--enable_hlo_dump)
βββ <benchmark>/repeat_*/ # per-run checkpoint dirs (unless --remove_repeated_dir)
Metrics are also printed to the logs as a per-process report after each
benchmark, and (with --baseline_path) as a per-metric delta table.
11. Cheat sheet#
RUN python checkpoint/orbax/checkpoint/_src/testing/benchmarks/run_benchmarks.py \
--config_file=cfg.yaml --output_directory=/tmp/out/ [--baseline_capture_path=β¦ | --baseline_path=β¦]
MEASURE with metrics.measure("op"): <code> # captures the whole suite; blocks nest
return TestResult(metrics=metrics)
CLASS @benchmarks_core.benchmark_options(MyOptions)
class MyBenchmark(benchmarks_core.BenchmarksGenerator):
def test_fn(self, context) -> benchmarks_core.TestResult: ...
CONFIG suite_name / num_repeats / checkpoint_config(.spec|.path) / mesh_config / benchmarks:[{generator, options}]
SWEEP any list-valued option, or checkpoint_configs / mesh_configs
DEFAULT metrics: time, rss, jax_monitoring, device_memory, tensorstore (+ tracemalloc opt-in)
HEADLINE <op>_4_throughput/*_gbps Β· <op>_0_basics/time_s Β· <op>_7_memory/device_hbm_peak_diff_gb Β· <op>_8_jax/cache_hit_rate
PATHWAYS auto-init on Pathways Β· checkpoint_config.load_with_colocated_python: true Β· xpk/launch_xpk.py --enable_pathways
12. Source map#
Every moving part, on google/orbax main:
Core framework β core/
File |
Contains |
|---|---|
|
|
|
|
event β TensorBoard-tag map |
|
|
|
YAML β |
|
synthetic generation + the Pathways colocated-Python load path |
|
|
|
baseline capture / compare |
|
cross-host aggregation + TensorBoard writing |
Entrypoint β run_benchmarks.py (CLI flags, distributed / Pathways init)
Built-in generators β v1/:
benchmark.py Β·
resharding_benchmark.py Β·
restore_and_broadcast_benchmark.py Β·
replica_parallel_multislice_benchmark.py
Cloud / launchers β
xpk/
(GKE/XPK; launch_xpk.py,
PathwaysColocatedPythonGuide.md) Β·
tpu_vm/
(single TPU VM)