# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Orbax utils related to multihost_utils functionality."""
import threading
import time
from typing import Collection, Optional
from absl import logging
import jax
from jax.experimental import multihost_utils
from orbax.checkpoint.experimental.v1._src.synchronization import signaling_client
# Default timeout in seconds.
_DEFAULT_BARRIER_TIMEOUT = 300
[docs]
def is_pathways_backend() -> bool:
# Pathways is single-host.
return (
hasattr(jax.devices()[0].client, 'pathways')
or jax.devices()[0].client.runtime_type == 'pathways'
or jax.devices()[0].client.runtime_type == 'proxy/pathways'
)
def coordination_timeout() -> int:
"""Returns the coordination timeout in seconds."""
return _DEFAULT_BARRIER_TIMEOUT
def should_skip_process_sync(processes: Collection[int] | None = None) -> bool:
if processes and len(processes) == 1 and process_index() in processes:
return True
if jax.process_count() == 1:
return True
if is_pathways_backend():
return True
return False
def _unique_barrier_key(key: str) -> str:
"""Function that can be overridden for testing purposes."""
return key
def unique_barrier_key(
key: str,
*,
prefix: str | None = None,
suffix: str | None = None,
) -> str:
"""Constructs a key given an optional prefix and suffix."""
if prefix is not None:
key = f'{prefix}_{key}'
if suffix is not None:
key = f'{key}.{suffix}'
return key
[docs]
async def sync_global_processes(
key: str,
*,
operation_id: str,
timeout: int | None = None,
processes: Collection[int] | None = None,
record_event_name: str = '/jax/checkpoint/sync_global_devices_duration_sec',
):
"""Barrier to sync concurrent processes.
NOTE: The barrier name must be unique, i.e. no process should wait on the
same barrier name multiple times.
Args:
key: barrier name. Must be unique.
operation_id: The barrier name will be prefixed with the
operation id.
timeout: timeout in seconds.
processes: If None, expects to wait across all processes and devices.
Otherwise, creates a barrier only across devices associated with the given
processes.
record_event_name: The name of the event to record the duration of the
synchronization.
"""
key = f'[op={operation_id}] {key}'
if should_skip_process_sync(processes):
logging.vlog(
1,
'[process=%s][thread=%s] Skipping global process sync, barrier'
' name: %s',
process_index(),
threading.current_thread().name,
key,
)
return
sync_start_time = time.time()
logging.vlog(
1,
'[process=%s][thread=%s] Waiting at barrier: %s with processes: %s',
process_index(),
threading.current_thread().name,
key,
processes
)
timeout = timeout or coordination_timeout()
if timeout <= 0:
raise ValueError(f'Timeout must be positive, but got {timeout} seconds.')
client = signaling_client.get_signaling_client()
key = _unique_barrier_key(key)
if processes is not None:
if process_index() not in processes:
raise ValueError(
'Attempted to create a barrier across a subset of processes, but the'
f' current process: {process_index()} was not present in the provided'
f' list of processes: {processes}.'
)
processes = list(processes)
await client.wait_at_barrier(key, timeout_secs=timeout, process_ids=processes)
duration = time.time() - sync_start_time
logging.vlog(
1,
'[process=%s][thread=%s] Done waiting at barrier: %s, took %s s',
process_index(),
threading.current_thread().name,
key,
duration,
)
# This may end up just being too noisy given how many barriers there are, but
# it does represent how long different processes waited around waiting for
# other processes to reach a barrier.
jax.monitoring.record_event_duration_secs(
record_event_name,
duration,
)
[docs]
def is_primary_host(primary_host: int | None):
if primary_host is None or primary_host == process_index():
return True
return False
[docs]
def process_count() -> int:
return jax.process_count()
[docs]
def process_index() -> int:
if is_pathways_backend():
return jax.process_index()
# Note that jax.process_index() does not return the same thing as
# global_state.process_id. We rely on the latter to work with barriers over a
# subset of processes.
return jax._src.distributed.global_state.process_id # pylint: disable=protected-access
def broadcast_one_to_all(in_tree, is_source: Optional[bool] = None):
"""Broadcast data from a source host to all other hosts."""
if is_pathways_backend():
return in_tree
if is_source is None:
is_source = process_index() == 0
return multihost_utils.broadcast_one_to_all(in_tree, is_source=is_source)
def process_allgather(in_tree, tiled: bool | None = None):
"""All-gather data from all hosts."""
if is_pathways_backend():
return in_tree
return multihost_utils.process_allgather(in_tree, tiled=tiled)