# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Define ValidationReport class here."""
import dataclasses
import pathlib
from typing import Any, Dict, Optional, Union
from absl import logging
import dataclasses_json
import jax
import numpy as np
from orbax.export.validate.validation_job import ValidationSingleJobResult
from orbax.export.validate.validation_utils import get_latency_stat
from orbax.export.validate.validation_utils import split_tf_floating_and_discrete_groups
from orbax.export.validate.validation_utils import Status
XprofURL = str
MetaData = Any
[docs]
@dataclasses_json.dataclass_json
@dataclasses.dataclass
class ValidationReportOption:
"""Option for ValidationReport class.
Attributes:
floating_atol: Absolute tolerance parameter for floating-point comparisons.
floating_rtol: Relative tolerance parameter for floating-point comparisons.
max_non_floating_mismatch_ratio: The maximum allowable ratio of mismatches
for discrete/non-floating-point tensors before the validation fails.
output_report_path: Optional path to save the generated report.
print_debug_info: If True, prints detailed tensor mismatches to the logs.
"""
floating_atol: float = 1e-7
floating_rtol: float = 1e-7
max_non_floating_mismatch_ratio: float = 1e-2
output_report_path: Optional[Union[str, pathlib.Path]] = None
print_debug_info: bool = False
[docs]
def __post_init__(self):
"""check if option value is legal.
Raises:
OverflowError: raise if floating_atol < 0 or floating_rtol < 0.
"""
if self.floating_atol < 0:
raise OverflowError('floating_atol should be larger than zero.')
if self.floating_rtol < 0:
raise OverflowError('floating_rtol should be larger than zero.')
@dataclasses_json.dataclass_json
@dataclasses.dataclass
class LatencyStat:
"""The latency indicator of ML job."""
num_batches: int
avg_in_ms: float
p90_in_ms: float
p99_in_ms: float
@dataclasses_json.dataclass_json
@dataclasses.dataclass
class FloatingPointDiffReport:
total: int
max_diff: float
max_rel_diff: float
all_close: bool
all_close_absolute_tolerance: float
all_close_relative_tolerance: float
@dataclasses_json.dataclass_json
@dataclasses.dataclass
class NonFloatingPointDiffReport:
total_flattened_tensors: int
mismatches: int
mismatch_ratio: float
max_non_floating_mismatch_ratio: float
[docs]
@dataclasses_json.dataclass_json
@dataclasses.dataclass(init=False, eq=False)
class ValidationReport:
"""Generate validation report based on ValidationSingleJobResult.
This class analyzes the execution results from the baseline (JAX) and
candidate (TF SavedModel) runs, comparing their outputs based on configured
tolerances. It computes latency percentiles and numerical divergence to
assign a final Pass/Fail status.
Example:
Configure tolerances and generate a validation report::
from orbax.export.validate.validation_report import ValidationReport,
ValidationReportOption
# Assume `baseline_result` and `candidate_result` were generated by
ValidationJob
# Loosen the floating-point tolerances for validation
options = ValidationReportOption(
floating_atol=1e-5,
floating_rtol=1e-5,
print_debug_info=True
)
report = ValidationReport(
baseline=baseline_result,
candidate=candidate_result,
option=options
)
if report.status == Status.Pass:
print("Validation Successful!")
Attributes:
outputs: A dictionary mapping report types to their respective diff reports
(`FloatingPointDiffReport` or `NonFloatingPointDiffReport`).
latency: A dictionary containing latency statistics for both the 'baseline'
and 'candidate' runs.
xprof_url: Profiling URLs mapped by 'baseline' and 'candidate'.
metadata: Contextual metadata mapped by 'baseline' and 'candidate'.
status: The final validation `Status` (Pass or Fail).
"""
outputs: Dict[str, Union[FloatingPointDiffReport, NonFloatingPointDiffReport]]
latency: Dict[str, LatencyStat]
xprof_url: Dict[str, XprofURL]
metadata: Dict[str, MetaData]
status: Status
[docs]
def __init__(self,
baseline: ValidationSingleJobResult,
candidate: ValidationSingleJobResult,
option: Optional[ValidationReportOption] = None):
"""Generate validation result report with users config options.
Args:
baseline: The baseline ValidationSingleJobResult.
candidate: The candidate ValidationSingleJobResult. The comparing
criterions will be apply on candidate.
option: ValidationReport options.
Raises:
ValueError: If the baseline and candidate result trees have different
structures, or if the flattened floating/non-floating arrays have
mismatched lengths.
"""
if not option:
self._option = ValidationReportOption()
else:
self._option = option
floating_atol = self._option.floating_atol
floating_rtol = self._option.floating_rtol
max_non_floating_mismatch_ratio = (
self._option.max_non_floating_mismatch_ratio)
baseline_latencies = baseline.latencies
baseline_outputs = baseline.outputs
baseline_url = baseline.xprof_url
self.status = Status.Pass
candidate_latencies = candidate.latencies
candidate_outputs = candidate.outputs
# TODO(b/251969924): check baseline and candidate have same structure.
candidate_url = candidate.xprof_url
num_batches, avg_in_ms, p90_in_ms, p99_in_ms = get_latency_stat(
baseline_latencies)
baseline_latency_stat = LatencyStat(num_batches, avg_in_ms, p90_in_ms,
p99_in_ms)
num_batches, avg_in_ms, p90_in_ms, p99_in_ms = get_latency_stat(
candidate_latencies)
candidate_latency_stat = LatencyStat(num_batches, avg_in_ms, p90_in_ms,
p99_in_ms)
baseline_outputs_tree_def = jax.tree_util.tree_structure(baseline_outputs)
candidate_outputs_tree_def = jax.tree_util.tree_structure(candidate_outputs)
if baseline_outputs_tree_def != candidate_outputs_tree_def:
raise ValueError(
'baseline and candidate result have diff tree_def.'
f'baseline tree_def = {baseline_outputs_tree_def}'
f'candidate tree_def = {candidate_outputs_tree_def}'
)
baseline_floatings, baseline_non_floatings = (
split_tf_floating_and_discrete_groups(baseline_outputs)
)
candidate_floatings, candidate_non_floatings = (
split_tf_floating_and_discrete_groups(candidate_outputs)
)
if baseline_floatings.size != candidate_floatings.size:
raise ValueError(
'baseline and candidate floating result have different length. '
f'baseline = {baseline_floatings}, candidate = {candidate_floatings}'
)
if len(baseline_non_floatings) != len(candidate_non_floatings):
raise ValueError(
'baseline and candidate non-floating result have different length. '
f'baseline = {baseline_floatings}, candidate = {candidate_floatings}'
)
self.outputs = {}
if baseline_floatings.size == 0:
logging.info('No floating-point outputs.')
else:
max_diff = np.abs(candidate_floatings - baseline_floatings).max()
max_rel_diff = (np.abs(candidate_floatings - baseline_floatings) /
np.maximum(np.abs(baseline_floatings), 1e-6)).max()
all_close = np.allclose(candidate_floatings, baseline_floatings,
floating_atol, floating_rtol)
if all_close:
logging.info(
'Baseline and candidate floating-point results are all close '
'(atol=%f, rtol=%f). max_diff=%f, max_rel_diff=%f', floating_atol,
floating_rtol, max_diff, max_rel_diff)
else:
logging.warning(
'Baseline and candidate floating-point results are not all close. '
'max_diff=%f, max_rel_diff=%f.', max_diff, max_rel_diff)
if self._option.print_debug_info:
logging.warning('baseline_floatings = %s', baseline_floatings)
logging.warning('candidate_floatings = %s', candidate_floatings)
self.status = Status.Fail
self.outputs['FloatingPointDiffReport'] = FloatingPointDiffReport(
total=int(baseline_floatings.size),
max_diff=float(max_diff),
max_rel_diff=float(max_rel_diff),
all_close=all_close,
all_close_absolute_tolerance=floating_atol,
all_close_relative_tolerance=floating_rtol,
)
mismatches = sum(
np.all(j != t)
for j, t in zip(baseline_non_floatings, candidate_non_floatings))
total_non_floatings = len(baseline_non_floatings)
mismatch_ratio = .0
if total_non_floatings > 0:
mismatch_ratio = mismatches / total_non_floatings
if mismatch_ratio <= max_non_floating_mismatch_ratio:
logging.info(
'%d Baseline/Candidate mismatches over %d non-floating-point results.'
'Mismatch ratio is %f (<= %f threshold).', mismatches,
total_non_floatings, mismatch_ratio, max_non_floating_mismatch_ratio)
else:
logging.warning(
(
'%d Baseline/Candidate mismatches over %d non-floating-point'
' results.Mismatch ratio is %f (> %f threshold).'
),
mismatches,
total_non_floatings,
mismatch_ratio,
max_non_floating_mismatch_ratio,
)
if self._option.print_debug_info:
logging.warning('baseline_non_floatings = %s', baseline_non_floatings)
logging.warning('candidate_non_floatings = %s', candidate_non_floatings)
self.status = Status.Fail
self.outputs['NonFloatingPointDiffReport'] = NonFloatingPointDiffReport(
total_flattened_tensors=int(total_non_floatings),
mismatches=int(mismatches),
mismatch_ratio=float(mismatch_ratio),
max_non_floating_mismatch_ratio=float(max_non_floating_mismatch_ratio))
# Create the result python dict.
self.latency = {
'baseline': baseline_latency_stat,
'candidate': candidate_latency_stat
}
self.xprof_url = {'baseline': baseline_url, 'candidate': candidate_url}
self.metadata = {
'baseline': baseline.metadata,
'candidate': candidate.metadata
}