DTensor utilities for multi–device/host export#
Utilities for transforming distributed JAX arrays to DTensors.
- orbax.export.dtensor_utils.initialize_dtensor(reset_context=False)[source][source]#
Initialize a DTensor system for Orbax Export.
This function aligns the TensorFlow DTensor environment with the existing JAX process and device topology. It configures the DTensor accelerator system to use CPUs matching the local JAX device count, ensuring parity between JAX and DTensor distributed processing.
- Parameters:
reset_context (
bool) – Reset the tensorflow context along DTensor initialization. Behaviors of existing TensorFlow objects (e.g. Tensors) are undefined. Set this to True as an escape hatch, if there is no clear way to refactor your code to call initialize_dtensor() before calling TensorFlow APIs that initialize the context. See alsodtensor.initialize_accelerator_system().- Raises:
RuntimeError – if the number of DTensor clients is not the same as that of JAX processes.
AssertionError – If the DTensor virtual CPU count does not match the JAX device count.
- orbax.export.dtensor_utils.dtensor_initialized()[source][source]#
Checks whether DTensor is initialized and matches the JAX device set.
This function inspects the internal state to confirm if the TensorFlow DTensor environment has been configured to align with the current JAX topology.
Example
Safely initialize the DTensor environment only if it hasn’t been already:
from orbax.export.dtensor_utils import dtensor_initialized, initialize_dtensor if not dtensor_initialized(): initialize_dtensor()
- Return type:
bool- Returns:
True if the DTensor system has been successfully initialized, False otherwise.
- orbax.export.dtensor_utils.shutdown_dtensor()[source][source]#
Shuts down the DTensor system.
This function gracefully shuts down the TensorFlow DTensor environment and resets the internal initialization state. It must only be called after the DTensor system has been successfully initialized.
Example
Initialize and cleanly shut down the DTensor environment:
from orbax.export.dtensor_utils import initialize_dtensor, shutdown_dtensor initialize_dtensor() # Perform distributed export operations shutdown_dtensor()
- Raises:
RuntimeError – If the DTensor system has not been initialized prior to calling this function.
- Return type:
None
- orbax.export.dtensor_utils.jax_mesh_to_dtensor_mesh(mesh)[source][source]#
Creates a DTensor mesh from a JAX mesh.
This function constructs a DTensor mesh that mirrors the topology and axis names of the provided JAX mesh, enabling distributed operations across both frameworks.
Example
Convert a JAX mesh to a DTensor mesh for distributed export:
from jax.sharding import Mesh from orbax.export.dtensor_utils import jax_mesh_to_dtensor_mesh, initialize_dtensor # DTensor must be initialized before creating a DTensor mesh initialize_dtensor(reset_context=True) # Create a JAX mesh using available local devices jax_mesh = Mesh(np.array(jax.devices()), axis_names=('batch',)) # Convert to DTensor dt_mesh = jax_mesh_to_dtensor_mesh(jax_mesh)
- Parameters:
mesh (
Mesh) – Ajax.sharding.Meshrepresenting the global JAX mesh.- Return type:
Mesh- Returns:
A DTensor host mesh of the same shape and axis names as those of the JAX mesh.
- orbax.export.dtensor_utils.maybe_enable_dtensor_export_on(mesh)[source][source]#
Creates a DTensor context from a JAX mesh for Orbax Export.
This context manager temporarily pushes the provided JAX mesh and its corresponding DTensor mesh onto the global mesh stack, enabling DTensor-aware export operations within its block. If the DTensor system is not initialized or if the provided mesh is None, this context manager acts as a safe no-op.
Example
Apply the context manager to a standard JAX mesh during export:
from jax.sharding import Mesh from orbax.export.dtensor_utils import initialize_dtensor, maybe_enable_dtensor_export_on initialize_dtensor() jax_mesh = Mesh(np.array(jax.devices()), axis_names=('batch',)) with maybe_enable_dtensor_export_on(jax_mesh): # Perform DTensor-aware export operations here pass
- Parameters:
mesh (
Optional[Mesh,None]) – a JAX pjit Mesh.- Yields:
None.
- orbax.export.dtensor_utils.get_current_dtensor_mesh()[source][source]#
Returns the DTensor mesh in the current context.
This function retrieves the DTensor mesh from the current context, which is determined by the active maybe_enable_dtensor_export_on context.
Example
Retrieve the active DTensor mesh inside an export context block:
from orbax.export.dtensor_utils import maybe_enable_dtensor_export_on from orbax.export.dtensor_utils import get_current_dtensor_mesh # Assuming jax_mesh is a valid jax.sharding.Mesh with maybe_enable_dtensor_export_on(jax_mesh): active_dt_mesh = get_current_dtensor_mesh() # active_dt_mesh is now the corresponding DTensor mesh
- Return type:
Optional[Mesh,None]- Returns:
The active DTensor mesh if inside a valid DTensor context, otherwise None.