Source code for orbax.export.export_manager

# Copyright 2024 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Manage the exporting of a JAXModule."""

from collections.abc import Mapping, Sequence
from typing import Any, Callable, Optional

from etils.epy.reraise_utils import maybe_reraise
from orbax.export import dtensor_utils
from orbax.export import utils
from orbax.export.export_manager_base import ExportManagerBase
from orbax.export.jax_module import JaxModule
from orbax.export.serving_config import ServingConfig
import tensorflow as tf
from tensorflow.experimental import dtensor


[docs]class ExportManager(ExportManagerBase): """Exports a JAXModule with pre- and post-processors."""
[docs] def __init__(self, module: JaxModule, serving_configs: Sequence[ServingConfig]): """ExportManager constructor. Args: module: the `JaxModule` to be exported. serving_configs: a sequence of which each element is a `ServingConfig` cooresponding to a serving signature of the exported SavedModel. """ # Creates a new tf.Module wrapping the JaxModule and extra trackable # resources. self._module = tf.Module() self._module.computation_module = module self._serving_signatures = {} tf_trackable_resources = [] for sc in serving_configs: with maybe_reraise(f'Failed exporting signature_key={sc.signature_key} '): method = sc.get_infer_step(module.methods) inference_fn = make_e2e_inference_fn(method, sc) if isinstance(sc.signature_key, str): keys = [sc.signature_key] else: keys = sc.signature_key for key in keys: if key in self._serving_signatures: raise ValueError( f'Duplicated key "{sc.signature_key}" in `serving_configs`.' ) self._serving_signatures[key] = inference_fn if sc.extra_trackable_resources is not None: tf_trackable_resources.append(sc.extra_trackable_resources) if len(serving_configs) == 1: # Make this module callable. Once exported, it can be loaded back in # python and the nested input structure will be preservered. In # contrast, signatures will flatten the TensorSpecs of the to kwargs. self.tf_module.__call__ = inference_fn self._module.tf_trackable_resources = tf_trackable_resources
@property def tf_module(self) -> tf.Module: """Returns the tf.module maintained by the export manager.""" return self._module @property def serving_signatures(self) -> Mapping[str, Callable[..., Any]]: """Returns a map of signature keys to serving functions.""" return self._serving_signatures
[docs] def save( self, model_path: str, save_options: Optional[tf.saved_model.SaveOptions] = None, signature_overrides: Optional[Mapping[str, Callable[..., Any]]] = None, ): """Saves the JAX model to a Savemodel. Args: model_path: a directory in which to write the SavedModel. save_options: an optional tf.saved_model.SaveOptions for configuring save options. signature_overrides: signatures to override the self-maintained ones, or additional signatures to export. """ save_options = save_options or tf.saved_model.SaveOptions() save_options.experimental_custom_gradients = ( self._module.computation_module.with_gradient ) serving_signatures = dict(self.serving_signatures) if signature_overrides: serving_signatures.update(signature_overrides) tf.saved_model.save( self.tf_module, model_path, serving_signatures, options=save_options ) mesh = dtensor_utils.get_current_mesh() if mesh: # TODO(b/261191533): we can remove this once tf.saved_model.save is aware # of SPMD saving. dtensor.barrier(mesh.dtensor_mesh, 'export done')
[docs] def load(self, model_path: str, **kwargs: Any): loaded = tf.saved_model.load(model_path, **kwargs) return loaded
def make_e2e_inference_fn( model_fn: Callable[..., Any], serving_config: ServingConfig) -> Callable[..., Any]: """Creates an concrete end-to-end inference tf.function. Args: model_fn: a callable in TF context for the numeric computation. serving_config: a ServingConfig that defines the input sigature, pre-processor and post-processor of the inference function. Returns: A tf.function for end-to-end inference. """ infer_step_func_map = serving_config.bind(model_fn, require_numpy=False) signature_key = serving_config.get_signature_keys()[0] return utils.with_default_args( infer_step_func_map[signature_key], serving_config.get_input_signature() )