Multi-host Utilities#

Defines exported symbols for package orbax.checkpoint.multihost.

orbax.checkpoint.multihost.sync_global_processes(name, processes=None)[source][source]#

Barrier to sync concurrent processes.

Parameters:
  • name (str) – barrier name.

  • processes (Optional[Set[int]]) – If None, expects to wait across all processes and devices. Otherwise, creates a barrier only across devices associated with the given processes.

orbax.checkpoint.multihost.broadcast_one_to_all(in_tree, is_source=None)[source][source]#

Broadcast data from a source host to all other hosts.

orbax.checkpoint.multihost.broadcast_one_to_some(in_tree, *, is_source=None, processes=None)[source][source]#

Broadcast data from a source host to some or all other hosts.

The function should only be called by participating processes - i.e. those appearing in processes if specified, or any process if not specified.

Inspired by JAX multihost_utils.

Parameters:
  • in_tree (Any) – pytree of arrays - each array must have the same shape across the hosts.

  • is_source (Optional[bool]) – Whether the current process is the source of the broadcast. If None, an arbitrary process within processes will be selected as the source for the broadcast.

  • processes (Optional[Set[int]]) – Set of participating processes. Assumed to be all processes if None.

Return type:

Any

Returns:

A pytree matching in_tree where the leaves now all contain the data from the first host.