ServingConfig#

ServingConfig class.

ServingConfig#

class orbax.export.serving_config.ServingConfig(signature_key, input_signature=None, tf_preprocessor=None, tf_postprocessor=None, extra_trackable_resources=None, method_key=None)[source][source]#

Configuration for constructing a serving signature for a JaxModule.

A ServingConfig is to be bound with a JaxModule to form an end-to-end serving signature.

get_input_signature(required=True)[source][source]#

Gets the input signature from the explict one or tf_preprocessor.

Return type:

Any

get_infer_step(infer_step_fns)[source][source]#

Finds the right inference fn to be bound with the ServingConfig.

Parameters:

infer_step_fns (Union[Callable[…, Any], Mapping[str, Callable[…, Any]]]) – the method_key/infer_step dict. Ususally the user can pass JaxModule.methods here.

Returns:

the corresponding jax method of current ServingConfig.

Return type:

method

bind(infer_step_fns, require_numpy=True)[source][source]#

Returns an e2e inference function by binding a inference step function.

Parameters:
  • infer_step_fns (Union[Callable[[Any], Any], Mapping[str, Callable[[Any], Any]]]) – An inference step function of a mapping of method key to inference step function. If it is a mapping, the function whose key matches the method_key of this ServingConfig will be used. If Users only provide infer_step function, all `method_key`s use same infer_step function.

  • require_numpy (bool) – Decide convert tf tensor to numpy after tf preprocess and tf postprocess. As a rule of thumb, if infer_step is jax function, set it to True. if infer_step if tf function, set it to False.

Returns:

The mapping of serving signature to the inference function

bound with the pre- and post-processors of this ServingConfig.

Return type:

func_map

__eq__(other)#

Return self==value.

__hash__ = None#
__init__(signature_key, input_signature=None, tf_preprocessor=None, tf_postprocessor=None, extra_trackable_resources=None, method_key=None)#