# Copyright 2024 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ServingConfig class."""
import dataclasses
from typing import Any, Callable, Mapping, Optional, Sequence, Text, Union
from absl import logging
import jax
from orbax.export import utils as orbax_export_utils
import tensorflow as tf
PyTree = orbax_export_utils.PyTree
[docs]@dataclasses.dataclass
class ServingConfig:
"""Configuration for constructing a serving signature for a JaxModule.
A ServingConfig is to be bound with a JaxModule to form an end-to-end serving
signature.
"""
# The key of the serving signature or a sequence of keys mapping to the same
# serving signature.
signature_key: Union[str, Sequence[str]]
# The input signature for `tf_preprocessor` (or the JaxModule method if there
# is no `tf_preprocessor`). If not specified, this will be infered from
# `tf_preprocessor`, in which case `tf_preprocessor` must be a tf.function
# with `input_signature` annotation. See
# https://www.tensorflow.org/api_docs/python/tf/function#input_signatures.
input_signature: Optional[Sequence[PyTree]] = None
# Optional pre-precessing function written in TF.
tf_preprocessor: Optional[Callable[..., Any]] = None
# Optional post-processing function written in TF.
tf_postprocessor: Optional[Callable[..., Any]] = None
# A nested structure of tf.saved_model.experimental.TrackableResource that are
# used in `tf_preprocessor` and/or `tf_postprocessor`. If a TrackableResource
# an attritute of the `tf_preprocessor` (or `tf_postprocessor`), and the
# `tf_preprocessor` (or `tf_postprocessor`) is a tf.module,
# the TrackableResource does not need to be in `extra_trackable_resources`.
extra_trackable_resources: Any = None
# Specify the key of the JAX method of the `JaxModule` to be bound
# with this serving config. If unspecified, the `JaxModule` should have
# exactly one method which will be used.
method_key: Optional[str] = None
def get_signature_keys(self) -> Sequence[str]:
if isinstance(self.signature_key, str):
return [self.signature_key]
else:
return self.signature_key
[docs] def get_infer_step(
self,
infer_step_fns: Union[
Callable[..., Any], Mapping[str, Callable[..., Any]]
],
) -> Callable[..., Any]:
"""Finds the right inference fn to be bound with the ServingConfig.
Args:
infer_step_fns: the method_key/infer_step dict. Ususally the user can pass
`JaxModule.methods` here.
Returns:
method: the corresponding jax method of current ServingConfig.
"""
if callable(infer_step_fns):
return infer_step_fns
method_key = self.method_key
if method_key is None:
if len(infer_step_fns) != 1:
raise ValueError(
'`method_key` is not specified in ServingConfig '
f'"{self.signature_key}" and the infer_step_fns has more than one '
f' methods: {list(infer_step_fns)}. Please specify '
'`method_key` explicitly.'
)
(method,) = infer_step_fns.values() # this is a tuple-destructuring
return method
else:
if method_key not in infer_step_fns:
raise ValueError(
f'Method key "{method_key}" is not found in the infer_step_fns. '
f'Available method keys: {list(infer_step_fns.keys())}.'
)
return infer_step_fns[method_key]
[docs] def bind(
self,
infer_step_fns: Union[
Callable[[PyTree], PyTree], Mapping[str, Callable[[PyTree], PyTree]]
],
require_numpy: bool = True,
) -> Mapping[str, Callable[..., Mapping[Text, Any]]]:
"""Returns an e2e inference function by binding a inference step function.
Args:
infer_step_fns: An inference step function of a mapping of method key to
inference step function. If it is a mapping, the function whose key
matches the `method_key` of this ServingConfig will be used. If Users
only provide infer_step function, all `method_key`s use same infer_step
function.
require_numpy: Decide convert tf tensor to numpy after tf preprocess and
tf postprocess. As a rule of thumb, if infer_step is jax function, set
it to True. if infer_step if tf function, set it to False.
Return:
func_map: The mapping of serving signature to the inference function
bound with the pre- and post-processors of this ServingConfig.
"""
def make_inference_fn(infer_step):
"""Bind the preprocess, method and postproess together."""
preprocessor = tf.function(self.tf_preprocessor or (lambda *a: a))
postprocessor = tf.function(self.tf_postprocessor or (lambda *a: a))
def inference_fn(*inputs):
if self.tf_preprocessor:
preprocessed_inputs = preprocessor(*inputs)
if require_numpy:
preprocessed_inputs = jax.tree_util.tree_map(
lambda x: x.numpy(), preprocessed_inputs
)
else:
preprocessed_inputs = inputs
if len(preprocessed_inputs) != 1:
raise ValueError(
'JaxModule only takes single arg as the input, but got'
f' len(inputs)={len(inputs)} from the preprocessor or input'
' signature. Please pack all inputs into one PyTree by'
' modifying the `input_signature` (if no `tf_preprocessor`) or'
' the ServingConfig.tf_preprocessor.'
)
preprocessed_inputs = preprocessed_inputs[0]
# Currently Jax Module only takes 1 input
outputs = infer_step(preprocessed_inputs)
if logging.vlog_is_on(3) and require_numpy:
if hasattr(infer_step, 'lower'):
lower = infer_step.lower
else:
lower = jax.jit(infer_step).lower
mlir_module_text = lower(
preprocessed_inputs,
).as_text()
logging.info(
'Jax function infer_step mlir module: = %s', mlir_module_text
)
if self.tf_postprocessor:
outputs = postprocessor(outputs)
if require_numpy:
outputs = jax.tree_util.tree_map(lambda x: x.numpy(), outputs)
return outputs
return inference_fn
func_map = {}
infer_fn_with_processors = make_inference_fn(
self.get_infer_step(infer_step_fns)
)
for key in self.get_signature_keys():
func_map[key] = infer_fn_with_processors
return func_map