ValidationManager#
The definition of ValidationManager class.
ValidationManager#
- class orbax.export.validate.validation_manager.ValidationManager(module, serving_configs, model_inputs)[source][source]#
Validate the JaxModule and its output tf saved model.
This manager orchestrates the validation process by feeding identical inputs into both the original JAX mathematical functions (the baseline) and the exported TensorFlow SavedModel (the candidate). It then generates a report comparing their outputs to ensure numerical and structural parity.
Example
Validate an exported model against its original JAX implementation:
import tensorflow as tf from orbax.export.validate.validation_manager import ValidationManager # Assume `my_jax_module` is your JaxModule and `my_config` is your ServingConfig test_inputs = [{'input_tensor': tf.ones((1, 32))}] # Initialize the manager with the module and configurations validator = ValidationManager( module=my_jax_module, serving_configs=[my_config], model_inputs=test_inputs ) # Assume `loaded_tf_model` is the result of tf.saved_model.load(...) # Run the validation to compare JAX vs TF outputs reports = validator.validate(loaded_tf_model)
- __init__(module, serving_configs, model_inputs)[source][source]#
Create the ValidationManager object.
- Parameters:
module (
Union[JaxModule,Mapping[str,Callable[[PyTree],PyTree]]]) – the JaxModule object.serving_configs (
Sequence[ServingConfig]) – the ServingConfig Sequence.model_inputs (
Union[Sequence[Any],Mapping[str,Sequence[Any]]]) – The inputs for saved TF SavedModel. It support two formats: (1) A mapping of signature key to a sequences batch inputs; or (2) a sequence of batch inputs to validate all signatures.
- validate(loaded_model, with_xprof=False, report_option=None)[source][source]#
Validates the baseline and candidate function map.
- Parameters:
loaded_model (
Any) – The loaded TensorFlow SavedModel to validate against. For CPU, this is usually tf.saved_model.load(path, [‘serve’]).with_xprof (
bool) – Whether to enable XLA profiling during the validation run.report_option (
Optional[ValidationReportOption,None]) – Optional ValidationReportOption to configure the generated report’s formatting and strictness.
- Return type:
Mapping[str,ValidationReport]- Returns:
A mapping of signature keys to ValidationReport objects containing the results of the comparison.
- classmethod check_input(inputs, batch_mode=True)[source][source]#
check model input format.
- Parameters:
inputs (
Union[Any,Sequence[Any]]) – model inputs. If batch_mode == True, inputs should be a list.batch_mode (
bool) – it decide inputs is a list of the input or a single input.
- Raises:
ValueError – If batch_mode is True and inputs is not a sequence.
- Return type:
None
- classmethod check_output(baseline_result, candidate_result)[source][source]#
check model output format.
- Parameters:
baseline_result (
ValidationSingleJobResult) – The ValidationSingleJobResult from the JAX model.candidate_result (
ValidationSingleJobResult) – The ValidationSingleJobResult from the TF model.
- Raises:
ValueError – If the outputs are not flat dictionaries, or if the baseline and candidate models produce a different number of output elements.
- Return type:
None