Multi-host Utilities#
Defines exported symbols for package orbax.checkpoint.multihost.
- orbax.checkpoint.multihost.sync_global_processes(name, processes=None)[source][source]#
Barrier to sync concurrent processes.
- Parameters:
name (
str
) – barrier name.processes (
Optional
[Set
[int
]]) – If None, expects to wait across all processes and devices. Otherwise, creates a barrier only across devices associated with the given processes.
- orbax.checkpoint.multihost.broadcast_one_to_all(in_tree, is_source=None)[source][source]#
Broadcast data from a source host to all other hosts.
- orbax.checkpoint.multihost.broadcast_one_to_some(in_tree, *, is_source=None, processes=None)[source][source]#
Broadcast data from a source host to some or all other hosts.
The function should only be called by participating processes - i.e. those appearing in processes if specified, or any process if not specified.
Inspired by JAX multihost_utils.
- Parameters:
in_tree (
Any
) – pytree of arrays - each array must have the same shape across the hosts.is_source (
Optional
[bool
]) – Whether the current process is the source of the broadcast. If None, an arbitrary process within processes will be selected as the source for the broadcast.processes (
Optional
[Set
[int
]]) – Set of participating processes. Assumed to be all processes if None.
- Return type:
Any
- Returns:
A pytree matching in_tree where the leaves now all contain the data from the first host.