Source code for orbax.export.export_manager
# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Manage the exporting of a JAXModule."""
from collections.abc import Mapping, Sequence
from typing import Any, Callable, Optional, cast
from etils.epy import reraise_utils
from orbax.export import config
from orbax.export import constants
from orbax.export import jax_module
from orbax.export import obm_export
from orbax.export import serving_config as osc
from orbax.export import tensorflow_export
import tensorflow as tf
obx_export_config = config.config
maybe_reraise = reraise_utils.maybe_reraise
[docs]
class ExportManager:
"""Exports a JAXModule with pre- and post-processors.
This manager acts as a unified interface for exporting JAX modules. It
handles the underlying serialization logic, dynamically routing to either
Orbax-native export (`ObmExport`) or TensorFlow SavedModel export
(`TensorFlowExport`) based on the configuration of the provided module.
Example:
Configure and export a JAX module using a specific serving configuration::
import tensorflow as tf
from orbax.export import ExportManager
from orbax.export import serving_config
# Assume `my_jax_module` is a fully initialized jax_module.JaxModule
# Define how the model should handle incoming requests
my_config = serving_config.ServingConfig(
signature_key="serving_default",
input_signature=[tf.TensorSpec(shape=(None, 32), dtype=tf.float32)],
)
# Initialize the manager
export_mgr = ExportManager(
module=my_jax_module,
serving_configs=[my_config]
)
# Save the model to a directory
export_mgr.save("/path/to/my/saved_model")
"""
[docs]
def __init__(
self,
module: jax_module.JaxModule | None,
serving_configs: Sequence[osc.ServingConfig],
):
"""ExportManager constructor.
Args:
module: The `JaxModule` to be exported. Can be None in specific delayed
initialization or native Orbax load scenarios.
serving_configs: a sequence of which each element is a `ServingConfig`
corresponding to a serving signature of the exported SavedModel.
"""
self._jax_module = module
if (
not self._jax_module
or self._jax_module.export_version
== constants.ExportModelType.ORBAX_MODEL
):
self._serialization_functions = obm_export.ObmExport(
self._jax_module, serving_configs
)
else:
self._serialization_functions = tensorflow_export.TensorFlowExport(
self._jax_module, serving_configs
)
@property
def tf_module(self) -> tf.Module:
"""Returns the tf.module maintained by the export manager.
Raises:
TypeError: If the export version is `ExportModelType.ORBAX_MODEL` or if
the module is not provided (as Orbax models do not use tf.Module).
"""
if (
not self._jax_module
or self._jax_module.export_version
== constants.ExportModelType.ORBAX_MODEL
):
raise TypeError(
'tf_module is not implemented for export version'
' ExportModelType.ORBAX_MODEL.'
)
return cast(
tensorflow_export.TensorFlowExport,
self._serialization_functions).tf_export_module()
@property
def serving_signatures(self) -> Mapping[str, Callable[..., Any]]:
"""Returns a map of signature keys to serving functions."""
return self._serialization_functions.serving_signatures
[docs]
def save(
self,
model_path: str,
save_options: Optional[tf.saved_model.SaveOptions] = None,
signature_overrides: Optional[Mapping[str, Callable[..., Any]]] = None,
):
"""Saves the JAX model to a Savemodel.
Args:
model_path: a directory in which to write the SavedModel.
save_options: an optional tf.saved_model.SaveOptions for configuring save
options.
signature_overrides: signatures to override the self-maintained ones, or
additional signatures to export.
"""
self._serialization_functions.save(
model_path=model_path,
save_options=save_options,
signature_overrides=signature_overrides,
)
[docs]
def load(self, model_path: str, **kwargs: Any):
"""Loads the exported model from disk.
Args:
model_path: The directory from which to load the model.
**kwargs: Additional keyword arguments passed to the underlying loader.
Returns:
The loaded model instance.
"""
loaded = self._serialization_functions.load(model_path, **kwargs)
return loaded