DTensor utilities for multi-device/host export#

Utilities for transforming distributed JAX arrays to DTensors.

orbax.export.dtensor_utils.initialize_dtensor(reset_context=False)[source][source]#

Initialize a DTensor system for Orbax Export.

Parameters:

reset_context (bool) – Reset the tensorflow context along DTensor initialization. Behaviors of existing TensorFlow objects (e.g. Tensors) are undefined. Set this to True as an escape hatch, if there is no clear way to refactor your code to call initialize_dtensor() before calling TensorFlow APIs that initialize the context. See also dtensor.initialize_accelerator_system.

Raises:
  • RuntimeError – if the number of DTensor clients is not the same as that of

  • JAX processes.

orbax.export.dtensor_utils.dtensor_initialized()[source][source]#

Checks whether DTensor is initialized and matches the JAX device set.

Return type:

bool

orbax.export.dtensor_utils.shutdown_dtensor()[source][source]#
Return type:

None

orbax.export.dtensor_utils.jax_mesh_to_dtensor_mesh(mesh)[source][source]#

Creates a DTensor mesh from a JAX mesh.

Parameters:

mesh (Mesh) – a JAX global mesh for pjit.

Return type:

Mesh

Returns:

A DTensor host mesh of the same shape and axis names as those of the JAX mesh.

orbax.export.dtensor_utils.maybe_enable_dtensor_export_on(mesh)[source][source]#

Creates a DTensor context from a JAX mesh for Orbax Export.

If DTensor is not initialized or mesh is None, this function is a no-op.

Parameters:

mesh (Optional[Mesh]) – a JAX pjit Mesh.

Yields:

None.

orbax.export.dtensor_utils.get_current_dtensor_mesh()[source][source]#

Returns the DTensor mesh in the current context.

Return type:

Optional[Mesh]