# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of ValidationManager class."""
from collections.abc import Mapping, Sequence
from typing import Any, Callable, Optional, Union
from absl import logging
import jax
from orbax.export import jax_module
from orbax.export.serving_config import ServingConfig
from orbax.export.validate.validation_job import ValidationJob
from orbax.export.validate.validation_job import ValidationSingleJobResult
from orbax.export.validate.validation_report import ValidationReport
from orbax.export.validate.validation_report import ValidationReportOption
import tensorflow as tf
def _is_flat_dict(x):
"""Checks if x is flat dict."""
if not isinstance(x, Mapping):
return False
return all(
jax.tree_util.treedef_is_leaf(jax.tree_util.tree_structure(v))
for v in x.values()
)
def _is_flat_sequence(x):
"""Checks if x is flat list."""
if not isinstance(x, Sequence):
return False
return all(
jax.tree_util.treedef_is_leaf(jax.tree_util.tree_structure(v)) for v in x
)
[docs]
class ValidationManager:
"""Validate the JaxModule and its output tf saved model.
This manager orchestrates the validation process by feeding identical inputs
into both the original JAX mathematical functions (the baseline) and the
exported TensorFlow SavedModel (the candidate). It then generates a report
comparing their outputs to ensure numerical and structural parity.
Example:
Validate an exported model against its original JAX implementation::
import tensorflow as tf
from orbax.export.validate.validation_manager import ValidationManager
# Assume `my_jax_module` is your JaxModule and `my_config` is your
ServingConfig
test_inputs = [{'input_tensor': tf.ones((1, 32))}]
# Initialize the manager with the module and configurations
validator = ValidationManager(
module=my_jax_module,
serving_configs=[my_config],
model_inputs=test_inputs
)
# Assume `loaded_tf_model` is the result of tf.saved_model.load(...)
# Run the validation to compare JAX vs TF outputs
reports = validator.validate(loaded_tf_model)
"""
[docs]
def __init__(
self,
module: Union[
jax_module.JaxModule,
Mapping[str, Callable[[jax_module.PyTree], jax_module.PyTree]],
],
serving_configs: Sequence[ServingConfig],
model_inputs: Union[Sequence[Any], Mapping[str, Sequence[Any]]],
):
"""Create the ValidationManager object.
Args:
module: the JaxModule object.
serving_configs: the ServingConfig Sequence.
model_inputs: The inputs for saved TF SavedModel. It support two formats:
(1) A mapping of signature key to a sequences batch inputs; or (2) a
sequence of batch inputs to validate all signatures.
"""
if isinstance(module, jax_module.JaxModule):
self._jax_methods = module.jax_methods
else:
logging.warn(
'Using Mapping[str, Callable] to initialize ValidationManager is'
' deprecated. Use JaxModule instead.'
)
self._jax_methods = module
self._serving_configs = serving_configs
self._model_inputs = model_inputs
def _create_baseline_fns(self) -> Mapping[str, Callable[..., Any]]:
"""Returns a map from signature keys to validation functions."""
validation_func_map = {}
for sc in self._serving_configs:
validation_func_map.update((sc.bind(self._jax_methods)))
return validation_func_map
def _create_candidate_fns(
self, loaded_model: Any
) -> Mapping[str, Callable[..., Any]]:
"""Returns a map from signature keys to candidate functions.
Args:
loaded_model: The user should provided the loaded_model. For CPU,
`loaded_model=tf.saved_model.load(tf_model_path, ['serve'])` for TPU,
`loaded_model=tf.saved_model.load(tf_model_path, ['serve', 'tpu'])`
"""
loaded_model_signatures = loaded_model.signatures
candidate_func_map = {}
def make_candidate_inference_fn(signature_key):
def inference_fn(*inputs):
if len(inputs) != 1:
raise ValueError(
'Currently does not accept multiple args, '
f'got len(inputs)={len(inputs)}.'
)
real_inputs = inputs[0]
real_inputs = jax.tree_util.tree_map(tf.convert_to_tensor, real_inputs)
self.check_input(real_inputs, batch_mode=False)
if isinstance(real_inputs, Mapping):
outputs = loaded_model_signatures[signature_key](**real_inputs)
elif isinstance(real_inputs, Sequence):
outputs = loaded_model_signatures[signature_key](*real_inputs)
else:
outputs = loaded_model_signatures[signature_key](real_inputs)
outputs = jax.tree_util.tree_map(
lambda x: x.numpy() if hasattr(x, 'numpy') else x, outputs
)
return outputs
return inference_fn
for sc in self._serving_configs:
for key in sc.get_signature_keys():
candidate_func_map[key] = make_candidate_inference_fn(key)
return candidate_func_map
def _create_input_map(self) -> Mapping[str, Sequence[Any]]:
"""Converts batch input into Mapping[signature_key, batch input] if need."""
if isinstance(self._model_inputs, Mapping):
return self._model_inputs
model_inputs = {}
for sc in self._serving_configs:
for key in sc.get_signature_keys():
model_inputs[key] = self._model_inputs
return model_inputs
[docs]
def validate(
self,
loaded_model: Any,
with_xprof: bool = False,
report_option: Optional[ValidationReportOption] = None,
) -> Mapping[str, ValidationReport]:
"""Validates the baseline and candidate function map.
Args:
loaded_model: The loaded TensorFlow SavedModel to validate against. For
CPU, this is usually `tf.saved_model.load(path, ['serve'])`.
with_xprof: Whether to enable XLA profiling during the validation run.
report_option: Optional `ValidationReportOption` to configure the
generated report's formatting and strictness.
Returns:
A mapping of signature keys to `ValidationReport` objects containing the
results of the comparison.
"""
candidate_fns = self._create_candidate_fns(loaded_model)
baseline_fns = self._create_baseline_fns()
input_map = self._create_input_map()
results = {}
if not report_option:
report_option = ValidationReportOption()
for sc in self._serving_configs:
for key in sc.get_signature_keys():
validation_job = ValidationJob(
baseline_fns[key], candidate_fns[key], input_map[key], with_xprof
)
baseline_result = validation_job.calc_baseline_result()
candidate_result = validation_job.calc_candidate_result()
# Always convert list to Dict
baseline_result.maybe_convert_result_to_dict()
candidate_result.maybe_convert_result_to_dict()
self.check_output(baseline_result, candidate_result)
results[key] = ValidationReport(
baseline_result, candidate_result, report_option
)
return results
[docs]
@classmethod
def check_output(
cls,
baseline_result: ValidationSingleJobResult,
candidate_result: ValidationSingleJobResult,
) -> None:
"""check model output format.
Args:
baseline_result: The `ValidationSingleJobResult` from the JAX model.
candidate_result: The `ValidationSingleJobResult` from the TF model.
Raises:
ValueError: If the outputs are not flat dictionaries, or if the baseline
and candidate models produce a different number of output elements.
"""
baseline_outputs = baseline_result.outputs
candidate_outputs = candidate_result.outputs
if not _is_flat_dict(baseline_outputs[0]):
err_message = (
'Currently ValidationReport only accept flat dict outputs. '
f' But we got {type(baseline_outputs[0])}'
)
raise ValueError(err_message)
baseline_flat = jax.tree_util.tree_leaves(baseline_outputs)
candidate_flat = jax.tree_util.tree_leaves(candidate_outputs)
if len(baseline_flat) != len(candidate_flat):
raise ValueError(
'baseline and candidate has different output length.'
f'len(baseline) = {len(baseline_flat)},'
f'len(candidate) = {len(candidate_flat)}.'
)