ServingConfig#

ServingConfig class.

ServingConfig#

class orbax.export.serving_config.ServingConfig(signature_key, input_signature=None, tf_preprocessor=None, tf_postprocessor=None, preprocessors=(), postprocessors=(), data_processors=(), extra_trackable_resources=None, method_key=None, obm_export_options=None, preprocess_output_passthrough_enabled=False)[source][source]#

Configuration for constructing a serving signature for a JaxModule.

A ServingConfig is to be bound with a JaxModule to form an end-to-end serving signature.

Example

Create a serving configuration with pre- and post-processors:

import tensorflow as tf
from orbax.export import tf_data_processor
from orbax.export.serving_config import ServingConfig

@tf.function(input_signature=[tf.TensorSpec(
    shape=(None, 32), dtype=tf.float32
)])
def preprocessor(inputs):
  return {'normalized': inputs / 255.0}

def postprocessor(outputs):
  return {'probabilities': tf.nn.softmax(outputs)}

config = ServingConfig(
    signature_key='serving_default',
    preprocessors=[tf_data_processor.TfDataProcessor(preprocessor)],
    postprocessors=[tf_data_processor.TfDataProcessor(postprocessor)]
)
signature_key#

The key of the serving signature or a sequence of keys mapping to the same serving signature.

Type:

str | collections.abc.Sequence[str]

input_signature#

The input signature for tf_preprocessor (or the JaxModule method if there is no preprocessor). If not specified, this will be inferred from tf_preprocessor.input_signature.

Type:

collections.abc.Sequence[jaxtyping.PyTree] | None

tf_preprocessor#

Optional pre-processing function written in TF.

Type:

collections.abc.Callable[[…], Any] | None

tf_postprocessor#

Optional post-processing function written in TF.

Type:

collections.abc.Callable[[…], Any] | None

preprocessors#

Optional sequence of DataProcessor`s to be applied before the main model function. Mutually exclusive with `tf_preprocessor.

Type:

collections.abc.Sequence[orbax.export.data_processors.data_processor_base.DataProcessor]

postprocessors#

Optional sequence of DataProcessor`s to be applied after the main model function. Mutually exclusive with `tf_postprocessor.

Type:

collections.abc.Sequence[orbax.export.data_processors.data_processor_base.DataProcessor]

data_processors#

Optional sequence of `DataProcessor`s. Mutually exclusive with other processors. Ordered based on input/output keys via topological sort.

Type:

collections.abc.Sequence[orbax.export.data_processors.data_processor_base.DataProcessor]

extra_trackable_resources#

A nested structure of trackable resources used in TF processors.

Type:

Any

method_key#

The key of the JAX method of the JaxModule to be bound.

Type:

str | None

obm_export_options#

Options passed to the Orbax Model export.

Type:

orbax.export.obm_configs.ObmExportOptions | None

preprocess_output_passthrough_enabled#

When True, allows a portion of the preprocessor’s outputs to be directly passed to the tf_postprocessor, bypassing the JAX function. Requires the preprocessor to return a tuple of two elements: (jax_inputs, postprocessor_inputs_extra).

Type:

bool

__post_init__()[source][source]#

Post-initialization checks for ServingConfig.

Raises:

ValueError – If any of the following conditions are met: - obm_kwargs and obm_export_options are both set. - signature_key is not set. - data_processors is set along with tf_preprocessor, preprocessors, tf_postprocessor, or postprocessors. - a processor in data_processors does not have input_keys or output_keys. - tf_preprocessor and preprocessors are both set. - tf_postprocessor and postprocessors are both set.

get_input_signature(required=True)[source][source]#

Gets the input signature from the explict one or tf_preprocessor.

Return type:

Any

get_infer_step(infer_step_fns)[source][source]#

Finds the right inference fn to be bound with the ServingConfig.

Parameters:

infer_step_fns (Union[Callable[…, Any], Mapping[str, Callable[…, Any]]]) – the method_key/infer_step dict. Usually the user can pass JaxModule.methods here.

Returns:

the corresponding jax method of current ServingConfig.

Return type:

method

bind(infer_step_fns, require_numpy=True)[source][source]#

Returns an e2e inference function by binding a inference step function.

Parameters:
  • infer_step_fns (Union[Callable[PyTree, PyTree], Mapping[str, Callable[PyTree, PyTree]]]) – An inference step function of a mapping of method key to inference step function. If it is a mapping, the function whose key matches the method_key of this ServingConfig will be used. If Users only provide infer_step function, all `method_key`s use same infer_step function.

  • require_numpy (bool) – Decide convert tf tensor to numpy after tf preprocess and tf postprocess. As a rule of thumb, if infer_step is jax function, set it to True. if infer_step if tf function, set it to False.

Returns:

The mapping of serving signature to the inference function

bound with the pre- and post-processors of this ServingConfig.

Return type:

func_map

__eq__(other)#

Return self==value.

__hash__ = None#
__init__(signature_key, input_signature=None, tf_preprocessor=None, tf_postprocessor=None, preprocessors=(), postprocessors=(), data_processors=(), extra_trackable_resources=None, method_key=None, obm_export_options=None, preprocess_output_passthrough_enabled=False)#