DTensor utilities for multi-device/host export#
Utilities for transforming distributed JAX arrays to DTensors.
- orbax.export.dtensor_utils.initialize_dtensor(reset_context=False)[source][source]#
Initialize a DTensor system for Orbax Export.
- Parameters:
reset_context (
bool
) – Reset the tensorflow context along DTensor initialization. Behaviors of existing TensorFlow objects (e.g. Tensors) are undefined. Set this to True as an escape hatch, if there is no clear way to refactor your code to call initialize_dtensor() before calling TensorFlow APIs that initialize the context. See also dtensor.initialize_accelerator_system.- Raises:
RuntimeError – if the number of DTensor clients is not the same as that of
JAX processes. –
- orbax.export.dtensor_utils.dtensor_initialized()[source][source]#
Checks whether DTensor is initialized and matches the JAX device set.
- Return type:
bool
- orbax.export.dtensor_utils.jax_mesh_to_dtensor_mesh(mesh)[source][source]#
Creates a DTensor mesh from a JAX mesh.
- Parameters:
mesh (
Mesh
) – a JAX global mesh for pjit.- Return type:
Mesh
- Returns:
A DTensor host mesh of the same shape and axis names as those of the JAX mesh.