Exporting with Orbax#
Orbax Export is a library for exporting JAX models to TensorFlow SavedModel format.
Exporting#
API Overview#
Orbax Export provides three classes.
{py:class}JaxModule <orbax.export.jax_module.JaxModule> wraps a JAX function and its parameters to an exportable and callable closure.
{py:class}ServingConfig <orbax.export.serving_config.ServingConfig> defines a serving configuration for a
JaxModule, including a signature key and an input signature, and optionally pre- and post-processing functions and extra TrackableResources.{py:class}ExportManager <orbax.export.export_manager.ExportManager> builds the actual serving signatures based on a
JaxModuleand a list ofServingConfigs, and saves them to the SavedModel format. It is for CPU. Users can inheritExportManagerclass and create their own “ExportManager” for different hardwares.
Simple Example Usage#
Setup#
Before start model exporting, users should have the JAX model and its model params, preprocess, postprocess function ready.
# Import Orbax Export classes.
from orbax.export import ExportManager
from orbax.export import JaxModule
from orbax.export import ServingConfig
import numpy as np
import jax
import jax.numpy as jnp
import tensorflow as tf
# Prepare the parameters and model function to export.
example1_params = { 'a': np.array(5.0), 'b': np.array(1.1), 'c': np.array(0.55)} # A pytree of the JAX model parameters.
# model f(x) = a * sin(x) + b * x + c, here (a, b, c) are model parameters
def example1_model_fn(params, inputs): # The JAX model function to export.
a, b, c = params['a'], params['b'], params['c']
return a * jnp.sin(inputs) + b * inputs + c
def example1_preprocess(inputs): # Optional: preprocessor in TF.
norm_inputs = tf.nest.map_structure(lambda x: x/tf.math.reduce_max(x), inputs)
return norm_inputs
def example1_postprocess(model_fn_outputs): # Optional: post-processor in TF.
return {'outputs': model_fn_outputs}
---------------------------------------------------------------------------
ModuleNotFoundError Traceback (most recent call last)
Cell In[1], line 2
1 # Import Orbax Export classes.
----> 2 from orbax.export import ExportManager
3 from orbax.export import JaxModule
4 from orbax.export import ServingConfig
File ~/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages/orbax/export/__init__.py:21
19 from orbax.export import constants
20 from orbax.export import obm_configs
---> 21 from orbax.export import typing
22 from orbax.export import utils
23 from orbax.export.data_processors.tf_data_processor import TfDataProcessor
File ~/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages/orbax/export/typing.py:21
19 from typing import Any, TypeVar, Union
20 import jaxtyping
---> 21 import tensorflow as tf
24 T = TypeVar('T')
25 Nested = Union[T, tuple[Any, ...], Sequence[Any], Mapping[str, Any]]
File ~/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages/tensorflow/__init__.py:40
37 _os.environ.setdefault("ENABLE_RUNTIME_UPTIME_TELEMETRY", "1")
39 # Do not remove this line; See https://github.com/tensorflow/tensorflow/issues/42596
---> 40 from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow # pylint: disable=unused-import
41 from tensorflow.python.tools import module_util as _module_util
42 from tensorflow.python.util.lazy_loader import KerasLazyLoader as _KerasLazyLoader
ModuleNotFoundError: No module named 'tensorflow.python'
inputs = tf.random.normal([16], dtype=tf.float32)
model_outputs = example1_postprocess(example1_model_fn(example1_params, np.array(example1_preprocess(inputs))))
print("model output: ", model_outputs)
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In[2], line 1
----> 1 inputs = tf.random.normal([16], dtype=tf.float32)
3 model_outputs = example1_postprocess(example1_model_fn(example1_params, np.array(example1_preprocess(inputs))))
4 print("model output: ", model_outputs)
NameError: name 'tf' is not defined
Exporting a JAX model to a CPU SavedModel
import tensorflow as tf
# Construct a JaxModule where JAX->TF conversion happens.
jax_module = JaxModule(example1_params, example1_model_fn)
# Export the JaxModule along with one or more serving configs.
export_mgr = ExportManager(
jax_module, [
ServingConfig(
'serving_default',
input_signature= [tf.TensorSpec(shape=[16], dtype=tf.float32)],
tf_preprocessor=example1_preprocess,
tf_postprocessor=example1_postprocess
),
])
output_dir='/tmp/example1_output_dir'
export_mgr.save(output_dir)
---------------------------------------------------------------------------
ModuleNotFoundError Traceback (most recent call last)
Cell In[3], line 1
----> 1 import tensorflow as tf
3 # Construct a JaxModule where JAX->TF conversion happens.
4 jax_module = JaxModule(example1_params, example1_model_fn)
File ~/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages/tensorflow/__init__.py:40
37 _os.environ.setdefault("ENABLE_RUNTIME_UPTIME_TELEMETRY", "1")
39 # Do not remove this line; See https://github.com/tensorflow/tensorflow/issues/42596
---> 40 from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow # pylint: disable=unused-import
41 from tensorflow.python.tools import module_util as _module_util
42 from tensorflow.python.util.lazy_loader import KerasLazyLoader as _KerasLazyLoader
ModuleNotFoundError: No module named 'tensorflow.python'
Load the TF saved_model model back and run it
loaded_model = tf.saved_model.load(output_dir)
loaded_model_outputs = loaded_model(inputs)
print("loaded model output: ", loaded_model_outputs)
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In[4], line 1
----> 1 loaded_model = tf.saved_model.load(output_dir)
2 loaded_model_outputs = loaded_model(inputs)
3 print("loaded model output: ", loaded_model_outputs)
NameError: name 'tf' is not defined
np.testing.assert_allclose(model_outputs['outputs'], loaded_model_outputs['outputs'], atol=1e-5, rtol=1e-5)
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In[5], line 1
----> 1 np.testing.assert_allclose(model_outputs['outputs'], loaded_model_outputs['outputs'], atol=1e-5, rtol=1e-5)
NameError: name 'np' is not defined
Limitation#
“JaxModule only take single arg as the input”.#
This error message means the JAX funtion model_fn only can take single arg as the input.
Orbax is designed to take a JAX Module in the format of a Callable with
parameters of type PyTree and model inputs of type PyTree. If your JAX function
takes multiple inputs, you must pack them into a single JAX PyTree. Otherwise,
you will encounter this error message.
To solve this problem, you can update the ServingConfig.tf_preprocessor
function to pack the inputs into a single JAX PyTree. For example, our model
takes two inputs x and y. You can define the ServingConfig.tf_preprocessor
pack them into a list [x, y].
example2_params = {} # A pytree of the JAX model parameters.
def example2_model_fn(params, inputs):
x, y = inputs
return x + y
def example2_preprocessor(x, y):
# put the normal tf_preprocessor codes here.
return [x, y] # pack it into a single list for jax model_func.
jax_module = JaxModule(example2_params, example2_model_fn)
export_mgr = ExportManager(
jax_module,
[
ServingConfig(
'serving_default',
input_signature=[tf.TensorSpec([16]), tf.TensorSpec([16])],
tf_preprocessor=example2_preprocessor,
)
],
)
output_dir='/tmp/example2_output_dir'
export_mgr.save(output_dir)
loaded_model = tf.saved_model.load(output_dir)
loaded_model_outputs = loaded_model(tf.random.normal([16]), tf.random.normal([16]))
print("loaded model output: ", loaded_model_outputs)
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In[6], line 11
7 def example2_preprocessor(x, y):
8 # put the normal tf_preprocessor codes here.
9 return [x, y] # pack it into a single list for jax model_func.
---> 11 jax_module = JaxModule(example2_params, example2_model_fn)
12 export_mgr = ExportManager(
13 jax_module,
14 [
(...) 20 ],
21 )
22 output_dir='/tmp/example2_output_dir'
NameError: name 'JaxModule' is not defined
Validating#
API Overview#
Orbax.export.validate is library that can be used to validate the JAX model and its exported TF SavedModel format.
Users must finish the JAX model exporting first. Users can export the model by orbax.export or manually.
Orbax.export.validate provides those classes:
ValidationJobtake the model and data as input, then output the result.ValidationReportcompare the JAX model and TF SavedModel results, then generate the formatted report.ValidationManagertakeJaxModuleas inputs and wrap the validation e2e flow.
Simple Example Usage#
Here we same example as ExportManager.
from orbax.export.validate import ValidationManager
from orbax.export import JaxModule
from orbax.export import ServingConfig
jax_module = JaxModule(example1_params, example1_model_fn)
batch_inputs = [inputs] * 16
serving_configs = [
ServingConfig(
'serving_default',
input_signature= [tf.TensorSpec(shape=[16], dtype=tf.float32)],
tf_preprocessor=example1_preprocess,
tf_postprocessor=example1_postprocess
),
]
# Provide computation method for the baseline.
validation_mgr = ValidationManager(jax_module, serving_configs,
batch_inputs)
tf_saved_model_path = "/tmp/example1_output_dir"
loaded_model = tf.saved_model.load(tf_saved_model_path)
# Provide the computation method for the candidate.
validation_reports = validation_mgr.validate(loaded_model)
# `validation_reports` is a python dict and the key is TF SavedModel serving_key.
for key in validation_reports:
assert(validation_reports[key].status.name == 'Pass')
# Users can also save the converted json to file.
print(validation_reports[key].to_json(indent=2))
---------------------------------------------------------------------------
ModuleNotFoundError Traceback (most recent call last)
Cell In[7], line 1
----> 1 from orbax.export.validate import ValidationManager
2 from orbax.export import JaxModule
3 from orbax.export import ServingConfig
File ~/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages/orbax/export/__init__.py:21
19 from orbax.export import constants
20 from orbax.export import obm_configs
---> 21 from orbax.export import typing
22 from orbax.export import utils
23 from orbax.export.data_processors.tf_data_processor import TfDataProcessor
File ~/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages/orbax/export/typing.py:21
19 from typing import Any, TypeVar, Union
20 import jaxtyping
---> 21 import tensorflow as tf
24 T = TypeVar('T')
25 Nested = Union[T, tuple[Any, ...], Sequence[Any], Mapping[str, Any]]
File ~/checkouts/readthedocs.org/user_builds/orbax/envs/latest/lib/python3.12/site-packages/tensorflow/__init__.py:40
37 _os.environ.setdefault("ENABLE_RUNTIME_UPTIME_TELEMETRY", "1")
39 # Do not remove this line; See https://github.com/tensorflow/tensorflow/issues/42596
---> 40 from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow # pylint: disable=unused-import
41 from tensorflow.python.tools import module_util as _module_util
42 from tensorflow.python.util.lazy_loader import KerasLazyLoader as _KerasLazyLoader
ModuleNotFoundError: No module named 'tensorflow.python'
Limitation#
Here we list those limitation of Orbax.export validate module.
Because the TF SavedModel the returned object is always a map. If the jax model output is a sequence, TF SavedModel will convert it to map. The tensor names are fairly generic, like output_0. To help
ValidationReportmodule can do apple-to-apple comparison between JAX model and TF model result, we suggest users modify the model output as a dictionary.
Examples#
Check-out the examples directory for a number of examples using Orbax Export.