Source code for orbax.experimental.model.core.python.function
# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The `Function` base class."""
import dataclasses
import enum
from typing import Any, Optional, Sequence, Tuple, TypeAlias
import numpy as np
from orbax.experimental.model.core.python import tree_util
from tensorflow.compiler.xla import xla_data_pb2 # pylint: disable=g-direct-tensorflow-import
Sharding: TypeAlias = xla_data_pb2.OpSharding
Layout: TypeAlias = xla_data_pb2.LayoutProto
ShloDimSize: TypeAlias = Optional[int]
ShloShape: TypeAlias = Optional[Sequence[ShloDimSize]]
# pylint: disable=invalid-name
# Copied from /third_party/py/jax/_src/export/serialization.fbs
[docs]
class ShloDType(enum.Enum): # pylint: disable=missing-class-docstring
bool = 0
i8 = 1
i16 = 2
i32 = 3
i64 = 4
ui8 = 5
ui16 = 6
ui32 = 7
ui64 = 8
f16 = 9
f32 = 10
f64 = 11
c64 = 12
c128 = 13
bf16 = 14
i4 = 15
ui4 = 16
f8_e4m3b11fnuz = 17
f8_e4m3fn = 18
f8_e4m3fnuz = 19
f8_e5m2 = 20
f8_e5m2fnuz = 21
str = 100
_NP_DTYPE_TO_SHLO_DTYPE: dict[np.dtype[Any], ShloDType] = {
np.dtype(np.bool): ShloDType.bool,
np.dtype(np.int8): ShloDType.i8,
np.dtype(np.int16): ShloDType.i16,
np.dtype(np.int32): ShloDType.i32,
np.dtype(np.int64): ShloDType.i64,
np.dtype(np.uint8): ShloDType.ui8,
np.dtype(np.uint16): ShloDType.ui16,
np.dtype(np.uint32): ShloDType.ui32,
np.dtype(np.uint64): ShloDType.ui64,
np.dtype(np.float16): ShloDType.f16,
np.dtype(np.float32): ShloDType.f32,
np.dtype(np.float64): ShloDType.f64,
np.dtype(np.complex64): ShloDType.c64,
np.dtype(np.complex128): ShloDType.c128,
}
_SHLO_DTYPE_TO_NP_DTYPE = {v: k for k, v in _NP_DTYPE_TO_SHLO_DTYPE.items()}
def np_dtype_to_shlo_dtype(dtype: np.dtype[Any]) -> ShloDType:
return _NP_DTYPE_TO_SHLO_DTYPE[dtype]
def shlo_dtype_to_np_dtype(dtype: ShloDType) -> np.dtype[Any]:
return _SHLO_DTYPE_TO_NP_DTYPE[dtype]
# TODO(wangpeng): value.py needs this class, so we should move this class out
# of function.py .
[docs]
@dataclasses.dataclass
class ShloTensorSpec:
"""A specification for the shape, dtype, sharding, and layout of a StableHLO tensor.
Attributes:
shape: The shape of the tensor.
dtype: The dtype of the tensor.
sharding: The sharding of the tensor. None means unspecified sharding.
layout: The layout of the tensor. None means the default layout is used.
name: The name of the tensor.
"""
shape: ShloShape
dtype: ShloDType
sharding: Sharding | None = None
layout: Layout | None = None
name: str| None = None
[docs]
@dataclasses.dataclass(kw_only=True)
class Function:
"""An abstract base class for functions whose signatures are StableHLO types.
Attributes:
input_signature: the input signature of the function.
output_signature: the output signature of the function.
data_names: checkpoint data names used by the function.
signature: the pair `(input_signature, output_signature)`.
"""
input_signature: tree_util.Tree[ShloTensorSpec]
output_signature: tree_util.Tree[ShloTensorSpec]
data_names: Sequence[str] | None = None
# TODO(b/372084833): Add `vjp_name``.
@property
def signature(
self,
) -> Tuple[tree_util.Tree[ShloTensorSpec], tree_util.Tree[ShloTensorSpec]]:
return self.input_signature, self.output_signature