Orbax#
Orbax is a modular and customizable JAX checkpointing library built for high performance at scale, allowing for distributed array storage and checkpoint lifecycle management.
We are focused on providing a JAX-native approach to model persistence and recovery, with the goals of providing an API that is easy to use, highly performant, and maximimally compatible across the JAX ecosystem.
Checkpointing with Orbax is fast and memory-efficient, allowing for quick start-up and minimal training impact.
Orbax abstracts the details of persisting disributed arrays and provides a unified API for single- and multi-process checkpointing.
Orbax facilitates checkpointing in a training loop, e.g. through metadata management, garbage collection, and saving policies.
Orbax features out-of-the-box support for advanced workflows like topology-agnostic loading (resharding), partial loading, incremental saving, and more.
Orbax provides extensibility for user-defined types and logic through customizable handler interfaces.
Orbax provides an associated library, orbax-export, for exporting JAX models to Tensorflow SavedModel format.
Quick Links#
Support#
Please report any issues or request support using our issue tracker.