ServingConfig#
ServingConfig class.
ServingConfig#
- class orbax.export.serving_config.ServingConfig(signature_key, input_signature=None, tf_preprocessor=None, tf_postprocessor=None, extra_trackable_resources=None, method_key=None)[source][source]#
Configuration for constructing a serving signature for a JaxModule.
A ServingConfig is to be bound with a JaxModule to form an end-to-end serving signature.
- get_input_signature(required=True)[source][source]#
Gets the input signature from the explict one or tf_preprocessor.
- Return type:
Any
- get_infer_step(infer_step_fns)[source][source]#
Finds the right inference fn to be bound with the ServingConfig.
- Parameters:
infer_step_fns (
Union
[Callable
[…,Any
],Mapping
[str
,Callable
[…,Any
]]]) – the method_key/infer_step dict. Ususally the user can pass JaxModule.methods here.- Returns:
the corresponding jax method of current ServingConfig.
- Return type:
method
- bind(infer_step_fns, require_numpy=True)[source][source]#
Returns an e2e inference function by binding a inference step function.
- Parameters:
infer_step_fns (
Union
[Callable
[[Any
],Any
],Mapping
[str
,Callable
[[Any
],Any
]]]) – An inference step function of a mapping of method key to inference step function. If it is a mapping, the function whose key matches the method_key of this ServingConfig will be used. If Users only provide infer_step function, all `method_key`s use same infer_step function.require_numpy (
bool
) – Decide convert tf tensor to numpy after tf preprocess and tf postprocess. As a rule of thumb, if infer_step is jax function, set it to True. if infer_step if tf function, set it to False.
- Returns:
- The mapping of serving signature to the inference function
bound with the pre- and post-processors of this ServingConfig.
- Return type:
func_map
- __eq__(other)#
Return self==value.
- __hash__ = None#
- __init__(signature_key, input_signature=None, tf_preprocessor=None, tf_postprocessor=None, extra_trackable_resources=None, method_key=None)#