Orbax#

Orbax is an umbrella namespace providing common training utilities for JAX users. It includes multiple distinct but interrelated libraries.

Checkpointing

A flexible and customizable API for managing checkpoints consisting of various user-defined objects in multi-host, multi-device settings.

Exporting

A library for exporting JAX models to Tensorflow SavedModel format.

Installation#

There is no single orbax package, but rather a separate package for each functionality provided by the Orbax namespace.

The latest release of orbax-checkpoint can be installed from PyPI using

pip install orbax-checkpoint

You may also install directly from GitHub, using the following command. This can be used to obtain the most recent version of Optax.

pip install 'git+https://github.com/google/orbax/#subdirectory=checkpoint'

NOTE: Certain edge cases of orbax-checkpoint may not work on Windows.

Also, supporting them is not planned yet.

Similarly, orbax-export can be installed from PyPI using

pip install orbax-export

Install from GitHub using the following.

pip install 'git+https://github.com/google/orbax/#subdirectory=export'

Checkpointing#

Announcements
Getting Started
API Overview
API Refactor
Checkpointing PyTrees of Arrays
Optimized Checkpointing
Custom Handlers
Transformations
Preemption Tolerance
Async Checkpointing
API Reference

Exporting#

Getting Started
API Reference

Support#

Please report any issues or request support using our issue tracker.