Source code for orbax.export.validate.validation_manager

# Copyright 2024 The Orbax Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

"""The definition of ValidationManager class."""
from import Mapping, Sequence
from typing import Any, Callable, Optional, Union

from absl import logging
import jax
from orbax.export import jax_module
from orbax.export.serving_config import ServingConfig
from orbax.export.validate.validation_job import ValidationJob
from orbax.export.validate.validation_job import ValidationSingleJobResult
from orbax.export.validate.validation_report import ValidationReport
from orbax.export.validate.validation_report import ValidationReportOption
import tensorflow as tf

def _is_flat_dict(x):
  """Checks if x is flat dict."""
  if not isinstance(x, Mapping):
    return False
  return all(
      for v in x.values()

def _is_flat_sequence(x):
  """Checks if x is flat list."""
  if not isinstance(x, Sequence):
    return False
  return all(
      jax.tree_util.treedef_is_leaf(jax.tree_util.tree_structure(v)) for v in x

[docs]class ValidationManager: """Validate the JaxModule and its output tf saved model."""
[docs] def __init__( self, module: Union[ jax_module.JaxModule, Mapping[str, Callable[[jax_module.PyTree], jax_module.PyTree]], ], serving_configs: Sequence[ServingConfig], model_inputs: Union[Sequence[Any], Mapping[str, Sequence[Any]]], ): """Create the ValidationManager ojbect. Args: module: the JaxModule object. serving_configs: the ServingConfig Sequence. model_inputs: The inputs for saved TF SavedModel. It support two formats: (1) A mapping of signature key to a sequences batch inputs; or (2) a sequence of batch inputs to validate all signatures. """ if isinstance(module, jax_module.JaxModule): self._jax_methods = module.jax_methods else: logging.warn( 'Using Mapping[str, Callable] to initialize ValidationManager is' ' deprecated. Use JaxModule instead.' ) self._jax_methods = module self._serving_configs = serving_configs self._model_inputs = model_inputs
def _create_baseline_fns(self) -> Mapping[str, Callable[..., Any]]: """Returns a map from signature keys to validation functions.""" validation_func_map = {} for sc in self._serving_configs: validation_func_map.update((sc.bind(self._jax_methods))) return validation_func_map def _create_candidate_fns( self, loaded_model: Any ) -> Mapping[str, Callable[..., Any]]: """Returns a map from signature keys to candidate functions. Args: loaded_model: The user should provided the loaded_model. For CPU, `loaded_model=tf.saved_model.load(tf_model_path, ['serve'])` for TPU, `loaded_model=tf.saved_model.load(tf_model_path, ['serve', 'tpu'])` """ loaded_model_signatures = loaded_model.signatures candidate_func_map = {} def make_candidate_inference_fn(signature_key): def inference_fn(*inputs): if len(inputs) != 1: raise ValueError( 'Currently does not accept multiple args, ' f'got len(inputs)={len(inputs)}.' ) real_inputs = inputs[0] real_inputs = jax.tree_util.tree_map(tf.convert_to_tensor, real_inputs) self.check_input(real_inputs, batch_mode=False) if isinstance(real_inputs, Mapping): outputs = loaded_model_signatures[signature_key](**real_inputs) elif isinstance(real_inputs, Sequence): outputs = loaded_model_signatures[signature_key](*real_inputs) else: outputs = loaded_model_signatures[signature_key](real_inputs) outputs = jax.tree_util.tree_map( lambda x: x.numpy() if hasattr(x, 'numpy') else x, outputs ) return outputs return inference_fn for sc in self._serving_configs: for key in sc.get_signature_keys(): candidate_func_map[key] = make_candidate_inference_fn(key) return candidate_func_map def _create_input_map(self) -> Mapping[str, Sequence[Any]]: """Converts batch input into Mapping[signature_key, batch input] if need.""" if isinstance(self._model_inputs, Mapping): return self._model_inputs model_inputs = {} for sc in self._serving_configs: for key in sc.get_signature_keys(): model_inputs[key] = self._model_inputs return model_inputs
[docs] def validate( self, loaded_model: Any, with_xprof: bool = False, report_option: Optional[ValidationReportOption] = None, ) -> Mapping[str, ValidationReport]: """Validates the baseline and candidate function map.""" candidate_fns = self._create_candidate_fns(loaded_model) baseline_fns = self._create_baseline_fns() input_map = self._create_input_map() results = {} if not report_option: report_option = ValidationReportOption() for sc in self._serving_configs: for key in sc.get_signature_keys(): validation_job = ValidationJob( baseline_fns[key], candidate_fns[key], input_map[key], with_xprof ) baseline_result = validation_job.calc_baseline_result() candidate_result = validation_job.calc_candidate_result() # Always convert list to Dict baseline_result.maybe_convert_result_to_dict() candidate_result.maybe_convert_result_to_dict() self.check_output(baseline_result, candidate_result) results[key] = ValidationReport( baseline_result, candidate_result, report_option ) return results
[docs] @classmethod def check_input( cls, inputs: Union[Any, Sequence[Any]], batch_mode: bool = True ) -> None: """check model input format. Args: inputs: model inputs. If batch_mode == True, inputs should be a list. batch_mode: it decide `inputs` is a list of the input or a single input. """ if batch_mode: if not isinstance(inputs, Sequence): raise ValueError('Batch inputs should be a python list.') test_inputs = inputs[0] else: test_inputs = inputs if not _is_flat_dict(test_inputs) and not _is_flat_sequence(test_inputs): logging.warning( ( 'Recommend Orbax validate inputs format as flat_dict or' ' flat_list, so it can generate the consistent tf.SavedModel' ' signature. For arbitrary format, we assume it as atomic data' ' type, it may fail. Got inputs format %s' ), type(test_inputs), )
[docs] @classmethod def check_output( cls, baseline_result: ValidationSingleJobResult, candidate_result: ValidationSingleJobResult, ) -> None: """check model output format.""" baseline_outputs = baseline_result.outputs candidate_outputs = candidate_result.outputs if not _is_flat_dict(baseline_outputs[0]): err_message = ( 'Currently ValidationReport only accept flat dict outputs. ' f' But we got {type(baseline_outputs[0])}' ) raise ValueError(err_message) baseline_flat = jax.tree_util.tree_leaves(baseline_outputs) candidate_flat = jax.tree_util.tree_leaves(candidate_outputs) if len(baseline_flat) != len(candidate_flat): raise ValueError( 'baseline and candidate has different output length.' f'len(baseline) = {len(baseline_flat)},' f'len(candidate) = {len(candidate_flat)}.' )