Source code for orbax.checkpoint._src.arrays.abstract_arrays

# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities for dealing with abstract arrays."""

from typing import Protocol
import jax
from jax import numpy as jnp
import numpy as np
from orbax.checkpoint._src.arrays import sharding as arrays_sharding_lib
from orbax.checkpoint._src.arrays import types
from orbax.checkpoint._src.metadata import sharding as sharding_metadata


class AbstractArrayLike(Protocol):
  """Abstract representation of an array.

  Can include objects like jax.Array, jax.ShapeDtypeStruct,
  ArrayRestoreArgs, and value_metadata.ArrayMetadata.
  """

  shape: types.Shape
  dtype: jnp.dtype | None
  sharding: jax.sharding.Sharding | sharding_metadata.ShardingMetadata | None


class AbstractArrayLikeGlobalShape(Protocol):
  """Same as above, but with `global_shape` property instead."""

  global_shape: types.Shape
  dtype: jnp.dtype | None
  sharding: jax.sharding.Sharding | sharding_metadata.ShardingMetadata | None


def _is_scalar(arr):
  return isinstance(arr, (ScalarType, np.number))


def _get_shape(
    arr: AbstractArrayLike | AbstractArrayLikeGlobalShape,
) -> types.Shape:
  if hasattr(arr, 'shape'):
    return arr.shape
  if hasattr(arr, 'global_shape'):
    return arr.global_shape
  raise ValueError(f'Object does not have a `shape` property: {arr}')


ArrayLike = AbstractArrayLike | AbstractArrayLikeGlobalShape | np.ndarray

ScalarType = int | float


[docs] def to_shape_dtype_struct( arr: ArrayLike, dtype: jnp.dtype | None = None, scalar_dtype: ScalarType | None = None, support_format: bool = False, # TODO(b/460844509) - True by default. ) -> jax.ShapeDtypeStruct | ScalarType: """Get ShapeDtypeStruct from array-like object. Args: arr: Array-like object. This can include jax.Array, jax.ShapeDtypeStruct, ArrayRestoreArgs, value_metadata.ArrayMetadata - anything that has `shape`/`global_shape`, `dtype`, and `sharding` properties. It may also be a numpy array or a scalar value. dtype: Optional dtype that overrides the dtype of `arr` in the result. scalar_dtype: Optional dtype to use for scalars. Useful for converting to Python scalar types. support_format: Whether to support layout in the result. Returns: jax.ShapeDtypeStruct or scalar value. """ if isinstance(arr, jax.Array) and jax.dtypes.issubdtype( arr.dtype, jax.dtypes.prng_key ): # For random keys, extract the dtype and shape as a regular Jax array. # Stored metadata will help restoring the original random key. arr = jax.random.key_data(arr) if _is_scalar(arr): if scalar_dtype is not None: return scalar_dtype(arr) return arr elif isinstance(arr, np.ndarray): dtype = dtype or arr.dtype return jax.ShapeDtypeStruct(_get_shape(arr), dtype) else: shape = _get_shape(arr) dtype = dtype or arr.dtype sharding = arr.sharding if isinstance(sharding, sharding_metadata.ShardingMetadata): sharding = sharding.to_jax_sharding() else: sharding = arrays_sharding_lib.get_sharding_or_format( arr, support_format=support_format ) return jax.ShapeDtypeStruct(shape, dtype, sharding=sharding)