Source code for orbax.export.dtensor_utils

# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities for transforming distributed JAX arrays to DTensors."""

import contextlib
import dataclasses
import threading
from typing import Optional

import jax
from jax.experimental import pjit
import jaxtyping
import numpy as np
import tensorflow as tf
from tensorflow.experimental import dtensor


DTensor = tf.Tensor


_DTENSOR_INITIALIZED = False


[docs] def initialize_dtensor(reset_context: bool = False): """Initialize a DTensor system for Orbax Export. This function aligns the TensorFlow DTensor environment with the existing JAX process and device topology. It configures the DTensor accelerator system to use CPUs matching the local JAX device count, ensuring parity between JAX and DTensor distributed processing. Args: reset_context: Reset the tensorflow context along DTensor initialization. Behaviors of existing TensorFlow objects (e.g. Tensors) are undefined. Set this to True as an escape hatch, if there is no clear way to refactor your code to call initialize_dtensor() before calling TensorFlow APIs that initialize the context. See also :py:func:`dtensor.initialize_accelerator_system`. Raises: RuntimeError: if the number of DTensor clients is not the same as that of JAX processes. AssertionError: If the DTensor virtual CPU count does not match the JAX device count. """ n_jax_devices = jax.device_count() n_jax_local_devices = jax.local_device_count() n_jax_processes = jax.process_count() dtensor.initialize_accelerator_system( device_type='CPU', num_logical_cpu_devices=n_jax_local_devices, experimental_reset_context=reset_context, ) if dtensor.num_clients() != n_jax_processes: raise RuntimeError( f'The number of DTensor clients ({dtensor.num_clients()}) is not equal' f' to the number of JAX processes ({n_jax_processes}. Did you forget to' ' set ``DTENSOR_JOBS`` or other DTensor env variables for all the JAX' ' processes?' ) assert ( dtensor.num_local_devices('CPU'), dtensor.num_global_devices('CPU'), ) == (n_jax_local_devices, n_jax_devices), ( 'DTensor virtual CPU count does not match JAX device count, this is' ' impossible.' ) global _DTENSOR_INITIALIZED _DTENSOR_INITIALIZED = True
[docs] def dtensor_initialized() -> bool: """Checks whether DTensor is initialized and matches the JAX device set. This function inspects the internal state to confirm if the TensorFlow DTensor environment has been configured to align with the current JAX topology. Example: Safely initialize the DTensor environment only if it hasn't been already:: from orbax.export.dtensor_utils import dtensor_initialized, initialize_dtensor if not dtensor_initialized(): initialize_dtensor() Returns: True if the DTensor system has been successfully initialized, False otherwise. """ return _DTENSOR_INITIALIZED
[docs] def shutdown_dtensor() -> None: """Shuts down the DTensor system. This function gracefully shuts down the TensorFlow DTensor environment and resets the internal initialization state. It must only be called after the DTensor system has been successfully initialized. Example: Initialize and cleanly shut down the DTensor environment:: from orbax.export.dtensor_utils import initialize_dtensor, shutdown_dtensor initialize_dtensor() # Perform distributed export operations shutdown_dtensor() Raises: RuntimeError: If the DTensor system has not been initialized prior to calling this function. """ if not dtensor_initialized(): raise RuntimeError('DTensor is not initialized.') dtensor.shutdown_accelerator_system() global _DTENSOR_INITIALIZED _DTENSOR_INITIALIZED = False
[docs] def jax_mesh_to_dtensor_mesh(mesh: jax.sharding.Mesh) -> dtensor.Mesh: """Creates a DTensor mesh from a JAX mesh. This function constructs a DTensor mesh that mirrors the topology and axis names of the provided JAX mesh, enabling distributed operations across both frameworks. Example: Convert a JAX mesh to a DTensor mesh for distributed export:: from jax.sharding import Mesh from orbax.export.dtensor_utils import jax_mesh_to_dtensor_mesh, initialize_dtensor # DTensor must be initialized before creating a DTensor mesh initialize_dtensor(reset_context=True) # Create a JAX mesh using available local devices jax_mesh = Mesh(np.array(jax.devices()), axis_names=('batch',)) # Convert to DTensor dt_mesh = jax_mesh_to_dtensor_mesh(jax_mesh) Args: mesh: A :py:class:`jax.sharding.Mesh` representing the global JAX mesh. Returns: A DTensor host mesh of the same shape and axis names as those of the JAX mesh. """ mesh_shape = mesh.devices.shape global_device_ids = np.arange(0, np.prod(mesh_shape)).reshape(mesh_shape) with mesh: # Shard the global device ids so that each process gets the local device ids # for the corresponding DTensor mesh. sharded_device_ids = pjit.pjit( lambda x: x, out_shardings=jax.sharding.PartitionSpec(*mesh.axis_names), )(global_device_ids) local_device_ids = [ int(s.data.item()) for s in sharded_device_ids.addressable_shards ] return dtensor.Mesh( list(mesh.shape.keys()), global_device_ids=global_device_ids, local_device_ids=list(local_device_ids), local_devices=dtensor.local_devices(device_type='CPU'), )
def _reshard_jax_array( array: jax.Array, mesh: jax.sharding.Mesh, pspec: jax.sharding.PartitionSpec ) -> jax.Array: """Reshards a jax.Array.""" with mesh: return pjit.pjit(lambda x: x, out_shardings=pspec)(array) # TODO(b/261191533): jax.Array contains OpSharding info. Maybe we can get rid # of jax.sharding.PartitionSpec here. def jax_array_to_dtensor( arr: jax.Array, pspec: jax.sharding.PartitionSpec, dmesh: dtensor.Mesh, jax_mesh: Optional[jax.sharding.Mesh] = None, allow_multi_axis_sharding_consolidation: bool = False, ) -> DTensor: """Converts a jax.Array to a dtensor. Args: arr: A jax.Array. pspec: The partition spec of the input ``array``. dmesh: The DTensor mesh where the output dtensor is created. jax_mesh: The jax mesh for the jax array and partition spec. allow_multi_axis_sharding_consolidation: Whether reducing sharding a dimension across multiple axis names to one is allowed or not. Returns: A DTensor sharded in the same way as the input ``arr`` when the partition spec does not have product-sharding (i.e., sharding a dimension across multiple axis names). Or a DTensor with consolidated/reduced sharding when the partition spec has product-sharding and `allow_multi_axis_sharding_consolidation` is true (e.g. input array's sharding is P(None, ('a','b') and output dtensor sharding gets reduced to P(None, 'a')). Raises: ValueError: When `allow_multi_axis_sharding_consolidation` is false and if a dimension of ``arr`` is product-sharded, i.e., sharded across more than one axes of the mesh. For example, if a mesh has two axes `'x'` and `'y'`, `PartitionSpec((x, y))` is considered product-sharded if the mesh size of both axes are greater than 1. If the mesh size of the `'x'` or `'y'` is 1, the spec is not considered product-sharded because it is efftively sharded on one axis only. """ arr_reshard_needed = False if pspec is None: dspec = [dtensor.UNSHARDED] * len(arr.shape) else: dspec = list() for i, mesh_axis_name in enumerate(pspec): if mesh_axis_name: if not isinstance(mesh_axis_name, str): if not isinstance(mesh_axis_name, tuple): raise TypeError( 'An element in a PartitionSpec must be be a ``None``, a mesh' ' axis or a tuple of mesh axes. Got {mesh_axis_name}.' ) if len(mesh_axis_name) > 1: dim_sizes = tuple(dmesh.dim_size(name) for name in mesh_axis_name) if dim_sizes.count(1) < len(mesh_axis_name) - 1: if not allow_multi_axis_sharding_consolidation: raise ValueError( f'Dimension {i} of the input array (shape={arr.shape}) is' f' sharded across more than one axis ({mesh_axis_name},' f' sizes = {dim_sizes}) of the mesh, but jax.Array to' ' DTensor transform does not support partitioning of an' ' array dimension across multiple mesh axes, unless there' ' is at most one axis with size >= 1.' ) else: # Reduce/consolidate partition across multiple axis names # into one of those axis name. The selected axis name will be # the one with the highest dim size. E.g. P(None, ('a', 'b')) # will be reduced to P(None, ('a')) for {'a':4, 'b':2} # or P(None, ('b')) for {'a':2, 'b':4}. max_dim_size = max( tuple(dmesh.dim_size(name) for name in mesh_axis_name) ) max_dim_size_idx = dim_sizes.index(max_dim_size) mesh_axis_name = tuple([mesh_axis_name[max_dim_size_idx]]) arr_reshard_needed = True else: mesh_axis_name = tuple( filter(lambda x: dmesh.dim_size(x) != 1, mesh_axis_name) ) or (mesh_axis_name[0],) assert len(mesh_axis_name) == 1, mesh_axis_name mesh_axis_name = mesh_axis_name[0] mesh_dim_size = dmesh.dim_size(mesh_axis_name) if arr.shape[i] % dmesh.dim_size(mesh_axis_name) != 0: raise ValueError( f'The size of the dim {i} (={arr.shape[i]}) of the input array' f' (shape={arr.shape}) must be a multiple of the size of' f' mesh axis "{mesh_axis_name}" (={mesh_dim_size}).)' ) dspec.append(mesh_axis_name) else: dspec.append(dtensor.UNSHARDED) dspec.extend([dtensor.UNSHARDED] * (len(arr.shape) - len(pspec))) layout = dtensor.Layout(dspec, dmesh) if not arr_reshard_needed: local_data = [s.data for s in arr.addressable_shards] else: resharded_arr = _reshard_jax_array( arr, jax_mesh, jax.sharding.PartitionSpec(*[ axis_name if axis_name != dtensor.UNSHARDED else None for axis_name in dspec ]), ) local_data = [s.data for s in resharded_arr.addressable_shards] del resharded_arr return dtensor.pack(local_data, layout) class _ThreadLocalStack(threading.local): def __init__(self): super().__init__() self.stack = list() _MESH_STACK = _ThreadLocalStack() @dataclasses.dataclass(frozen=True) class Mesh: jax_mesh: jax.sharding.Mesh dtensor_mesh: dtensor.Mesh
[docs] @contextlib.contextmanager def maybe_enable_dtensor_export_on(mesh: Optional[jax.sharding.Mesh]): """Creates a DTensor context from a JAX mesh for Orbax Export. This context manager temporarily pushes the provided JAX mesh and its corresponding DTensor mesh onto the global mesh stack, enabling DTensor-aware export operations within its block. If the DTensor system is not initialized or if the provided `mesh` is None, this context manager acts as a safe no-op. Example: Apply the context manager to a standard JAX mesh during export:: from jax.sharding import Mesh from orbax.export.dtensor_utils import initialize_dtensor, maybe_enable_dtensor_export_on initialize_dtensor() jax_mesh = Mesh(np.array(jax.devices()), axis_names=('batch',)) with maybe_enable_dtensor_export_on(jax_mesh): # Perform DTensor-aware export operations here pass Args: mesh: a JAX pjit Mesh. Yields: None. """ if not dtensor_initialized() or mesh is None: yield else: _MESH_STACK.stack.append( Mesh(jax_mesh=mesh, dtensor_mesh=jax_mesh_to_dtensor_mesh(mesh)) ) try: yield finally: _MESH_STACK.stack.pop(-1)
[docs] def get_current_dtensor_mesh() -> Optional[dtensor.Mesh]: """Returns the DTensor mesh in the current context. This function retrieves the DTensor mesh from the current context, which is determined by the active `maybe_enable_dtensor_export_on` context. Example: Retrieve the active DTensor mesh inside an export context block:: from orbax.export.dtensor_utils import maybe_enable_dtensor_export_on from orbax.export.dtensor_utils import get_current_dtensor_mesh # Assuming jax_mesh is a valid jax.sharding.Mesh with maybe_enable_dtensor_export_on(jax_mesh): active_dt_mesh = get_current_dtensor_mesh() # active_dt_mesh is now the corresponding DTensor mesh Returns: The active DTensor mesh if inside a valid DTensor context, otherwise `None`. """ mesh = get_current_mesh() return mesh.dtensor_mesh if mesh else None
def get_current_mesh() -> Optional[Mesh]: """Returns the Jax and DTensor mesh in the current context.""" return _MESH_STACK.stack[-1] if _MESH_STACK.stack else None def get_pspec_from_jax_arrays( nested_jax_arrays: jaxtyping.PyTree, ) -> jaxtyping.PyTree[jax.sharding.PartitionSpec]: """Get the partition spec of a nested jax.Array or jax.ShapeDtypeStruct. Args: nested_jax_arrays: a nested structure of jax.Array or jax.ShapeDtypeStruct. Returns: A nested structure of jax.sharding.PartitionSpec. Raises: AssertionError: if the input nested structure contains jax.Array with different meshes. """ expected_mesh = None def _get_partition_spec(jax_arr): nonlocal expected_mesh if not hasattr(jax_arr, 'sharding') or jax_arr.sharding is None: return jax.sharding.PartitionSpec() if not isinstance(jax_arr.sharding, jax.sharding.NamedSharding): raise AssertionError( f'Unsupported sharding type: {type(jax_arr.sharding)}, only support' ' NamedSharding' ) expected_mesh = ( jax_arr.sharding.mesh if not expected_mesh else expected_mesh ) if expected_mesh != jax_arr.sharding.mesh: raise AssertionError( 'All those NamedShardings must have the same mesh.' f' {expected_mesh} != {jax_arr.sharding.mesh}' ) else: return jax_arr.sharding.spec return jax.tree_util.tree_map(_get_partition_spec, nested_jax_arrays)