# Copyright 2024 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of ValidationManager class."""
from collections.abc import Mapping, Sequence
from typing import Any, Callable, Optional, Union
from absl import logging
import jax
from orbax.export import jax_module
from orbax.export.serving_config import ServingConfig
from orbax.export.validate.validation_job import ValidationJob
from orbax.export.validate.validation_job import ValidationSingleJobResult
from orbax.export.validate.validation_report import ValidationReport
from orbax.export.validate.validation_report import ValidationReportOption
import tensorflow as tf
def _is_flat_dict(x):
"""Checks if x is flat dict."""
if not isinstance(x, Mapping):
return False
return all(
jax.tree_util.treedef_is_leaf(jax.tree_util.tree_structure(v))
for v in x.values()
)
def _is_flat_sequence(x):
"""Checks if x is flat list."""
if not isinstance(x, Sequence):
return False
return all(
jax.tree_util.treedef_is_leaf(jax.tree_util.tree_structure(v)) for v in x
)
[docs]class ValidationManager:
"""Validate the JaxModule and its output tf saved model."""
[docs] def __init__(
self,
module: Union[
jax_module.JaxModule,
Mapping[str, Callable[[jax_module.PyTree], jax_module.PyTree]],
],
serving_configs: Sequence[ServingConfig],
model_inputs: Union[Sequence[Any], Mapping[str, Sequence[Any]]],
):
"""Create the ValidationManager ojbect.
Args:
module: the JaxModule object.
serving_configs: the ServingConfig Sequence.
model_inputs: The inputs for saved TF SavedModel. It support two formats:
(1) A mapping of signature key to a sequences batch inputs; or (2) a
sequence of batch inputs to validate all signatures.
"""
if isinstance(module, jax_module.JaxModule):
self._jax_methods = module.jax_methods
else:
logging.warn(
'Using Mapping[str, Callable] to initialize ValidationManager is'
' deprecated. Use JaxModule instead.'
)
self._jax_methods = module
self._serving_configs = serving_configs
self._model_inputs = model_inputs
def _create_baseline_fns(self) -> Mapping[str, Callable[..., Any]]:
"""Returns a map from signature keys to validation functions."""
validation_func_map = {}
for sc in self._serving_configs:
validation_func_map.update((sc.bind(self._jax_methods)))
return validation_func_map
def _create_candidate_fns(
self, loaded_model: Any
) -> Mapping[str, Callable[..., Any]]:
"""Returns a map from signature keys to candidate functions.
Args:
loaded_model: The user should provided the loaded_model. For CPU,
`loaded_model=tf.saved_model.load(tf_model_path, ['serve'])` for TPU,
`loaded_model=tf.saved_model.load(tf_model_path, ['serve', 'tpu'])`
"""
loaded_model_signatures = loaded_model.signatures
candidate_func_map = {}
def make_candidate_inference_fn(signature_key):
def inference_fn(*inputs):
if len(inputs) != 1:
raise ValueError(
'Currently does not accept multiple args, '
f'got len(inputs)={len(inputs)}.'
)
real_inputs = inputs[0]
real_inputs = jax.tree_util.tree_map(tf.convert_to_tensor, real_inputs)
self.check_input(real_inputs, batch_mode=False)
if isinstance(real_inputs, Mapping):
outputs = loaded_model_signatures[signature_key](**real_inputs)
elif isinstance(real_inputs, Sequence):
outputs = loaded_model_signatures[signature_key](*real_inputs)
else:
outputs = loaded_model_signatures[signature_key](real_inputs)
outputs = jax.tree_util.tree_map(
lambda x: x.numpy() if hasattr(x, 'numpy') else x, outputs
)
return outputs
return inference_fn
for sc in self._serving_configs:
for key in sc.get_signature_keys():
candidate_func_map[key] = make_candidate_inference_fn(key)
return candidate_func_map
def _create_input_map(self) -> Mapping[str, Sequence[Any]]:
"""Converts batch input into Mapping[signature_key, batch input] if need."""
if isinstance(self._model_inputs, Mapping):
return self._model_inputs
model_inputs = {}
for sc in self._serving_configs:
for key in sc.get_signature_keys():
model_inputs[key] = self._model_inputs
return model_inputs
[docs] def validate(
self,
loaded_model: Any,
with_xprof: bool = False,
report_option: Optional[ValidationReportOption] = None,
) -> Mapping[str, ValidationReport]:
"""Validates the baseline and candidate function map."""
candidate_fns = self._create_candidate_fns(loaded_model)
baseline_fns = self._create_baseline_fns()
input_map = self._create_input_map()
results = {}
if not report_option:
report_option = ValidationReportOption()
for sc in self._serving_configs:
for key in sc.get_signature_keys():
validation_job = ValidationJob(
baseline_fns[key], candidate_fns[key], input_map[key], with_xprof
)
baseline_result = validation_job.calc_baseline_result()
candidate_result = validation_job.calc_candidate_result()
# Always convert list to Dict
baseline_result.maybe_convert_result_to_dict()
candidate_result.maybe_convert_result_to_dict()
self.check_output(baseline_result, candidate_result)
results[key] = ValidationReport(
baseline_result, candidate_result, report_option
)
return results
[docs] @classmethod
def check_output(
cls,
baseline_result: ValidationSingleJobResult,
candidate_result: ValidationSingleJobResult,
) -> None:
"""check model output format."""
baseline_outputs = baseline_result.outputs
candidate_outputs = candidate_result.outputs
if not _is_flat_dict(baseline_outputs[0]):
err_message = (
'Currently ValidationReport only accept flat dict outputs. '
f' But we got {type(baseline_outputs[0])}'
)
raise ValueError(err_message)
baseline_flat = jax.tree_util.tree_leaves(baseline_outputs)
candidate_flat = jax.tree_util.tree_leaves(candidate_outputs)
if len(baseline_flat) != len(candidate_flat):
raise ValueError(
'baseline and candidate has different output length.'
f'len(baseline) = {len(baseline_flat)},'
f'len(candidate) = {len(candidate_flat)}.'
)