ServingConfig#
ServingConfig class.
ServingConfig#
- class orbax.export.serving_config.ServingConfig(signature_key, input_signature=None, tf_preprocessor=None, tf_postprocessor=None, preprocessors=(), postprocessors=(), data_processors=(), extra_trackable_resources=None, method_key=None, obm_export_options=None, preprocess_output_passthrough_enabled=False)[source][source]#
Configuration for constructing a serving signature for a JaxModule.
A ServingConfig is to be bound with a JaxModule to form an end-to-end serving signature.
Example
Create a serving configuration with pre- and post-processors:
import tensorflow as tf from orbax.export import tf_data_processor from orbax.export.serving_config import ServingConfig @tf.function(input_signature=[tf.TensorSpec( shape=(None, 32), dtype=tf.float32 )]) def preprocessor(inputs): return {'normalized': inputs / 255.0} def postprocessor(outputs): return {'probabilities': tf.nn.softmax(outputs)} config = ServingConfig( signature_key='serving_default', preprocessors=[tf_data_processor.TfDataProcessor(preprocessor)], postprocessors=[tf_data_processor.TfDataProcessor(postprocessor)] )
- signature_key#
The key of the serving signature or a sequence of keys mapping to the same serving signature.
- Type:
str | collections.abc.Sequence[str]
- input_signature#
The input signature for tf_preprocessor (or the JaxModule method if there is no preprocessor). If not specified, this will be inferred from tf_preprocessor.input_signature.
- Type:
collections.abc.Sequence[jaxtyping.PyTree] | None
- tf_preprocessor#
Optional pre-processing function written in TF.
- Type:
collections.abc.Callable[[…], Any] | None
- tf_postprocessor#
Optional post-processing function written in TF.
- Type:
collections.abc.Callable[[…], Any] | None
- preprocessors#
Optional sequence of DataProcessor`s to be applied before the main model function. Mutually exclusive with `tf_preprocessor.
- Type:
collections.abc.Sequence[orbax.export.data_processors.data_processor_base.DataProcessor]
- postprocessors#
Optional sequence of DataProcessor`s to be applied after the main model function. Mutually exclusive with `tf_postprocessor.
- Type:
collections.abc.Sequence[orbax.export.data_processors.data_processor_base.DataProcessor]
- data_processors#
Optional sequence of `DataProcessor`s. Mutually exclusive with other processors. Ordered based on input/output keys via topological sort.
- Type:
collections.abc.Sequence[orbax.export.data_processors.data_processor_base.DataProcessor]
- extra_trackable_resources#
A nested structure of trackable resources used in TF processors.
- Type:
Any
- method_key#
The key of the JAX method of the JaxModule to be bound.
- Type:
str | None
- obm_export_options#
Options passed to the Orbax Model export.
- Type:
orbax.export.obm_configs.ObmExportOptions | None
- preprocess_output_passthrough_enabled#
When True, allows a portion of the preprocessor’s outputs to be directly passed to the tf_postprocessor, bypassing the JAX function. Requires the preprocessor to return a tuple of two elements: (jax_inputs, postprocessor_inputs_extra).
- Type:
bool
- __post_init__()[source][source]#
Post-initialization checks for ServingConfig.
- Raises:
ValueError – If any of the following conditions are met: - obm_kwargs and obm_export_options are both set. - signature_key is not set. - data_processors is set along with tf_preprocessor, preprocessors, tf_postprocessor, or postprocessors. - a processor in data_processors does not have input_keys or output_keys. - tf_preprocessor and preprocessors are both set. - tf_postprocessor and postprocessors are both set.
- get_input_signature(required=True)[source][source]#
Gets the input signature from the explict one or tf_preprocessor.
- Return type:
Any
- get_infer_step(infer_step_fns)[source][source]#
Finds the right inference fn to be bound with the ServingConfig.
- Parameters:
infer_step_fns (
Union[Callable[…,Any],Mapping[str,Callable[…,Any]]]) – the method_key/infer_step dict. Usually the user can pass JaxModule.methods here.- Returns:
the corresponding jax method of current ServingConfig.
- Return type:
method
- bind(infer_step_fns, require_numpy=True)[source][source]#
Returns an e2e inference function by binding a inference step function.
- Parameters:
infer_step_fns (
Union[Callable[PyTree,PyTree],Mapping[str,Callable[PyTree,PyTree]]]) – An inference step function of a mapping of method key to inference step function. If it is a mapping, the function whose key matches the method_key of this ServingConfig will be used. If Users only provide infer_step function, all `method_key`s use same infer_step function.require_numpy (
bool) – Decide convert tf tensor to numpy after tf preprocess and tf postprocess. As a rule of thumb, if infer_step is jax function, set it to True. if infer_step if tf function, set it to False.
- Returns:
- The mapping of serving signature to the inference function
bound with the pre- and post-processors of this ServingConfig.
- Return type:
func_map
- __eq__(other)#
Return self==value.
- __hash__ = None#
- __init__(signature_key, input_signature=None, tf_preprocessor=None, tf_postprocessor=None, preprocessors=(), postprocessors=(), data_processors=(), extra_trackable_resources=None, method_key=None, obm_export_options=None, preprocess_output_passthrough_enabled=False)#