Orbax

Orbax#

Orbax is a modular and customizable JAX checkpointing library built for high performance at scale, allowing for distributed array storage and checkpoint lifecycle management.

We are focused on providing a JAX-native approach to model persistence and recovery, with the goals of providing an API that is easy to use, highly performant, and maximimally compatible across the JAX ecosystem.

Performance

Checkpointing with Orbax is fast and memory-efficient, allowing for quick start-up and minimal training impact.

Distributed

Orbax abstracts the details of persisting disributed arrays and provides a unified API for single- and multi-process checkpointing.

Management

Orbax facilitates checkpointing in a training loop, e.g. through metadata management, garbage collection, and saving policies.

Flexibility

Orbax features out-of-the-box support for advanced workflows like topology-agnostic loading (resharding), partial loading, incremental saving, and more.

Extensibility

Orbax provides extensibility for user-defined types and logic through customizable handler interfaces.

Exporting

Orbax provides an associated library, orbax-export, for exporting JAX models to Tensorflow SavedModel format.

Support#

Please report any issues or request support using our issue tracker.