Exporting with Orbax#

Orbax Export is a library for exporting JAX models to TensorFlow SavedModel format.

Exporting#

API Overview#

Orbax Export provides three classes.

  • JaxModule wraps a JAX function and its parameters to an exportable and callable closure.

  • ServingConfig defines a serving configuration for a JaxModule, including a signature key and an input signature, and optionally pre- and post-processing functions and extra TrackableResources.

  • ExportManager builds the actual serving signatures based on a JaxModule and a list of ServingConfigs, and saves them to the SavedModel format. It is for CPU. Users can inherit ExportManager class and create their own “ExportManager” for different hardwares.

Simple Example Usage#

Setup#

Before start model exporting, users should have the JAX model and its model params, preprocess, postprocess function ready.

# Import Orbax Export classes.
from orbax.export import ExportManager
from orbax.export import JaxModule
from orbax.export import ServingConfig
import numpy as np
import jax
import jax.numpy as jnp
import tensorflow as tf

# Prepare the parameters and model function to export.
example1_params = {  'a': np.array(5.0), 'b': np.array(1.1), 'c': np.array(0.55)} # A pytree of the JAX model parameters.

# model f(x) = a * sin(x) + b * x + c, here (a, b, c) are model parameters
def example1_model_fn(params, inputs):  # The JAX model function to export.
  a, b, c = params['a'], params['b'], params['c']
  return a * jnp.sin(inputs) + b * inputs + c

def example1_preprocess(inputs):  # Optional: preprocessor in TF.
  norm_inputs = tf.nest.map_structure(lambda x: x/tf.math.reduce_max(x), inputs)
  return norm_inputs

def example1_postprocess(model_fn_outputs):  # Optional: post-processor in TF.
  return {'outputs': model_fn_outputs}
2024-04-26 19:46:37.668612: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
WARNING:absl:Type handler registry overriding type "<class 'float'>" collision on scalar
WARNING:absl:Type handler registry overriding type "<class 'bytes'>" collision on scalar
WARNING:absl:Type handler registry overriding type "<class 'numpy.number'>" collision on scalar
inputs = tf.random.normal([16], dtype=tf.float32)

model_outputs = example1_postprocess(example1_model_fn(example1_params, np.array(example1_preprocess(inputs))))
print("model output: ", model_outputs)
model output:  {'outputs': Array([ 0.4047188 ,  1.9534669 , -0.22997212,  1.7075796 ,  0.17860335,
        4.758424  , -1.1869419 , -1.1554301 ,  0.01281381,  5.8573546 ,
        0.3630153 ,  0.15386757, -0.43379813,  4.606002  ,  1.8755991 ,
        4.102188  ], dtype=float32)}

Exporting a JAX model to a CPU SavedModel

import tensorflow as tf

# Construct a JaxModule where JAX->TF conversion happens.
jax_module = JaxModule(example1_params, example1_model_fn)
# Export the JaxModule along with one or more serving configs.
export_mgr = ExportManager(
  jax_module, [
    ServingConfig(
      'serving_default',
      input_signature= [tf.TensorSpec(shape=[16], dtype=tf.float32)],
      tf_preprocessor=example1_preprocess,
      tf_postprocessor=example1_postprocess
    ),
])
output_dir='/tmp/example1_output_dir'
export_mgr.save(output_dir)
INFO:tensorflow:Assets written to: /tmp/example1_output_dir/assets
INFO:tensorflow:Assets written to: /tmp/example1_output_dir/assets

Load the TF saved_model model back and run it

loaded_model = tf.saved_model.load(output_dir)
loaded_model_outputs = loaded_model(inputs)
print("loaded model output: ", loaded_model_outputs)
loaded model output:  {'outputs': <tf.Tensor: shape=(16,), dtype=float32, numpy=
array([ 0.4047188 ,  1.9534669 , -0.22997218,  1.7075796 ,  0.17860335,
        4.758424  , -1.1869419 , -1.1554301 ,  0.01281381,  5.857355  ,
        0.3630153 ,  0.1538676 , -0.43379813,  4.606002  ,  1.8755989 ,
        4.102188  ], dtype=float32)>}
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1714160798.912831    1234 service.cc:145] XLA service 0x7f73b00066c0 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1714160798.912879    1234 service.cc:153]   StreamExecutor device (0): Host, Default Version
I0000 00:00:1714160798.938494    1234 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
np.testing.assert_allclose(model_outputs['outputs'], loaded_model_outputs['outputs'], atol=1e-5, rtol=1e-5)

Limitation#

“JaxModule only take single arg as the input”.#

This error message means the JAX funtion model_fn only can take single arg as the input. Orbax is designed to take a JAX Module in the format of a Callable with parameters of type PyTree and model inputs of type PyTree. If your JAX function takes multiple inputs, you must pack them into a single JAX PyTree. Otherwise, you will encounter this error message.

To solve this problem, you can update the ServingConfig.tf_preprocessor function to pack the inputs into a single JAX PyTree. For example, our model takes two inputs x and y. You can define the ServingConfig.tf_preprocessor pack them into a list [x, y].

example2_params = {} # A pytree of the JAX model parameters.

def example2_model_fn(params, inputs):
  x, y = inputs
  return x + y

def example2_preprocessor(x, y):
  # put the normal tf_preprocessor codes here.
  return [x, y] # pack it into a single list for jax model_func.

jax_module = JaxModule(example2_params, example2_model_fn)
export_mgr = ExportManager(
  jax_module,
  [
      ServingConfig(
          'serving_default',
          input_signature=[tf.TensorSpec([16]), tf.TensorSpec([16])],
          tf_preprocessor=example2_preprocessor,
      )
  ],
)
output_dir='/tmp/example2_output_dir'
export_mgr.save(output_dir)

loaded_model = tf.saved_model.load(output_dir)
loaded_model_outputs = loaded_model(tf.random.normal([16]), tf.random.normal([16]))
print("loaded model output: ", loaded_model_outputs)
INFO:tensorflow:Assets written to: /tmp/example2_output_dir/assets
INFO:tensorflow:Assets written to: /tmp/example2_output_dir/assets
loaded model output:  tf.Tensor(
[ 0.19923565 -2.748554   -1.6485482  -0.92307144 -0.11171305 -1.4577065
  1.1756551  -1.3872656  -1.8684382   1.5276146  -0.77312386 -0.14387262
 -2.8860745  -0.45430046  1.456396    0.15503347], shape=(16,), dtype=float32)

Validating#

API Overview#

Orbax.export.validate is library that can be used to validate the JAX model and its exported TF SavedModel format.

Users must finish the JAX model exporting first. Users can export the model by orbax.export or manually.

Orbax.export.validate provides those classes:

  • ValidationJob take the model and data as input, then output the result.

  • ValidationReport compare the JAX model and TF SavedModel results, then generate the formatted report.

  • ValidationManager take JaxModule as inputs and wrap the validation e2e flow.

Simple Example Usage#

Here we same example as ExportManager.

from orbax.export.validate import ValidationManager
from orbax.export import JaxModule
from orbax.export import ServingConfig

jax_module = JaxModule(example1_params, example1_model_fn)
batch_inputs = [inputs] * 16

serving_configs = [
  ServingConfig(
      'serving_default',
      input_signature= [tf.TensorSpec(shape=[16], dtype=tf.float32)],
      tf_preprocessor=example1_preprocess,
      tf_postprocessor=example1_postprocess
    ),
]
# Provide computation method for the baseline.
validation_mgr = ValidationManager(jax_module, serving_configs,
                                       batch_inputs)

tf_saved_model_path = "/tmp/example1_output_dir"
loaded_model = tf.saved_model.load(tf_saved_model_path)

# Provide the computation method for the candidate.
validation_reports = validation_mgr.validate(loaded_model)

# `validation_reports` is a python dict and the key is TF SavedModel serving_key.
for key in validation_reports:
  assert(validation_reports[key].status.name == 'Pass')
  # Users can also save the converted json to file.
  print(validation_reports[key].to_json(indent=2))
WARNING:absl:Recommend Orbax validate inputs format as flat_dict or flat_list, so it can generate the consistent tf.SavedModel signature. For arbitrary format, we assume it as atomic data type, it may fail. Got inputs format <class 'tensorflow.python.framework.ops.EagerTensor'>
WARNING:absl:Recommend Orbax validate inputs format as flat_dict or flat_list, so it can generate the consistent tf.SavedModel signature. For arbitrary format, we assume it as atomic data type, it may fail. Got inputs format <class 'tensorflow.python.framework.ops.EagerTensor'>
WARNING:absl:Recommend Orbax validate inputs format as flat_dict or flat_list, so it can generate the consistent tf.SavedModel signature. For arbitrary format, we assume it as atomic data type, it may fail. Got inputs format <class 'tensorflow.python.framework.ops.EagerTensor'>
WARNING:absl:Recommend Orbax validate inputs format as flat_dict or flat_list, so it can generate the consistent tf.SavedModel signature. For arbitrary format, we assume it as atomic data type, it may fail. Got inputs format <class 'tensorflow.python.framework.ops.EagerTensor'>
WARNING:absl:Recommend Orbax validate inputs format as flat_dict or flat_list, so it can generate the consistent tf.SavedModel signature. For arbitrary format, we assume it as atomic data type, it may fail. Got inputs format <class 'tensorflow.python.framework.ops.EagerTensor'>
WARNING:absl:Recommend Orbax validate inputs format as flat_dict or flat_list, so it can generate the consistent tf.SavedModel signature. For arbitrary format, we assume it as atomic data type, it may fail. Got inputs format <class 'tensorflow.python.framework.ops.EagerTensor'>
WARNING:absl:Recommend Orbax validate inputs format as flat_dict or flat_list, so it can generate the consistent tf.SavedModel signature. For arbitrary format, we assume it as atomic data type, it may fail. Got inputs format <class 'tensorflow.python.framework.ops.EagerTensor'>
WARNING:absl:Recommend Orbax validate inputs format as flat_dict or flat_list, so it can generate the consistent tf.SavedModel signature. For arbitrary format, we assume it as atomic data type, it may fail. Got inputs format <class 'tensorflow.python.framework.ops.EagerTensor'>
WARNING:absl:Recommend Orbax validate inputs format as flat_dict or flat_list, so it can generate the consistent tf.SavedModel signature. For arbitrary format, we assume it as atomic data type, it may fail. Got inputs format <class 'tensorflow.python.framework.ops.EagerTensor'>
WARNING:absl:Recommend Orbax validate inputs format as flat_dict or flat_list, so it can generate the consistent tf.SavedModel signature. For arbitrary format, we assume it as atomic data type, it may fail. Got inputs format <class 'tensorflow.python.framework.ops.EagerTensor'>
WARNING:absl:Recommend Orbax validate inputs format as flat_dict or flat_list, so it can generate the consistent tf.SavedModel signature. For arbitrary format, we assume it as atomic data type, it may fail. Got inputs format <class 'tensorflow.python.framework.ops.EagerTensor'>
WARNING:absl:Recommend Orbax validate inputs format as flat_dict or flat_list, so it can generate the consistent tf.SavedModel signature. For arbitrary format, we assume it as atomic data type, it may fail. Got inputs format <class 'tensorflow.python.framework.ops.EagerTensor'>
WARNING:absl:Recommend Orbax validate inputs format as flat_dict or flat_list, so it can generate the consistent tf.SavedModel signature. For arbitrary format, we assume it as atomic data type, it may fail. Got inputs format <class 'tensorflow.python.framework.ops.EagerTensor'>
WARNING:absl:Recommend Orbax validate inputs format as flat_dict or flat_list, so it can generate the consistent tf.SavedModel signature. For arbitrary format, we assume it as atomic data type, it may fail. Got inputs format <class 'tensorflow.python.framework.ops.EagerTensor'>
WARNING:absl:Recommend Orbax validate inputs format as flat_dict or flat_list, so it can generate the consistent tf.SavedModel signature. For arbitrary format, we assume it as atomic data type, it may fail. Got inputs format <class 'tensorflow.python.framework.ops.EagerTensor'>
WARNING:absl:Recommend Orbax validate inputs format as flat_dict or flat_list, so it can generate the consistent tf.SavedModel signature. For arbitrary format, we assume it as atomic data type, it may fail. Got inputs format <class 'tensorflow.python.framework.ops.EagerTensor'>
WARNING:absl:Recommend Orbax validate inputs format as flat_dict or flat_list, so it can generate the consistent tf.SavedModel signature. For arbitrary format, we assume it as atomic data type, it may fail. Got inputs format <class 'tensorflow.python.framework.ops.EagerTensor'>
{
  "outputs": {
    "FloatingPointDiffReport": {
      "total": 256,
      "max_diff": 4.76837158203125e-07,
      "max_rel_diff": 2.591820305042347e-07,
      "all_close": true,
      "all_close_absolute_tolerance": 1e-07,
      "all_close_relative_tolerance": 1e-07
    },
    "NonFloatingPointDiffReport": {
      "total_flattened_tensors": 0,
      "mismatches": 0,
      "mismatch_ratio": 0.0,
      "max_non_floating_mismatch_ratio": 0.01
    }
  },
  "latency": {
    "baseline": {
      "num_batches": 16,
      "avg_in_ms": 1.1198371648788452,
      "p90_in_ms": 1.272439956665039,
      "p99_in_ms": 1.3556122779846191
    },
    "candidate": {
      "num_batches": 16,
      "avg_in_ms": 2.1341294050216675,
      "p90_in_ms": 2.560734748840332,
      "p99_in_ms": 2.9134511947631836
    }
  },
  "xprof_url": {
    "baseline": "N/A",
    "candidate": "N/A"
  },
  "metadata": {
    "baseline": {},
    "candidate": {}
  },
  "status": 1
}

Limitation#

Here we list those limitation of Orbax.export validate module.

  • Because the TF SavedModel the returned object is always a map. If the jax model output is a sequence, TF SavedModel will convert it to map. The tensor names are fairly generic, like output_0. To help ValidationReport module can do apple-to-apple comparison between JAX model and TF model result, we suggest users modify the model output as a dictionary.

Examples#

Check-out the examples directory for a number of examples using Orbax Export.