# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for Orbax export."""
from collections.abc import Mapping, Sequence
import dataclasses
import functools
import inspect
import os
from typing import Any, Callable, List, Optional, Tuple, Union
from absl import logging
import jax
from jax import export as jax_export
from jax import tree_util
import jax.numpy as jnp
import jaxtyping
import numpy as np
from orbax.export import serving_config as osc
import tensorflow as tf
ConfigProto = Any
PyTree = jaxtyping.PyTree
SignatureDef = Any
_FILE_TYPE = 'jax_exported'
[docs]
@dataclasses.dataclass
class TensorSpecWithDefault:
"""Extends tf.TensorSpec to hold a default value.
This class associates a specific default value with a TensorFlow tensor
specification. It automatically converts the provided default value into a
`tf.Tensor` and validates that it strictly matches the shape and dtype of
the provided `tensor_spec`.
Constraints due to Python function calling conventions:
- For a python function parameter, all corresponding tensor values in the
signature must have a TensorSpecWithDefault or none of them should.
- Parameters with default values should be ordered before non-default ones.
Example:
Create a TensorSpec with a default fallback value::
import tensorflow as tf
from orbax.export.utils import TensorSpecWithDefault
# Define a spec for a 1D tensor
spec = tf.TensorSpec(shape=(2,), dtype=tf.float32, name="input_a")
# Create the extended spec with a default list
# The list is automatically converted to a tf.Tensor upon initialization
default_spec = TensorSpecWithDefault(
tensor_spec=spec,
default_val=[1.0, 2.0]
)
Attributes:
tensor_spec: The underlying `tf.TensorSpec` defining the expected shape and
dtype.
default_val: The default value to use. Upon initialization, this is
automatically converted to a `tf.Tensor` using the dtype from
`tensor_spec`.
is_primary: Whether this tensor is a primary input tensor. A primary input
tensor is a tensor whose batch size is already or will be tiled to match
the batch size of all other primary input tensors. A non-primary input
tensor must have a batch size of 1, or the same as the primary batch size.
Used primarily by `orbax.export.utils.make_auto_batching_function`.
"""
tensor_spec: tf.TensorSpec
default_val: Any
# Whether this tensor is a primary input tensor.
# A primary input tensor is a tensor whose batch size is already or will be
# tiled to match the batch size of all other primary input tensors, so all
# primary input tensors will have the same batch size.
# A non-primary input tensor must have a batch size of 1, or the same as the
# primary batch size.
#
# This attribute will be used in
# `orbax.export.utils.make_auto_batching_function` and there are several
# constraints. See `make_auto_batching_function` for details.
is_primary: bool = False
def __post_init__(self):
if self.default_val is None:
raise ValueError('Use TensorSpec if no defaults are needed.')
# Has to be a Tensor to be available for TF1 style signatures.
if not isinstance(self.default_val, tf.Tensor):
self.default_val = tf.convert_to_tensor(
self.default_val, dtype=self.tensor_spec.dtype
)
if not tf.TensorSpec.from_tensor(
self.default_val,
name=self.tensor_spec.name,
).is_subtype_of(self.tensor_spec):
raise ValueError(
f'TensorSpec {self.tensor_spec} is not compatible with'
f' the default value {self.default_val}'
)
NestedTfTensorSpec = jaxtyping.PyTree[
Union[tf.TensorSpec, TensorSpecWithDefault]
]
def remove_signature_defaults(input_signature: PyTree) -> PyTree:
"""Removes TensorSpecWithDefault from an input_signature."""
def strip_fn(x):
if isinstance(x, TensorSpecWithDefault):
return x.tensor_spec
else:
return x
return jax.tree_util.tree_map(
strip_fn,
input_signature,
)
def _get_defaults(input_signature: Sequence[PyTree]) -> list[PyTree]:
"""Returns a list of default values corresponding with each parameter."""
default_values = []
for parameter in input_signature:
leaves = jax.tree_util.tree_leaves(parameter)
if not any(isinstance(x, TensorSpecWithDefault) for x in leaves):
default_values.append(inspect.Parameter.empty)
else:
if any(isinstance(x, tf.TensorSpec) for x in leaves):
raise ValueError(
'TensorSpecWithDefault must be defined for each tensor in the'
' structure for the Python arg.'
)
default_values.append(
jax.tree_util.tree_map(lambda x: x.default_val, parameter)
)
return default_values
[docs]
def with_default_args(
tf_fn: Callable[..., Any],
input_signature: Sequence[PyTree],
) -> tf.types.experimental.PolymorphicFunction:
"""Creates a TF function with default args specified in `input_signature`.
This utility wraps a standard Python or TensorFlow function and rewrites its
Python signature to include default values extracted from the provided
`input_signature`.
Example:
Create a TensorFlow function where the second argument has a default value::
import tensorflow as tf
from orbax.export.utils import TensorSpecWithDefault, with_default_args
def add_tensors(x, y):
return x + y
# Define the signature: `x` is required, `y` has a default value
signature = [
tf.TensorSpec(shape=(2,), dtype=tf.float32, name='x'),
TensorSpecWithDefault(
tensor_spec=tf.TensorSpec(shape=(2,), dtype=tf.float32, name='y'),
default_val=[1.0, 1.0]
)
]
# Create the new tf.function with defaults applied
fn_with_defaults = with_default_args(add_tensors, signature)
# The function can now be executed with only the required 'x' argument
result = fn_with_defaults(tf.constant([2.0, 2.0]))
Args:
tf_fn: the TF function.
input_signature: the input signature. Even leaf is a tf.TensorSpec, or a
orbax.export.TensorSpecWithDefault if the default value is specified.
Returns:
A tf function with default arguments.
"""
tf_input_signature = remove_signature_defaults(input_signature)
tf_fn_with_input_signature = tf.function(
tf_fn,
input_signature=tf_input_signature,
jit_compile=False,
autograph=False,
)
default_values = _get_defaults(input_signature)
if all(v is inspect.Parameter.empty for v in default_values):
return tf_fn_with_input_signature
# Generate a new Python function signature with default values.
old_parameters = (
tf_fn_with_input_signature.function_spec.function_type.parameters.values()
)
parameters = [
inspect.Parameter(parameter.name, parameter.kind, default=value)
for parameter, value in zip(old_parameters, default_values)
]
py_signature_with_defaults = inspect.Signature(parameters)
# Create a fn_with_defaults that upholds py_signature_with_defaults.
def fn_with_defaults(*args, **kwargs):
bound_args = py_signature_with_defaults.bind(*args, **kwargs)
bound_args.apply_defaults()
return tf_fn(*bound_args.args, **bound_args.kwargs)
fn_with_defaults.__signature__ = py_signature_with_defaults
# Generate a tf.function and return.
return tf.function(
func=fn_with_defaults,
input_signature=tf_input_signature,
jit_compile=False,
autograph=False,
)
def _runtime_batch_size(x: tf.Tensor) -> tf.Tensor:
"""Gets the runtime batch size of a tensor."""
return tf.shape(x)[0]
@tf.function(autograph=True)
def _repeat_to_batch(
x: tf.Tensor,
primary_batch_size: tf.Tensor,
tensor_name: Optional[str] = None,
) -> tf.Tensor:
"""Repeats a tensor to match a primary batch size."""
input_batch_size = _runtime_batch_size(x)
if input_batch_size == 1:
x = tf.repeat(x, primary_batch_size, axis=0)
else:
tf.assert_equal(
input_batch_size,
primary_batch_size,
f'The batch size of a non-primary input tensor (name={tensor_name})'
' must be 1 or the same as that of the primary tensors.',
)
return x
[docs]
def make_auto_batching_function(
input_signature: Sequence[PyTree],
) -> Callable[..., Any]:
"""Creates an auto-batching function from input signature.
An auto-batching function is a function whose input tensors can have either
a batch size of "b" or 1, and whose output tensors have a batch size of "b",
where "b" is the batch size of the primary input tensors.
Requirements:
- All input tensors must have a leading batch dimension.
- There must be at least one primary tensor. A primary tensor is a tensor
whose tensor spec is either a `tf.TensorSpec` or a `TensorSpecWithDefault`
whose
is_primary attribute is True.
- All primary tensors must have the same batch size.
- All non-primary tensors must have a batch size of 1, or the same as the
primary batch size.
Example:
>>> input_signature = (
>>> tf.TensorSpec([None], tf.int32, name='primary'),
>>> TensorSpecWithDefault(
>>> tf.TensorSpec([None], tf.int32, name='optional'), [1]
>>> ),
>>> )
>>> batching_fn = utils.make_auto_batching_function(input_signature)
>>> batching_fn(tf.constant([0, 0]), tf.constant([1]))
(<tf.Tensor: shape=(2,), dtype=int32, numpy=array([0, 0], dtype=int32)>,
<tf.Tensor: shape=(2,), dtype=int32, numpy=array([1, 1], dtype=int32)>)
Args:
input_signature: a sequence of PyTrees whose leaf node is `tf.Tensor` or
`TensorSpecWithDefault`.
Returns:
A TF function whose output tensors all have the same batch size.
"""
flat_sig, sig_treedef = jax.tree_util.tree_flatten(tuple(input_signature))
is_primary_tensor = [
isinstance(x, tf.TensorSpec) or x.is_primary for x in flat_sig
]
tensor_names = [
x.name if isinstance(x, tf.TensorSpec) else x.tensor_spec.name
for x in flat_sig
]
if not any(is_primary_tensor):
raise ValueError(
'No primary input tensors. A primary tensor is a tensor whose tensor'
' spec is either a `tf.TensorSpec` or a `TensorSpecWithDefault` whose'
' `is_primary` attribute is True. Got'
f' input_signature={input_signature}`'
)
def auto_batching_fn(*args):
flat_args, arg_treedef = jax.tree_util.tree_flatten(args)
assert arg_treedef == sig_treedef, (arg_treedef, sig_treedef)
primary_batch_size = None
for tensor, is_primary in zip(flat_args, is_primary_tensor):
if is_primary:
primary_batch_size = _runtime_batch_size(tensor)
break
assert primary_batch_size is not None
batched = []
for tensor, is_primary, name in zip(
flat_args, is_primary_tensor, tensor_names
):
if is_primary:
tf.assert_equal(
primary_batch_size,
_runtime_batch_size(tensor),
'All primary input tensors must have the same batch size.',
)
else:
tensor = _repeat_to_batch(tensor, primary_batch_size, name)
batched.append(tensor)
return jax.tree_util.tree_unflatten(arg_treedef, batched)
return with_default_args(auto_batching_fn, input_signature)
[docs]
class CallableSignatures:
"""Holds TF SignatureDefs as python callables."""
[docs]
def __init__(
self,
sess: tf.compat.v1.Session,
signature_defs: Mapping[str, SignatureDef],
):
callable_signatures = {}
for name, signature_def in signature_defs.items():
def call(signature_def, **inputs):
output_tensor_keys = list(signature_def.outputs.keys())
feed_dict = {
sess.graph.get_tensor_by_name(signature_def.inputs[k].name): (
v.numpy() if isinstance(v, tf.Tensor) else v
)
for k, v in inputs.items()
}
fetches = [
sess.graph.get_tensor_by_name(signature_def.outputs[k].name)
for k in output_tensor_keys
]
outputs = sess.run(fetches, feed_dict)
return dict(zip(output_tensor_keys, outputs))
callable_signatures[name] = functools.partial(call, signature_def)
self._sess = sess
self._signatures = callable_signatures
[docs]
@classmethod
def from_saved_model(
cls, model_dir: str, tags: list[str], sess_config: ConfigProto = None
):
"""Loads a SavedModel and reconsruct its signatures as python callables.
The signatures of the object loaded by the ``tf.saved_model.load`` API
doesn't support default values, hence one can use this class to load the
model in TF1 and reconstruct the signatures. Example:
>>> loaded = CallableSignatures.from_saved_model(model_dir, ['serve'])
>>> outputs = loaded.signature['serving_default'](**inputs)
The TF2 version of this example is
>>> loaded_tf2 = tf.saved_model.load(model_dir, ['serve'])
>>> outputs = loaded_tf2.signatures['serving_default'](**inputs)
But the callables in `loaded_tf2.signatures` doesn't have any default
inputs.
Args:
model_dir: SavedModel directory.
tags: Tags to identify the metagraph to load. Same as the `tags` argument
in tf.saved_model.load.
sess_config: (Optional.) A
[`ConfigProto`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/protobuf/config.proto)
protocol buffer with configuration options for the session.
Returns:
A mapping of signature names to the callables.
"""
with tf.Graph().as_default():
sess = tf.compat.v1.Session(config=sess_config)
meta_graph_def = tf.compat.v1.saved_model.loader.load(sess, tags, model_dir)
return cls(sess, meta_graph_def.signature_def)
@property
def signatures(self):
"""Returns a mapping for signature names to python callables."""
return self._signatures
def _save_jax_exported_to_disk(
exp: jax_export.Exported,
bin_file_path: str,
*,
vjp_order: int = 0,
) -> None:
if tf.io.gfile.exists(bin_file_path):
raise ValueError(f'File {bin_file_path} already exists.')
with tf.io.gfile.GFile(bin_file_path, 'wb') as f:
f.write(exp.serialize(vjp_order=vjp_order))
def _load_jax_exported_from_disk(bin_file_path: str) -> jax_export.Exported:
if not tf.io.gfile.exists(bin_file_path):
raise ValueError(f'File {bin_file_path} does not exist.')
with tf.io.gfile.GFile(bin_file_path, 'rb') as f:
exp = jax_export.deserialize(bytearray(f.read()))
return exp
def save_jax_exported_map(
dir_path: str,
jax_exported_map: Mapping[str, jax_export.Exported],
*,
vjp_order: int = 0,
):
"""Saves the orbax.export JaxExported Map to disk."""
if tf.io.gfile.exists(dir_path):
raise ValueError(f'Directory {dir_path} already exists.')
tf.io.gfile.makedirs(dir_path)
for method_key, jax_exported in jax_exported_map.items():
file_path = os.path.join(dir_path, f'{method_key}.{_FILE_TYPE}')
_save_jax_exported_to_disk(
jax_exported, os.path.join(dir_path, file_path), vjp_order=vjp_order
)
logging.info('Saved JaxExported Map to %s successfully.', dir_path)
def load_jax_exported_map(dir_path: str) -> Mapping[str, jax_export.Exported]:
"""Loads the orbax.export ApplyFn JaxExported Map from disk.
Args:
dir_path: The directory path to load the ApplyFn Map.
Returns:
A map of method_key to JaxExported object.
"""
jax_exported_map = {}
if not tf.io.gfile.exists(dir_path):
raise ValueError(f'Directory {dir_path} does not exist.')
for method_key in tf.io.gfile.listdir(dir_path):
if not method_key.endswith(f'.{_FILE_TYPE}'):
continue
jax_exported = _load_jax_exported_from_disk(
os.path.join(dir_path, method_key)
)
jax_exported_map[method_key[: -len(f'.{_FILE_TYPE}')]] = jax_exported
if not jax_exported_map:
raise ValueError(f'No .{_FILE_TYPE} files found in {dir_path}.')
logging.info('Loaded ApplyFn JaxExported Map from %s successfully.', dir_path)
return jax_exported_map
def get_key_name(key: Any) -> Union[int, str]:
"""Returns the name of a JAX Key."""
if isinstance(key, jax.tree_util.SequenceKey):
return key.idx
elif isinstance(key, jax.tree_util.DictKey):
return str(key.key)
elif isinstance(key, jax.tree_util.GetAttrKey):
return key.name
elif isinstance(key, jax.tree_util.FlattenedIndexKey):
return key.key
else:
raise ValueError(f'Unsupported KeyEntry: {type(key)}: "{key}"')
def get_param_names(params: PyTree) -> PyTree:
"""Gets parameter names for PyTree elements."""
def _param_name_from_keypath(keypath: Tuple[Any, ...]) -> str:
name = '.'.join([str(get_key_name(k)) for k in keypath])
# '~' is not allowed in variable names but are used by dm-haiku. See
# https://github.com/google/orbax/issues/420
return name.replace('~', '_')
names = jax.tree_util.tree_map_with_path(
lambda kp, _: _param_name_from_keypath(kp), params
)
if jax.tree_util.tree_structure(params) != jax.tree_util.tree_structure(
names
):
logging.warning(
(
'Cannot construct variable names for JAX parameters, which means'
' the parameters tree contains customized nodes not registered with'
' ``jax.tree_util.register_pytree_with_keys``. Variables will be'
' named to `jax_param_<index>` instead. PyTreeDef of params=%s.'
),
jax.tree_util.tree_structure(params),
)
flat_params, tree_def = jax.tree_util.tree_flatten(params)
names = jax.tree_util.tree_unflatten(
tree_def, [f'jax_param_{i}' for i in range(len(flat_params))]
)
return names
def get_variable_tree(
var_treedef: tree_util.PyTreeDef, var_leaves: list[Any]
) -> PyTree:
"""Returns the PyTree of the tf.Variables or obm.Variables associated with the var_treedef."""
return jax.tree_util.tree_unflatten(var_treedef, var_leaves)
def make_e2e_inference_fn(
model_fn: Callable[..., Any],
serving_config: osc.ServingConfig,
) -> Callable[..., Any]:
"""Creates an concrete end-to-end inference tf.function.
Args:
model_fn: a callable in TF context for the numeric computation.
serving_config: a ServingConfig that defines the input sigature,
pre-processor and post-processor of the inference function.
Returns:
A tf.function for end-to-end inference.
"""
infer_step_func_map = serving_config.bind(model_fn, require_numpy=False)
signature_key = serving_config.get_signature_keys()[0]
return with_default_args(
infer_step_func_map[signature_key], serving_config.get_input_signature()
)
def get_lowering_platforms(
native_serialization_platforms: Sequence[str] | str | None,
) -> Sequence[str] | None:
"""Returns a Sequence of lowering platforms provided by the user.
Args:
native_serialization_platforms: A platform string or a sequence of platform
strings for native serialization (e.g., 'tpu', 'cpu'), or None.
Returns:
A Sequence of lowering platforms provided by the user, or None.
"""
if native_serialization_platforms is None:
return None
if isinstance(native_serialization_platforms, str):
native_serialization_platforms = [native_serialization_platforms]
allowed_lower_platforms = set(p.lower() for p in manifest_pb2.Platform.keys())
lowered_native_serialization_platforms = [
p.lower() for p in native_serialization_platforms
]
for p in lowered_native_serialization_platforms:
if p not in allowed_lower_platforms:
raise ValueError(
'native_serialization_platforms must be a sequence'
' and should be a Platform enum type.'
)
return lowered_native_serialization_platforms
def to_bfloat16(x: Any) -> Any:
"""Helper to convert leaves of a pytree to bfloat16.
It handles `float`, `jax.ShapeDtypeStruct`, and other array-like objects with
a floating point `dtype`.
Args:
x: The input pytree to convert.
Returns:
The input `x` with floating point values converted to `jnp.bfloat16`.
"""
def _to_bfloat16_leaf(x: Any) -> Any:
if isinstance(x, jax.ShapeDtypeStruct):
if jnp.issubdtype(x.dtype, jnp.float32):
return jax.ShapeDtypeStruct(
x.shape,
jnp.bfloat16,
sharding=x.sharding,
)
return x
if hasattr(x, 'dtype'):
if x.dtype == tf.string:
return x
if jnp.issubdtype(x.dtype, jnp.float32):
return x.astype(jnp.bfloat16)
return x
flattened_x, treedef = jax.tree_util.tree_flatten(x)
flattened_y = [
jax.tree_util.tree_map(_to_bfloat16_leaf, y) for y in flattened_x
]
return jax.tree_util.tree_unflatten(treedef, flattened_y)