ocp.v1.arrays module#

Public API for arrays package.

orbax.checkpoint.experimental.v1.arrays.to_shape_dtype_struct(arr, dtype=None, scalar_dtype=None, support_format=False)[source][source]#

Get ShapeDtypeStruct from array-like object.

Parameters:
  • arr (UnionType[AbstractArrayLike, AbstractArrayLikeGlobalShape, ndarray]) – Array-like object. This can include jax.Array, jax.ShapeDtypeStruct, ArrayRestoreArgs, value_metadata.ArrayMetadata - anything that has shape/global_shape, dtype, and sharding properties. It may also be a numpy array or a scalar value.

  • dtype (UnionType[dtype, None]) – Optional dtype that overrides the dtype of arr in the result.

  • scalar_dtype (UnionType[int, float, None]) – Optional dtype to use for scalars. Useful for converting to Python scalar types.

  • support_format (bool) – Whether to support layout in the result.

Return type:

UnionType[ShapeDtypeStruct, int, float]

Returns:

jax.ShapeDtypeStruct or scalar value.

orbax.checkpoint.experimental.v1.arrays.ArrayLike = orbax.checkpoint.experimental.v1._src.arrays.types.AbstractArray | orbax.checkpoint.experimental.v1._src.arrays.types.AbstractShardedArray#

Represent a PEP 604 union type

E.g. for int | str

class orbax.checkpoint.experimental.v1.arrays.AbstractArray(*args, **kwargs)[source][source]#

Abstract representation of an array.

This is a protocol for an abstract array that can be used to represent the metadata belonging to an array.

shape:

Tuple of integers describing the array shape.

dtype:

Dtype of array elements.

class orbax.checkpoint.experimental.v1.arrays.AbstractShardedArray(*args, **kwargs)[source][source]#

Abstract representation of an array.

This is a protocol for an abstract array that can be used to represent various metadata types such as jax.ShapeDtypeStruct and ArrayMetadata.

#TODO(dnlng): All attributes are made optional to support the case where # the ArrayMetadata is passed into the metadata() call to pass only the # write_shape. Optional attributes are not needed once write_shape is # refactored.

shape:

Tuple of integers describing the array shape.

dtype:

Dtype of array elements.

Sharding:

Sharding to indicate how the array is sharded. This can be jax’s Sharding or Layout or None.