Configuring Specialized Features#
This guide provides a comprehensive overview of ocp.Context in Orbax v1. It explains the underlying architecture, demonstrates basic and advanced usage patterns, and outlines best practices for managing environment, runtime, and I/O configuration in your training loops.
1. Configuration Behavior#
Context objects and their underlying configuration option dataclasses (e.g., ArrayOptions, AsyncOptions, FileOptions) serve two distinct roles during their lifecycle:
Standalone Configuration Templates (Mutable): When you instantiate
ctx = ocp.Context(), it acts as an in-memory template. You can freely build, modify, and inspect its configuration parameters using a clean, mutable dot-notation syntax.Active Execution Policies (Frozen): Once a
Contextis bound to a context manager (with ctx:), it becomes the active runtime policy for all Orbax operations executed within that block.
Strict immutability during use#
To guarantee thread safety and prevent unpredictable mid-flight side effects, Orbax enforces a strict immutability invariant on active contexts.
Attempting to mutate any configuration parameter while the context is active will immediately be intercepted and raise a RuntimeError.
First, let’s set up a minimal environment for our examples:
from etils import epath
import jax
import jax.numpy as jnp
from orbax.checkpoint import v1 as ocp
# Minimal setup for examples
directory = epath.Path('/tmp/my_checkpoint_dir')
params = {'w': jnp.zeros((2, 3)), 'b': jnp.ones((3, 4))}
params_tree = {'params': params}
step = 0
print('Setup complete.')
Setup complete.
ctx = ocp.Context()
ctx.asynchronous.timeout_secs = 600 # Perfectly valid (configuring template)
with ctx:
# Executing checkpoint operations...
ocp.save(directory / 'sync', params_tree)
# Attempting to mutate an active context is strictly prohibited:
try:
ctx.asynchronous.timeout_secs = 1200
except RuntimeError as e:
print(f"Caught expected error: {e}")
Caught expected error: Cannot mutate options of an active context. Configure before entering the `with` block.
2. Basic Usage & Configuration#
Configuring a Context relies on a hierarchical, dot-notation namespace. You do not need to construct or pass complex option dataclasses directly.
Below is a summary of the available option types.
Option Type |
Dot Path |
Description |
|---|---|---|
|
Asynchronous checkpoint saving operations, including timeouts, workers, and background finalization behavior. |
|
|
Multi-host and multi-process checkpointing behavior and barrier synchronization. |
|
|
Underlying filesystem interactions, directory permissions, atomicity protocols, and customized path implementations. |
|
|
PyTree-level saving, loading, and structural restoration behavior. |
|
|
High-performance tensor and array I/O, storage formats, compression, sharding, and multi-replica load-and-broadcast behavior. |
|
|
Handler resolution registries for custom checkpointable types. |
|
|
Pathways-specific distributed checkpointing implementations. |
|
|
Checkpoint cleanup and soft-deletion behavior across storage backends. |
|
|
Concurrent I/O memory limits and prioritized transfer scheduling to prevent out-of-memory (OOM) errors during large checkpoint operations. |
|
|
HuggingFace SafeTensors loading and conversion behavior. |
|
|
Permanent on-disk serialization layout format. |
Example: Basic Configuration & Execution#
When using Checkpointer, you do not need to wrap your training loop inside with ctx:. You can pass the Context object directly into the Checkpointer constructor (context=ctx). For standalone free functions (ocp.save), you use with ctx: to bind the active context.
# 1. Instantiate the root Context
ctx = ocp.Context()
# 2. Configure options via mutable dot-notation
ctx.asynchronous.timeout_secs = 1200
ctx.asynchronous.create_directories_asynchronously = True
ctx.array.saving.use_zarr3 = True
ctx.array.saving.use_compression = False
ctx.array.loading.enable_padding_and_truncation = True
ctx.pytree.loading.partial_load = True
# 3a. Using Checkpointer (Pass context directly into constructor)
with ocp.training.Checkpointer(directory / 'ckptr', context=ctx) as ckptr:
ckptr.save_checkpointables(step, {'params': params})
# 3b. Using Free Functions (Bind context via with block)
with ctx:
ocp.save(directory / 'free', params_tree)
3. Advanced Usage: Inheritance & Customization#
Orbax Context supports powerful inheritance patterns, allowing you to branch configurations for specialized sub-tasks without duplicating code or risking side effects.
3.1 Context Inheritance#
To inherit properties from an existing parent Context, pass the parent context directly to the constructor (ctx2 = ocp.Context(ctx1)).
Orbax performs a deep copy of the parent’s option tree (while safely sharing immutable functions and callbacks by reference). The resulting child context inherits all parent properties but is completely decoupled, allowing you to mutate ctx2 independently without affecting ctx1.
Inheritance in Checkpointer#
Note that when you pass a Context object into Checkpointer(..., context=ctx) (or when Checkpointer inherits an active context from a with ctx: block), Orbax automatically executes ocp.Context(ctx) under the hood. This means Checkpointer inherits all properties from your context but operates on a completely independent, unfrozen child copy, preserving perfect isolation.
# Parent context configures baseline rules
base_ctx = ocp.Context()
base_ctx.pytree.loading.partial_load = True
base_ctx.asynchronous.timeout_secs = 1200
# Checkpointer automatically branches a child context ocp.Context(base_ctx)
with ocp.training.Checkpointer(directory / 'child_ckptr', context=base_ctx) as ckptr:
ckptr.save_checkpointables(step, {'params': params})
3.2 Scoped Storage Options (Per-Leaf Configuration)#
When saving complex PyTrees, you may want certain parameter leaves (e.g., large weight matrices) to use different storage rules (e.g., lower precision dtypes or specific chunk shapes) than smaller parameters (e.g., biases). You can achieve this using scoped_storage_options_creator.
def custom_storage_rules(keypath, value):
# Downcast large weights to float16, leave biases as default
if 'weight' in jax.tree_util.keystr(keypath):
return ocp.options.ArrayOptions.Saving.StorageOptions(dtype=jnp.float16)
return None # Fall back to global storage_options
ctx = ocp.Context()
ctx.array.saving.storage_options.dtype = jnp.float32
ctx.array.saving.scoped_storage_options_creator = custom_storage_rules
with ctx:
ocp.save(directory / 'scoped', params_tree)
3.3 Custom Handler Registration#
Orbax provides a global registry for standard handlers (like PyTrees and JSON). However, when defining custom
CheckpointableHandler types (as detailed in the Customization Guide), you can configure
context-local registries to avoid polluting the global registry or causing conflicts across different modules.
By attaching a custom CheckpointableHandlerRegistry to ctx.checkpointables.registry, Orbax will resolve handlers for your custom checkpointables exclusively within that context’s scope.
# 1. Create a standalone local registry
local_registry = ocp.handlers.local_registry()
# 2. Register your custom handler (assuming MyCustomHandler is defined)
# local_registry.add(MyCustomHandler, checkpointable_name='custom_state')
# 3. Attach the local registry to your Context
ctx = ocp.Context()
ctx.checkpointables.registry = local_registry
# 4. Execute within the context scope; Orbax will now use local_registry
with ctx:
# ocp.save_checkpointables(directory, dict(custom_state=my_custom_object))
pass
4. Context Best Practices#
To ensure clean architectural design and prevent subtle concurrency or scoping bugs across multi-threaded, multi-host, or asynchronous workflows, adhere to the following core principles.
4.1 Concurrency in Asynchronous Workflows#
For standard asynchronous workflows where you configure a Context once at the start of your training script, spawning a dedicated child context (sub-context) is not required. The active context safely manages the background save operation, and Checkpointer automatically creates an isolated child context under the hood.
However, note that any background task or coroutine inheriting the active context receives a reference to the exact same Context instance in memory. If your advanced workflow requires the main thread to actively mutate the shared Context instance mid-flight after launching an async save, those mutations would propagate to the background task. In those rare cases, branching a child context (ctx2 = ocp.Context(main_ctx)) ensures perfect isolation.
Below is the standard, recommended pattern for asynchronous saving:
main_ctx = ocp.Context()
main_ctx.asynchronous.timeout_secs = 600
# Standard async save operates directly within the active context
with main_ctx:
response = ocp.save_async(directory / 'async', params_tree)
# Perform other concurrent training work...
response.result()
4.2 Thread Safety and Background Threads#
Orbax is not thread-safe in general. Calling Orbax operations from a custom background thread (e.g., via threading.Thread or ThreadPoolExecutor) is not supported and should be avoided. If you believe your specific architecture requires calling Orbax from a background thread, please consult with the Orbax team to discuss your use case.
4.3 Consistency in Distributed Multi-Host Environments#
In distributed training setups (e.g., multi-slice TPU pods or multi-node GPU clusters), Context settings like MultiprocessingOptions, AsyncOptions, and underlying storage sharding rules must be configured identically across all participating hosts. Mismatched configurations can lead to barrier deadlocks or corrupted metadata.
# Execute identical configuration logic across all participating hosts
cluster_ctx = ocp.Context()
cluster_ctx.multiprocessing.barrier_sync_key_prefix = "model_v1_sync_"
cluster_ctx.asynchronous.timeout_secs = 1200
# Ensure all hosts enter the managed block synchronously
with cluster_ctx:
ocp.save(directory / 'cluster', params_tree)
4.4 Configuration Scoping in Multi-Actor Scripts#
When authoring scripts or test suites that mix Checkpointer objects and standalone free functions (ocp.save), guarantee isolation by wrapping free functions in dedicated with ocp.Context(...): blocks or passing explicit contexts into Checkpointer constructors. This ensures each independent actor executes with its intended settings regardless of the outer scope.
4.5 Frozen Lineage Protection#
When instantiating a child context from a parent context (ctx2 = ocp.Context(ctx1)), Orbax guarantees strict lineage safety. If the parent ctx1 is currently bound to an active with block, ctx1 remains fully frozen against mutation.
Creating ctx2 produces a completely independent, unfrozen child copy that you can customize freely, but it does not unfreeze or compromise the parent ctx1. Attempting to mutate ctx1 while it is active is strictly blocked by the freeze guard.
parent_ctx = ocp.Context()
parent_ctx.asynchronous.timeout_secs = 600
with parent_ctx:
# Safely branch an unfrozen child context from the active parent
child_ctx = ocp.Context(parent_ctx)
child_ctx.asynchronous.timeout_secs = 300 # Perfectly valid (modifying child)
# The active parent context remains strictly frozen against mutation:
try:
parent_ctx.asynchronous.timeout_secs = 1200
except RuntimeError as e:
print(f"Caught expected error: {e}")
Caught expected error: Cannot mutate options of an active context. Configure before entering the `with` block.
5. Summary of Best Practices#
Configure Early, Execute Late: Instantiate and configure your
ocp.Contextonce at the top of your training script or sub-task.Dot-Notation: Use
ctx.asynchronous.timeout_secs = ...rather than constructing complex option dataclasses.Branch via Inheritance: Use
ctx2 = ocp.Context(ctx1)when you need modified settings for a specific evaluation or export step.Respect Immutability: Never attempt to mutate a
Contextwhile it is actively bound to awithblock or being used by a background asynchronous save.Manage Checkpointer Lifecycle: Prefer
with ocp.training.Checkpointer(...) as ckptr:to guarantee clean resource cleanup and complete outstanding writes.Consult API Reference: For in-depth configuration fields, sub-options, and advanced usage, refer to the Orbax Context Options API Reference.