Orbax is an umbrella namespace providing common training utilities for JAX users. It includes multiple distinct but interrelated libraries.


A flexible and customizable API for managing checkpoints consisting of various user-defined objects in multi-host, multi-device settings.


A library for exporting JAX models to Tensorflow SavedModel format.


There is no single orbax package, but rather a separate package for each functionality provided by the Orbax namespace.

The latest release of orbax-checkpoint can be installed from PyPI using

pip install orbax-checkpoint

You may also install directly from GitHub, using the following command. This can be used to obtain the most recent version of Optax.

pip install 'git+https://github.com/google/orbax/#subdirectory=checkpoint'

NOTE: Certain edge cases of orbax-checkpoint may not work on Windows.

Also, supporting them is not planned yet.

Similarly, orbax-export can be installed from PyPI using

pip install orbax-export

Install from GitHub using the following.

pip install 'git+https://github.com/google/orbax/#subdirectory=export'


Getting Started
API Overview
API Refactor
Checkpointing PyTrees of Arrays
Optimized Checkpointing
Custom Handlers
Preemption Tolerance
Async Checkpointing
API Reference


Getting Started
API Reference


Please report any issues or request support using our issue tracker.