General Utilities#

Utility functions for Orbax.

TODO(b/266449081) Increase unit test coverage.

Async wrappers#

orbax.checkpoint.utils.async_makedirs(path, *args, parents=False, exist_ok=True, **kwargs)[source][source]#
orbax.checkpoint.utils.async_write_bytes(path, data)[source][source]#
orbax.checkpoint.utils.async_exists(path)[source][source]#

Tree utils#

orbax.checkpoint.utils.is_empty_or_leaf(x)[source][source]#
Return type:

bool

orbax.checkpoint.utils.get_key_name(key)[source][source]#

Returns the name of a JAX Key.

Return type:

Union[int, str]

orbax.checkpoint.utils.to_flat_dict(tree, sep=None, keep_empty_nodes=False)[source][source]#

Converts a tree into a flattened dictionary.

The nested keys are flattened to a tuple.

Example:

tree = {'foo': 1, 'bar': {'a': 2, 'b': {}}}
to_flat_dict(tree)
{
  ('foo',): 1,
  ('bar', 'a'): 2,
}
Parameters:
  • tree (Any) – A PyTree to be flattened.

  • sep (Optional[str]) – If provided, keys will be returned as sep-separated strings. Otherwise, keys are returned as tuples.

  • keep_empty_nodes (bool) – If True, empty nodes are not filtered out.

Return type:

Any

Returns:

A flattened dictionary and the tree structure.

orbax.checkpoint.utils.serialize_tree(tree, keep_empty_nodes=False)[source][source]#

Transforms a PyTree to a serializable format.

Parameters:
  • tree (Any) – The tree to serialize, if tree is empty and keep_empty_nodes is False, an error is raised as there is no valid representation.

  • keep_empty_nodes (bool) – If true, does not filter out empty nodes.

Return type:

Any

Returns:

The serialized PyTree.

orbax.checkpoint.utils.deserialize_tree(serialized, target, keep_empty_nodes=False)[source][source]#

Deserializes a PyTree to the same structure as target.

Return type:

Any

orbax.checkpoint.utils.from_flat_dict(flat_dict, target=None, sep=None)[source][source]#

Reconstructs the original tree object from a flattened dictionary.

Parameters:
  • flat_dict (Any) – A dictionary conforming to the return value of to_flat_dict.

  • target (Optional[Any]) – A reference PyTree. The returned value will conform to this structure. If not provided, an unflattened dict will be returned with the inferred structure of the original tree, without necessarily matching it exactly. Note, if not provided, the keys in flat_dict need to match sep.

  • sep (Optional[str]) – separator used for nested keys in flat_dict.

Return type:

Any

Returns:

A dict matching the structure of tree with the values of flat_dict.

orbax.checkpoint.utils.pytree_structure(directory)[source][source]#

Reconstruct state dict from saved model format in directory.

Return type:

Any

Aggregate file#

orbax.checkpoint.utils.leaf_is_placeholder(leaf)[source][source]#

Determines if leaf represents a placeholder for a non-aggregated value.

Return type:

bool

orbax.checkpoint.utils.leaf_placeholder(name)[source][source]#

Constructs value to act as placeholder for non-aggregated value.

Return type:

str

orbax.checkpoint.utils.name_from_leaf_placeholder(placeholder)[source][source]#

Gets the param name from a placeholder with the correct prefix.

Return type:

str

orbax.checkpoint.utils.is_supported_empty_aggregation_type(value)[source][source]#

Determines if the empty value is supported for aggregation.

Return type:

bool

orbax.checkpoint.utils.is_supported_aggregation_type(value)[source][source]#

Determines if the value is supported for aggregation.

Return type:

bool

Directories#

orbax.checkpoint.utils.cleanup_tmp_directories(directory, primary_host=0, active_processes=None)[source][source]#

Cleanup steps in directory with tmp files, as these are not finalized.

orbax.checkpoint.utils.get_tmp_directory(path)[source][source]#

Returns a non-deterministic tmp directory for path without creating it.

Return type:

Path

orbax.checkpoint.utils.create_tmp_directory(final_dir, *, primary_host=0, active_processes=None)[source][source]#

Creates a non-deterministic tmp directory for saving for given final_dir.

Also writes checkpoint metadata in the tmp directory.

Parameters:
  • final_dir (Union[str, PathLike]) – The eventual directory path where checkpoint will be committed.

  • primary_host (Optional[int]) – primary host id, default=0.

  • active_processes (Optional[Set[int]]) – Ids of active processes. default=None

Return type:

Path

Returns:

The tmp directory.

Raises:

FileExistsError – if tmp directory already exists.

orbax.checkpoint.utils.get_save_directory(step, directory, name=None, step_prefix=None, override_directory=None, step_format_fixed_length=None, step_name_format=None)[source][source]#

Returns the standardized path to a save directory for a single item.

Parameters:
  • step (int) – Step number.

  • directory (Union[str, PathLike]) – Top level checkpoint directory.

  • name (Optional[str]) – Item name (‘params’, ‘state’, ‘dataset’, etc.).

  • step_prefix (Optional[str]) – Prefix applied to step (e.g. ‘checkpoint’).

  • override_directory (Union[str, PathLike, None]) – If provided, step, directory, and step_prefix are ignored.

  • step_format_fixed_length (Optional[int]) – Uses a fixed number of digits with leading zeros to represent the step number. If None, there are no leading zeros.

  • step_name_format (Optional[NameFormat]) – NameFormat used to define step name for step and under given root directory. If provided, step_prefix and step_format_fixed_length are ignored.

Return type:

Path

Returns:

A directory.

orbax.checkpoint.utils.is_gcs_path(path)[source][source]#
Return type:

bool

Atomicity#

orbax.checkpoint.utils.is_tmp_checkpoint(path)[source][source]#

Determines whether a directory is a tmp checkpoint path.

Return type:

bool

orbax.checkpoint.utils.is_checkpoint_finalized(path)[source][source]#

Determines if the given path represents a finalized checkpoint.

Path takes the form:

path/to/my/dir/<name>.orbax-checkpoint-tmp-<timestamp>/  # not finalized
path/to/my/dir/<name>/  # finalized

Alternatively:

gs://path/to/my/dir/<name>/  # finalized
  commit_success.txt
  ...
gs://<path/to/my/dir/<name>/  # not finalized
  ...
Parameters:

path (Union[str, PathLike]) – Directory.

Return type:

bool

Returns:

True if the checkpoint is finalized.

Raises:
  • ValueError if the provided path is not a directory. Valid checkpoint paths

  • must be a directory.

Checkpoint steps#

orbax.checkpoint.utils.step_from_checkpoint_name(name)[source][source]#

Returns the step from a checkpoint name. Also works for tmp checkpoints.

Return type:

int

orbax.checkpoint.utils.checkpoint_steps_paths(checkpoint_dir)[source][source]#

Returns a list of finalized checkpoint paths in the directory.

Return type:

List[Path]

orbax.checkpoint.utils.checkpoint_steps(checkpoint_dir, single_host_load_and_broadcast=False)[source][source]#

Returns a list of finalized checkpoint steps in the directory.

Return type:

List[int]

orbax.checkpoint.utils.any_checkpoint_step(checkpoint_dir)[source][source]#

Returns any finalized checkpoint step in the directory or None.

This avoids iterating over the entire directory.

Parameters:

checkpoint_dir (Union[str, PathLike]) – Checkpoint directory.

Return type:

Optional[int]

Returns:

Any finalized checkpoint step in the directory or None.

orbax.checkpoint.utils.tmp_checkpoints(checkpoint_dir)[source][source]#

Returns a list of tmp checkpoint dir names in checkpoint_dir.

Return type:

List[str]

orbax.checkpoint.utils.lockdir(directory)[source][source]#

Constructs a directory used to indicate that a checkpoint step is locked.

Return type:

Path

orbax.checkpoint.utils.is_locked(directory)[source][source]#

Determines whether a checkpoint step is considered locked.

Return type:

bool

orbax.checkpoint.utils.are_locked(directory, steps, step_prefix=None, step_format_fixed_length=None, step_name_format=None)[source][source]#

In parallel, determines whether the steps are considered locked.

Return type:

List[bool]

Sharding#

orbax.checkpoint.utils.fully_replicated_host_local_array_to_global_array(arr)[source][source]#

Converts a host local array from to global jax.Array.

In most cases, the local array is expected to have been produced by pmap.

Parameters:

arr (Array) – Host local array

Return type:

Array

Returns:

A global array.

Misc.#

orbax.checkpoint.utils.is_scalar(x)[source][source]#
orbax.checkpoint.utils.record_saved_duration(checkpoint_start_time)[source][source]#

Record program duration that is accounted for by this checkpoint.

For the very first checkpoint, this is the interval between program init and current checkpoint start time.

Note that we use the checkpoint start time instead of end time. The saved duration should not include prallel training duration while the async checkpoint is being written in the background.

Parameters:

checkpoint_start_time (float) – Start time of current checkpoint.