Multi–host Utilities#

Defines exported symbols for package orbax.checkpoint.multihost.

orbax.checkpoint.multihost.broadcast_one_to_all(in_tree, is_source=None)[source][source]#

Broadcast data from a source host to all other hosts.

orbax.checkpoint.multihost.is_primary_host(primary_host)[source][source]#
orbax.checkpoint.multihost.reached_preemption(step)[source][source]#

Returns True if a preemption sync point has been reached.

Return type:

bool

orbax.checkpoint.multihost.sync_global_processes(name, *, timeout=None, processes=None, barrier_sync_fn=None, record_event_name='/jax/checkpoint/sync_global_devices_duration_sec')[source][source]#

Barrier to sync concurrent processes.

NOTE: The barrier name must be unique, i.e. no process should wait on the same barrier name multiple times.

Parameters:
  • name (str) – barrier name. Must be unique.

  • timeout (Optional[int, None]) – timeout in seconds.

  • processes (Optional[Set[int], None]) – If None, expects to wait across all processes and devices. Otherwise, creates a barrier only across devices associated with the given processes.

  • barrier_sync_fn (Optional[BarrierSyncFn, None]) – Used as the implementation for the synchronization. If not provided, a default implementation is used.

  • record_event_name (str) – The name of the event to record the duration of the synchronization.

orbax.checkpoint.multihost.process_index()[source][source]#

Customized logic for obtaining JAX process index.

Return type:

int

class orbax.checkpoint.multihost.BarrierSyncFn(*args, **kwargs)[source][source]#

Protocol for a barrier synchronization callable.

__call__(*, key, timeout_ms)[source][source]#

Blocks on a barrier identified by key with the given timeout.

Return type:

None

__init__(*args, **kwargs)[source]#
classmethod __subclasshook__(other)[source]#

Abstract classes can override this to customize issubclass().

This is invoked early on by abc.ABCMeta.__subclasscheck__(). It should return True, False or NotImplemented. If it returns NotImplemented, the normal algorithm is used. Otherwise, it overrides the normal algorithm (and the outcome is cached).

orbax.checkpoint.multihost.get_barrier_sync_fn(*, processes=None)[source][source]#

Provides a barrier synchronization function for JAX processes.

Barriers with different sync keys are safe to use from independent background threads.

Parameters:

processes (Optional[Set[int], None]) – If None, expects to wait across all processes and devices. Otherwise, creates a barrier only across devices associated with the given processes.

Return type:

BarrierSyncFn

Returns:

A no-op function if there is a single JAX process, or A barrier synchronization callable which accepts two arguments: “key”: [str] unique barrier id; “timeout_ms”: [int] timeout to use for waiting on the barrier. Should be called from all JAX processes with the same sync key and will block until either 1) all processes have reached the barrier or 2) the timeout is exceeded.