Arrays

Arrays#

Utilities for working with arrays.

orbax.checkpoint.arrays.to_shape_dtype_struct(arr, dtype=None, scalar_dtype=None, support_format=False)[source][source]#

Get ShapeDtypeStruct from array-like object.

Parameters:
  • arr (UnionType[AbstractArrayLike, AbstractArrayLikeGlobalShape, ndarray]) – Array-like object. This can include jax.Array, jax.ShapeDtypeStruct, ArrayRestoreArgs, value_metadata.ArrayMetadata - anything that has shape/global_shape, dtype, and sharding properties. It may also be a numpy array or a scalar value.

  • dtype (UnionType[dtype, None]) – Optional dtype that overrides the dtype of arr in the result.

  • scalar_dtype (UnionType[int, float, None]) – Optional dtype to use for scalars. Useful for converting to Python scalar types.

  • support_format (bool) – Whether to support layout in the result.

Return type:

UnionType[ShapeDtypeStruct, int, float]

Returns:

jax.ShapeDtypeStruct or scalar value.