DTensor utilities for multi–device/host export#

Utilities for transforming distributed JAX arrays to DTensors.

orbax.export.dtensor_utils.initialize_dtensor(reset_context=False)[source][source]#

Initialize a DTensor system for Orbax Export.

This function aligns the TensorFlow DTensor environment with the existing JAX process and device topology. It configures the DTensor accelerator system to use CPUs matching the local JAX device count, ensuring parity between JAX and DTensor distributed processing.

Parameters:

reset_context (bool) – Reset the tensorflow context along DTensor initialization. Behaviors of existing TensorFlow objects (e.g. Tensors) are undefined. Set this to True as an escape hatch, if there is no clear way to refactor your code to call initialize_dtensor() before calling TensorFlow APIs that initialize the context. See also dtensor.initialize_accelerator_system().

Raises:
  • RuntimeError – if the number of DTensor clients is not the same as that of JAX processes.

  • AssertionError – If the DTensor virtual CPU count does not match the JAX device count.

orbax.export.dtensor_utils.dtensor_initialized()[source][source]#

Checks whether DTensor is initialized and matches the JAX device set.

This function inspects the internal state to confirm if the TensorFlow DTensor environment has been configured to align with the current JAX topology.

Example

Safely initialize the DTensor environment only if it hasn’t been already:

from orbax.export.dtensor_utils import dtensor_initialized,
initialize_dtensor

if not dtensor_initialized():
    initialize_dtensor()
Return type:

bool

Returns:

True if the DTensor system has been successfully initialized, False otherwise.

orbax.export.dtensor_utils.shutdown_dtensor()[source][source]#

Shuts down the DTensor system.

This function gracefully shuts down the TensorFlow DTensor environment and resets the internal initialization state. It must only be called after the DTensor system has been successfully initialized.

Example

Initialize and cleanly shut down the DTensor environment:

from orbax.export.dtensor_utils import initialize_dtensor,
shutdown_dtensor

initialize_dtensor()
# Perform distributed export operations
shutdown_dtensor()
Raises:

RuntimeError – If the DTensor system has not been initialized prior to calling this function.

Return type:

None

orbax.export.dtensor_utils.jax_mesh_to_dtensor_mesh(mesh)[source][source]#

Creates a DTensor mesh from a JAX mesh.

This function constructs a DTensor mesh that mirrors the topology and axis names of the provided JAX mesh, enabling distributed operations across both frameworks.

Example

Convert a JAX mesh to a DTensor mesh for distributed export:

from jax.sharding import Mesh
from orbax.export.dtensor_utils import jax_mesh_to_dtensor_mesh,
initialize_dtensor

# DTensor must be initialized before creating a DTensor mesh
initialize_dtensor(reset_context=True)

# Create a JAX mesh using available local devices
jax_mesh = Mesh(np.array(jax.devices()), axis_names=('batch',))

# Convert to DTensor
dt_mesh = jax_mesh_to_dtensor_mesh(jax_mesh)
Parameters:

mesh (Mesh) – A jax.sharding.Mesh representing the global JAX mesh.

Return type:

Mesh

Returns:

A DTensor host mesh of the same shape and axis names as those of the JAX mesh.

orbax.export.dtensor_utils.maybe_enable_dtensor_export_on(mesh)[source][source]#

Creates a DTensor context from a JAX mesh for Orbax Export.

This context manager temporarily pushes the provided JAX mesh and its corresponding DTensor mesh onto the global mesh stack, enabling DTensor-aware export operations within its block. If the DTensor system is not initialized or if the provided mesh is None, this context manager acts as a safe no-op.

Example

Apply the context manager to a standard JAX mesh during export:

from jax.sharding import Mesh
from orbax.export.dtensor_utils import initialize_dtensor,
maybe_enable_dtensor_export_on

initialize_dtensor()
jax_mesh = Mesh(np.array(jax.devices()), axis_names=('batch',))

with maybe_enable_dtensor_export_on(jax_mesh):
    # Perform DTensor-aware export operations here
    pass
Parameters:

mesh (Optional[Mesh, None]) – a JAX pjit Mesh.

Yields:

None.

orbax.export.dtensor_utils.get_current_dtensor_mesh()[source][source]#

Returns the DTensor mesh in the current context.

This function retrieves the DTensor mesh from the current context, which is determined by the active maybe_enable_dtensor_export_on context.

Example

Retrieve the active DTensor mesh inside an export context block:

from orbax.export.dtensor_utils import maybe_enable_dtensor_export_on
from orbax.export.dtensor_utils import get_current_dtensor_mesh

# Assuming jax_mesh is a valid jax.sharding.Mesh
with maybe_enable_dtensor_export_on(jax_mesh):
    active_dt_mesh = get_current_dtensor_mesh()
    # active_dt_mesh is now the corresponding DTensor mesh
Return type:

Optional[Mesh, None]

Returns:

The active DTensor mesh if inside a valid DTensor context, otherwise None.