# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ServingConfig class."""
from collections.abc import Callable, Mapping, Sequence
import dataclasses
from typing import Any, Optional, Text, Union
from absl import logging
import jax
import jaxtyping
from orbax.export import obm_configs
from orbax.export.data_processors import data_processor_base
import tensorflow as tf
[docs]
@dataclasses.dataclass
class ServingConfig:
"""Configuration for constructing a serving signature for a JaxModule.
A ServingConfig is to be bound with a JaxModule to form an end-to-end serving
signature.
Example:
Create a serving configuration with pre- and post-processors::
import tensorflow as tf
from orbax.export import tf_data_processor
from orbax.export.serving_config import ServingConfig
@tf.function(input_signature=[tf.TensorSpec(
shape=(None, 32), dtype=tf.float32
)])
def preprocessor(inputs):
return {'normalized': inputs / 255.0}
def postprocessor(outputs):
return {'probabilities': tf.nn.softmax(outputs)}
config = ServingConfig(
signature_key='serving_default',
preprocessors=[tf_data_processor.TfDataProcessor(preprocessor)],
postprocessors=[tf_data_processor.TfDataProcessor(postprocessor)]
)
Attributes:
signature_key: The key of the serving signature or a sequence of keys
mapping to the same serving signature.
input_signature: The input signature for `tf_preprocessor` (or the JaxModule
method if there is no preprocessor). If not specified, this will be
inferred from `tf_preprocessor.input_signature`.
tf_preprocessor: Optional pre-processing function written in TF.
tf_postprocessor: Optional post-processing function written in TF.
preprocessors: Optional sequence of `DataProcessor`s to be applied before
the main model function. Mutually exclusive with `tf_preprocessor`.
postprocessors: Optional sequence of `DataProcessor`s to be applied after
the main model function. Mutually exclusive with `tf_postprocessor`.
data_processors: Optional sequence of `DataProcessor`s. Mutually exclusive
with other processors. Ordered based on input/output keys via topological
sort.
extra_trackable_resources: A nested structure of trackable resources used in
TF processors.
method_key: The key of the JAX method of the `JaxModule` to be bound.
obm_export_options: Options passed to the Orbax Model export.
preprocess_output_passthrough_enabled: When True, allows a portion of the
preprocessor's outputs to be directly passed to the tf_postprocessor,
bypassing the JAX function. Requires the preprocessor to return a tuple of
two elements: (jax_inputs, postprocessor_inputs_extra).
"""
# The key of the serving signature or a sequence of keys mapping to the same
# serving signature.
signature_key: Union[str, Sequence[str]]
# The input signature for `tf_preprocessor` (or the JaxModule method if there
# is no `tf_preprocessor`). If not specified, this will be inferred from
# `tf_preprocessor`, in which case `tf_preprocessor` must be a tf.function
# with `input_signature` annotation. See
# https://www.tensorflow.org/api_docs/python/tf/function#input_signatures.
input_signature: Optional[Sequence[jaxtyping.PyTree]] = None
# Optional pre-precessing function written in TF.
tf_preprocessor: Optional[Callable[..., Any]] = None
# Optional post-processing function written in TF.
tf_postprocessor: Optional[Callable[..., Any]] = None
#
# Optional sequence of `DataProcessor`s to be applied before the main model
# function.
preprocessors: Sequence[data_processor_base.DataProcessor] = ()
# Optional sequence of `DataProcessor`s to be applied after the main model
# function.
postprocessors: Sequence[data_processor_base.DataProcessor] = ()
# Optional sequence of `DataProcessor`s to be applied. `DataProcessor` is a
# new abstraction for constructing pipelines for Orbax Model export. This
# field is mutually exclusive with `tf_preprocessor`, `preprocessors`,
# `tf_postprocessor`, and `postprocessors`. If this field is used, the
# `DataProcessor`s and the model function will be ordered based on their
# input and output keys using topological sorting.
data_processors: Sequence[data_processor_base.DataProcessor] = ()
# A nested structure of tf.saved_model.experimental.TrackableResource that are
# used in `tf_preprocessor` and/or `tf_postprocessor`. If a TrackableResource
# an attritute of the `tf_preprocessor` (or `tf_postprocessor`), and the
# `tf_preprocessor` (or `tf_postprocessor`) is a tf.module,
# the TrackableResource does not need to be in `extra_trackable_resources`.
extra_trackable_resources: Any = None
# Specify the key of the JAX method of the `JaxModule` to be bound
# with this serving config. If unspecified, the `JaxModule` should have
# exactly one method which will be used.
method_key: Optional[str] = None
# Options passed to the Orbax Model export.
obm_export_options: obm_configs.ObmExportOptions | None = None
# When set to true, it allows a portion of the preprocessor's outputs to be
# directly passed to the tf_postprocessor, bypassing the JAX function.
#
# The primary use case of this option is to handle preprocessing outputs
# containing string tensor that cannot be passed to JAX function, but required
# by the postprocessor.
#
# Pre-requisites:
# This option requires the preprocessor outputs and postprocess inputs
# to be structured in specific ways:
# - The preprocessor must return two outputs, where the first will be passed
# as the input to the jax function and the second will be passed as the
# second input to the postprocessor.
# - The JAX function must take one input and return one output. The output
# will be passed as the first input to the postprocessor.
# - The postprocessor must take two inputs. The first is the output of the
# JAX function and the second is the second element of the preprocessor
# outputs.
#
# For example:
#
# def tf_preprocessor(x):
# return {'pre_out_to_jax': x}, {'pre_out_to_post': x}
#
# def jax_func(inputs: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]:
# return {'jax_out_to_post': inputs['pre_out_to_jax']}
#
# def tf_postprocessor(
# inputs: Mapping[str, tf.Tensor],
# inputs_extra: Mapping[str, tf.Tensor],
# ) -> Mapping[str, tf.Tensor]:
# return {
# 'post_out_from_jax': inputs['jax_out_to_post'],
# 'post_out_from_pre': inputs_extra['pre_out_to_post'],
# }
preprocess_output_passthrough_enabled: bool = False
[docs]
def __post_init__(self):
"""Post-initialization checks for ServingConfig.
Raises:
ValueError: If any of the following conditions are met:
- `obm_kwargs` and `obm_export_options` are both set.
- `signature_key` is not set.
- `data_processors` is set along with `tf_preprocessor`,
`preprocessors`, `tf_postprocessor`, or `postprocessors`.
- a processor in `data_processors` does not have `input_keys` or
`output_keys`.
- `tf_preprocessor` and `preprocessors` are both set.
- `tf_postprocessor` and `postprocessors` are both set.
"""
if not self.signature_key:
raise ValueError('`signature_key` must be set.')
if self.data_processors:
if (
self.tf_preprocessor
or self.preprocessors
or self.tf_postprocessor
or self.postprocessors
):
raise ValueError(
'`data_processors` cannot be set at the same time as'
' `tf_preprocessor`, `preprocessors`, `tf_postprocessor` or'
' `postprocessors`.'
)
for processor in self.data_processors:
if not processor.input_keys:
raise ValueError(
f'Processor {processor.name} in `data_processors` must have'
' `input_keys`.'
)
if not processor.output_keys:
raise ValueError(
f'Processor {processor.name} in `data_processors` must have'
' `output_keys`.'
)
else:
if self.tf_preprocessor and self.preprocessors:
raise ValueError(
'`tf_preprocessor` and `preprocessors` cannot be set at the same'
' time.'
)
if self.tf_postprocessor and self.postprocessors:
raise ValueError(
'`tf_postprocessor` and `postprocessors` cannot be set at the same'
' time.'
)
def get_signature_keys(self) -> Sequence[str]:
if isinstance(self.signature_key, str):
return [self.signature_key]
else:
return self.signature_key
[docs]
def get_infer_step(
self,
infer_step_fns: Union[
Callable[..., Any], Mapping[str, Callable[..., Any]]
],
) -> Callable[..., Any]:
"""Finds the right inference fn to be bound with the ServingConfig.
Args:
infer_step_fns: the method_key/infer_step dict. Usually the user can pass
`JaxModule.methods` here.
Returns:
method: the corresponding jax method of current ServingConfig.
"""
if callable(infer_step_fns):
return infer_step_fns
method_key = self.method_key
if method_key is None:
if len(infer_step_fns) != 1:
raise ValueError(
'`method_key` is not specified in ServingConfig '
f'"{self.signature_key}" and the infer_step_fns has more than one '
f' methods: {list(infer_step_fns)}. Please specify '
'`method_key` explicitly.'
)
(method,) = infer_step_fns.values() # this is a tuple-destructuring
return method
else:
if method_key not in infer_step_fns:
raise ValueError(
f'Method key "{method_key}" is not found in the infer_step_fns. '
f'Available method keys: {list(infer_step_fns.keys())}.'
)
return infer_step_fns[method_key]
[docs]
def bind(
self,
infer_step_fns: Union[
Callable[[jaxtyping.PyTree], jaxtyping.PyTree],
Mapping[str, Callable[[jaxtyping.PyTree], jaxtyping.PyTree]],
],
require_numpy: bool = True,
) -> Mapping[str, Callable[..., Mapping[Text, Any]]]:
"""Returns an e2e inference function by binding a inference step function.
Args:
infer_step_fns: An inference step function of a mapping of method key to
inference step function. If it is a mapping, the function whose key
matches the `method_key` of this ServingConfig will be used. If Users
only provide infer_step function, all `method_key`s use same infer_step
function.
require_numpy: Decide convert tf tensor to numpy after tf preprocess and
tf postprocess. As a rule of thumb, if infer_step is jax function, set
it to True. if infer_step if tf function, set it to False.
Return:
func_map: The mapping of serving signature to the inference function
bound with the pre- and post-processors of this ServingConfig.
"""
def make_inference_fn(infer_step):
"""Bind the preprocess, method and postproess together."""
preprocessor = tf.function(self.tf_preprocessor or (lambda *a: a))
postprocessor = tf.function(self.tf_postprocessor or (lambda *a: a))
def inference_fn(*inputs):
if self.tf_preprocessor:
preprocessor_outputs = preprocessor(*inputs)
if require_numpy:
preprocessor_outputs = jax.tree_util.tree_map(
lambda x: x.numpy(), preprocessor_outputs
)
if self.preprocess_output_passthrough_enabled:
if (
not isinstance(preprocessor_outputs, tuple)
or len(preprocessor_outputs) != 2
):
raise ValueError(
'`preprocess_output_passthrough_enabled` is enabled,'
' requiring the preprocessor output to be a tuple of two'
f' elements, but got {preprocessor_outputs} with'
f' type={type(preprocessor_outputs)} and'
f' length={len(preprocessor_outputs)}.'
)
jax_inputs, postprocessor_inputs_extra = preprocessor_outputs
else:
jax_inputs = preprocessor_outputs
else:
jax_inputs = inputs
if len(jax_inputs) != 1:
raise ValueError(
'JaxModule only takes single arg as the input, but got'
f' len(inputs)={len(inputs)} from the preprocessor or input'
' signature. Please pack all inputs into one PyTree by'
' modifying the `input_signature` (if no `tf_preprocessor`) or'
' the ServingConfig.tf_preprocessor.'
)
jax_inputs = jax_inputs[0]
# Currently Jax Module only takes 1 input
jax_outputs = infer_step(jax_inputs)
if logging.vlog_is_on(3) and require_numpy:
if hasattr(infer_step, 'lower'):
lower = infer_step.lower
else:
lower = jax.jit(infer_step).lower
mlir_module_text = lower(
jax_inputs,
).as_text()
logging.info(
'Jax function infer_step mlir module: = %s', mlir_module_text
)
if self.tf_postprocessor:
if self.preprocess_output_passthrough_enabled:
postprocessor_outputs = postprocessor(
jax_outputs, postprocessor_inputs_extra
)
else:
postprocessor_outputs = postprocessor(jax_outputs)
if require_numpy:
postprocessor_outputs = jax.tree_util.tree_map(
lambda x: x.numpy(), postprocessor_outputs
)
else:
postprocessor_outputs = jax_outputs
return postprocessor_outputs
return inference_fn
func_map = {}
infer_fn_with_processors = make_inference_fn(
self.get_infer_step(infer_step_fns)
)
for key in self.get_signature_keys():
func_map[key] = infer_fn_with_processors
return func_map