# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for transforming distributed JAX arrays to DTensors."""
import contextlib
import dataclasses
import threading
from typing import Optional
import jax
from jax.experimental import pjit
import jaxtyping
import numpy as np
import tensorflow as tf
from tensorflow.experimental import dtensor
DTensor = tf.Tensor
_DTENSOR_INITIALIZED = False
[docs]
def initialize_dtensor(reset_context: bool = False):
"""Initialize a DTensor system for Orbax Export.
This function aligns the TensorFlow DTensor environment with the existing JAX
process and device topology. It configures the DTensor accelerator system to
use CPUs matching the local JAX device count, ensuring parity between JAX and
DTensor distributed processing.
Args:
reset_context: Reset the tensorflow context along DTensor initialization.
Behaviors of existing TensorFlow objects (e.g. Tensors) are undefined. Set
this to True as an escape hatch, if there is no clear way to refactor your
code to call initialize_dtensor() before calling TensorFlow APIs that
initialize the context. See also
:py:func:`dtensor.initialize_accelerator_system`.
Raises:
RuntimeError: if the number of DTensor clients is not the same as that of
JAX processes.
AssertionError: If the DTensor virtual CPU count does not match the JAX
device count.
"""
n_jax_devices = jax.device_count()
n_jax_local_devices = jax.local_device_count()
n_jax_processes = jax.process_count()
dtensor.initialize_accelerator_system(
device_type='CPU',
num_logical_cpu_devices=n_jax_local_devices,
experimental_reset_context=reset_context,
)
if dtensor.num_clients() != n_jax_processes:
raise RuntimeError(
f'The number of DTensor clients ({dtensor.num_clients()}) is not equal'
f' to the number of JAX processes ({n_jax_processes}. Did you forget to'
' set ``DTENSOR_JOBS`` or other DTensor env variables for all the JAX'
' processes?'
)
assert (
dtensor.num_local_devices('CPU'),
dtensor.num_global_devices('CPU'),
) == (n_jax_local_devices, n_jax_devices), (
'DTensor virtual CPU count does not match JAX device count, this is'
' impossible.'
)
global _DTENSOR_INITIALIZED
_DTENSOR_INITIALIZED = True
[docs]
def dtensor_initialized() -> bool:
"""Checks whether DTensor is initialized and matches the JAX device set.
This function inspects the internal state to confirm if the TensorFlow DTensor
environment has been configured to align with the current JAX topology.
Example:
Safely initialize the DTensor environment only if it hasn't been already::
from orbax.export.dtensor_utils import dtensor_initialized,
initialize_dtensor
if not dtensor_initialized():
initialize_dtensor()
Returns:
True if the DTensor system has been successfully initialized, False
otherwise.
"""
return _DTENSOR_INITIALIZED
[docs]
def shutdown_dtensor() -> None:
"""Shuts down the DTensor system.
This function gracefully shuts down the TensorFlow DTensor environment and
resets the internal initialization state. It must only be called after
the DTensor system has been successfully initialized.
Example:
Initialize and cleanly shut down the DTensor environment::
from orbax.export.dtensor_utils import initialize_dtensor,
shutdown_dtensor
initialize_dtensor()
# Perform distributed export operations
shutdown_dtensor()
Raises:
RuntimeError: If the DTensor system has not been initialized prior to
calling this function.
"""
if not dtensor_initialized():
raise RuntimeError('DTensor is not initialized.')
dtensor.shutdown_accelerator_system()
global _DTENSOR_INITIALIZED
_DTENSOR_INITIALIZED = False
[docs]
def jax_mesh_to_dtensor_mesh(mesh: jax.sharding.Mesh) -> dtensor.Mesh:
"""Creates a DTensor mesh from a JAX mesh.
This function constructs a DTensor mesh that mirrors the topology and axis
names of the provided JAX mesh, enabling distributed operations across both
frameworks.
Example:
Convert a JAX mesh to a DTensor mesh for distributed export::
from jax.sharding import Mesh
from orbax.export.dtensor_utils import jax_mesh_to_dtensor_mesh,
initialize_dtensor
# DTensor must be initialized before creating a DTensor mesh
initialize_dtensor(reset_context=True)
# Create a JAX mesh using available local devices
jax_mesh = Mesh(np.array(jax.devices()), axis_names=('batch',))
# Convert to DTensor
dt_mesh = jax_mesh_to_dtensor_mesh(jax_mesh)
Args:
mesh: A :py:class:`jax.sharding.Mesh` representing the global JAX mesh.
Returns:
A DTensor host mesh of the same shape and axis names as those of the JAX
mesh.
"""
mesh_shape = mesh.devices.shape
global_device_ids = np.arange(0, np.prod(mesh_shape)).reshape(mesh_shape)
with mesh:
# Shard the global device ids so that each process gets the local device ids
# for the corresponding DTensor mesh.
sharded_device_ids = pjit.pjit(
lambda x: x,
out_shardings=jax.sharding.PartitionSpec(*mesh.axis_names),
)(global_device_ids)
local_device_ids = [
int(s.data.item()) for s in sharded_device_ids.addressable_shards
]
return dtensor.Mesh(
list(mesh.shape.keys()),
global_device_ids=global_device_ids,
local_device_ids=list(local_device_ids),
local_devices=dtensor.local_devices(device_type='CPU'),
)
def _reshard_jax_array(
array: jax.Array, mesh: jax.sharding.Mesh, pspec: jax.sharding.PartitionSpec
) -> jax.Array:
"""Reshards a jax.Array."""
with mesh:
return pjit.pjit(lambda x: x, out_shardings=pspec)(array)
# TODO(b/261191533): jax.Array contains OpSharding info. Maybe we can get rid
# of jax.sharding.PartitionSpec here.
def jax_array_to_dtensor(
arr: jax.Array,
pspec: jax.sharding.PartitionSpec,
dmesh: dtensor.Mesh,
jax_mesh: Optional[jax.sharding.Mesh] = None,
allow_multi_axis_sharding_consolidation: bool = False,
) -> DTensor:
"""Converts a jax.Array to a dtensor.
Args:
arr: A jax.Array.
pspec: The partition spec of the input ``array``.
dmesh: The DTensor mesh where the output dtensor is created.
jax_mesh: The jax mesh for the jax array and partition spec.
allow_multi_axis_sharding_consolidation: Whether reducing sharding a
dimension across multiple axis names to one is allowed or not.
Returns:
A DTensor sharded in the same way as the input ``arr`` when the partition
spec does not have product-sharding (i.e., sharding a dimension across
multiple axis names). Or a DTensor with consolidated/reduced sharding
when the partition spec has product-sharding and
`allow_multi_axis_sharding_consolidation` is true (e.g. input array's
sharding is P(None, ('a','b') and output dtensor sharding gets reduced
to P(None, 'a')).
Raises:
ValueError: When `allow_multi_axis_sharding_consolidation` is false and if a
dimension of ``arr`` is product-sharded, i.e., sharded across more than
one axes of the mesh. For example, if a mesh has two axes `'x'` and `'y'`,
`PartitionSpec((x, y))` is considered product-sharded if the mesh size of
both axes are greater than 1. If the mesh size of the `'x'` or `'y'` is 1,
the spec is not considered product-sharded because it is efftively
sharded on one axis only.
"""
arr_reshard_needed = False
if pspec is None:
dspec = [dtensor.UNSHARDED] * len(arr.shape)
else:
dspec = list()
for i, mesh_axis_name in enumerate(pspec):
if mesh_axis_name:
if not isinstance(mesh_axis_name, str):
if not isinstance(mesh_axis_name, tuple):
raise TypeError(
'An element in a PartitionSpec must be be a ``None``, a mesh'
' axis or a tuple of mesh axes. Got {mesh_axis_name}.'
)
if len(mesh_axis_name) > 1:
dim_sizes = tuple(dmesh.dim_size(name) for name in mesh_axis_name)
if dim_sizes.count(1) < len(mesh_axis_name) - 1:
if not allow_multi_axis_sharding_consolidation:
raise ValueError(
f'Dimension {i} of the input array (shape={arr.shape}) is'
f' sharded across more than one axis ({mesh_axis_name},'
f' sizes = {dim_sizes}) of the mesh, but jax.Array to'
' DTensor transform does not support partitioning of an'
' array dimension across multiple mesh axes, unless there'
' is at most one axis with size >= 1.'
)
else:
# Reduce/consolidate partition across multiple axis names
# into one of those axis name. The selected axis name will be
# the one with the highest dim size. E.g. P(None, ('a', 'b'))
# will be reduced to P(None, ('a')) for {'a':4, 'b':2}
# or P(None, ('b')) for {'a':2, 'b':4}.
max_dim_size = max(
tuple(dmesh.dim_size(name) for name in mesh_axis_name)
)
max_dim_size_idx = dim_sizes.index(max_dim_size)
mesh_axis_name = tuple([mesh_axis_name[max_dim_size_idx]])
arr_reshard_needed = True
else:
mesh_axis_name = tuple(
filter(lambda x: dmesh.dim_size(x) != 1, mesh_axis_name)
) or (mesh_axis_name[0],)
assert len(mesh_axis_name) == 1, mesh_axis_name
mesh_axis_name = mesh_axis_name[0]
mesh_dim_size = dmesh.dim_size(mesh_axis_name)
if arr.shape[i] % dmesh.dim_size(mesh_axis_name) != 0:
raise ValueError(
f'The size of the dim {i} (={arr.shape[i]}) of the input array'
f' (shape={arr.shape}) must be a multiple of the size of'
f' mesh axis "{mesh_axis_name}" (={mesh_dim_size}).)'
)
dspec.append(mesh_axis_name)
else:
dspec.append(dtensor.UNSHARDED)
dspec.extend([dtensor.UNSHARDED] * (len(arr.shape) - len(pspec)))
layout = dtensor.Layout(dspec, dmesh)
if not arr_reshard_needed:
local_data = [s.data for s in arr.addressable_shards]
else:
resharded_arr = _reshard_jax_array(
arr,
jax_mesh,
jax.sharding.PartitionSpec(*[
axis_name if axis_name != dtensor.UNSHARDED else None
for axis_name in dspec
]),
)
local_data = [s.data for s in resharded_arr.addressable_shards]
del resharded_arr
return dtensor.pack(local_data, layout)
class _ThreadLocalStack(threading.local):
def __init__(self):
super().__init__()
self.stack = list()
_MESH_STACK = _ThreadLocalStack()
@dataclasses.dataclass(frozen=True)
class Mesh:
jax_mesh: jax.sharding.Mesh
dtensor_mesh: dtensor.Mesh
[docs]
@contextlib.contextmanager
def maybe_enable_dtensor_export_on(mesh: Optional[jax.sharding.Mesh]):
"""Creates a DTensor context from a JAX mesh for Orbax Export.
This context manager temporarily pushes the provided JAX mesh and its
corresponding DTensor mesh onto the global mesh stack, enabling
DTensor-aware export operations within its block. If the DTensor system
is not initialized or if the provided `mesh` is None, this context
manager acts as a safe no-op.
Example:
Apply the context manager to a standard JAX mesh during export::
from jax.sharding import Mesh
from orbax.export.dtensor_utils import initialize_dtensor,
maybe_enable_dtensor_export_on
initialize_dtensor()
jax_mesh = Mesh(np.array(jax.devices()), axis_names=('batch',))
with maybe_enable_dtensor_export_on(jax_mesh):
# Perform DTensor-aware export operations here
pass
Args:
mesh: a JAX pjit Mesh.
Yields:
None.
"""
if not dtensor_initialized() or mesh is None:
yield
else:
_MESH_STACK.stack.append(
Mesh(jax_mesh=mesh, dtensor_mesh=jax_mesh_to_dtensor_mesh(mesh))
)
try:
yield
finally:
_MESH_STACK.stack.pop(-1)
[docs]
def get_current_dtensor_mesh() -> Optional[dtensor.Mesh]:
"""Returns the DTensor mesh in the current context.
This function retrieves the DTensor mesh from the current context, which is
determined by the active `maybe_enable_dtensor_export_on` context.
Example:
Retrieve the active DTensor mesh inside an export context block::
from orbax.export.dtensor_utils import maybe_enable_dtensor_export_on
from orbax.export.dtensor_utils import get_current_dtensor_mesh
# Assuming jax_mesh is a valid jax.sharding.Mesh
with maybe_enable_dtensor_export_on(jax_mesh):
active_dt_mesh = get_current_dtensor_mesh()
# active_dt_mesh is now the corresponding DTensor mesh
Returns:
The active DTensor mesh if inside a valid DTensor context, otherwise `None`.
"""
mesh = get_current_mesh()
return mesh.dtensor_mesh if mesh else None
def get_current_mesh() -> Optional[Mesh]:
"""Returns the Jax and DTensor mesh in the current context."""
return _MESH_STACK.stack[-1] if _MESH_STACK.stack else None
def get_pspec_from_jax_arrays(
nested_jax_arrays: jaxtyping.PyTree,
) -> jaxtyping.PyTree[jax.sharding.PartitionSpec]:
"""Get the partition spec of a nested jax.Array or jax.ShapeDtypeStruct.
Args:
nested_jax_arrays: a nested structure of jax.Array or jax.ShapeDtypeStruct.
Returns:
A nested structure of jax.sharding.PartitionSpec.
Raises:
AssertionError: if the input nested structure contains jax.Array with
different meshes.
"""
expected_mesh = None
def _get_partition_spec(jax_arr):
nonlocal expected_mesh
if not hasattr(jax_arr, 'sharding') or jax_arr.sharding is None:
return jax.sharding.PartitionSpec()
if not isinstance(jax_arr.sharding, jax.sharding.NamedSharding):
raise AssertionError(
f'Unsupported sharding type: {type(jax_arr.sharding)}, only support'
' NamedSharding'
)
expected_mesh = (
jax_arr.sharding.mesh if not expected_mesh else expected_mesh
)
if expected_mesh != jax_arr.sharding.mesh:
raise AssertionError(
'All those NamedShardings must have the same mesh.'
f' {expected_mesh} != {jax_arr.sharding.mesh}'
)
else:
return jax_arr.sharding.spec
return jax.tree_util.tree_map(_get_partition_spec, nested_jax_arrays)