ocp.v1.arrays module#
Public API for arrays package.
- orbax.checkpoint.experimental.v1.arrays.to_shape_dtype_struct(arr, dtype=None, scalar_dtype=None, support_format=False)[source][source]#
Get ShapeDtypeStruct from array-like object.
- Parameters:
arr (
UnionType[AbstractArrayLike,AbstractArrayLikeGlobalShape,ndarray]) – Array-like object. This can include jax.Array, jax.ShapeDtypeStruct, ArrayRestoreArgs, value_metadata.ArrayMetadata - anything that has shape/global_shape, dtype, and sharding properties. It may also be a numpy array or a scalar value.dtype (
UnionType[dtype,None]) – Optional dtype that overrides the dtype of arr in the result.scalar_dtype (
UnionType[int,float,None]) – Optional dtype to use for scalars. Useful for converting to Python scalar types.support_format (
bool) – Whether to support layout in the result.
- Return type:
UnionType[ShapeDtypeStruct,int,float]- Returns:
jax.ShapeDtypeStruct or scalar value.
- orbax.checkpoint.experimental.v1.arrays.ArrayLike = orbax.checkpoint.experimental.v1._src.arrays.types.AbstractArray | orbax.checkpoint.experimental.v1._src.arrays.types.AbstractShardedArray#
Represent a PEP 604 union type
E.g. for int | str
- class orbax.checkpoint.experimental.v1.arrays.AbstractArray(*args, **kwargs)[source][source]#
Abstract representation of an array.
This is a protocol for an abstract array that can be used to represent the metadata belonging to an array.
- shape:
Tuple of integers describing the array shape.
- dtype:
Dtype of array elements.
- class orbax.checkpoint.experimental.v1.arrays.AbstractShardedArray(*args, **kwargs)[source][source]#
Abstract representation of an array.
This is a protocol for an abstract array that can be used to represent various metadata types such as
jax.ShapeDtypeStructandArrayMetadata.#TODO(dnlng): All attributes are made optional to support the case where # the ArrayMetadata is passed into the metadata() call to pass only the # write_shape. Optional attributes are not needed once write_shape is # refactored.
- shape:
Tuple of integers describing the array shape.
- dtype:
Dtype of array elements.
- Sharding:
Sharding to indicate how the array is sharded. This can be jax’s Sharding or Layout or None.