ExportManager#

Manage the exporting of a JAXModule.

ExportManager#

class orbax.export.export_manager.ExportManager(module, serving_configs)[source][source]#

Exports a JAXModule with pre- and post-processors.

This manager acts as a unified interface for exporting JAX modules. It handles the underlying serialization logic, dynamically routing to either Orbax-native export (ObmExport) or TensorFlow SavedModel export (TensorFlowExport) based on the configuration of the provided module.

Example

Configure and export a JAX module using a specific serving configuration:

import tensorflow as tf
from orbax.export import ExportManager
from orbax.export import serving_config

# Assume `my_jax_module` is a fully initialized jax_module.JaxModule
# Define how the model should handle incoming requests
my_config = serving_config.ServingConfig(
    signature_key="serving_default",
    input_signature=[tf.TensorSpec(shape=(None, 32), dtype=tf.float32)],
)

# Initialize the manager
export_mgr = ExportManager(
    module=my_jax_module,
    serving_configs=[my_config]
)

# Save the model to a directory
export_mgr.save("/path/to/my/saved_model")
__init__(module, serving_configs)[source][source]#

ExportManager constructor.

Parameters:
  • module (UnionType[JaxModule, None]) – The JaxModule to be exported. Can be None in specific delayed initialization or native Orbax load scenarios.

  • serving_configs (Sequence[ServingConfig]) – a sequence of which each element is a ServingConfig corresponding to a serving signature of the exported SavedModel.

property tf_module: Module#

Returns the tf.module maintained by the export manager.

Raises:

TypeError – If the export version is ExportModelType.ORBAX_MODEL or if the module is not provided (as Orbax models do not use tf.Module).

Return type:

Module

property serving_signatures: Mapping[str, Callable[[...], Any]]#

Returns a map of signature keys to serving functions.

Return type:

Mapping[str, Callable[…, Any]]

save(model_path, save_options=None, signature_overrides=None)[source][source]#

Saves the JAX model to a Savemodel.

Parameters:
  • model_path (str) – a directory in which to write the SavedModel.

  • save_options (Optional[SaveOptions, None]) – an optional tf.saved_model.SaveOptions for configuring save options.

  • signature_overrides (Optional[Mapping[str, Callable[…, Any]], None]) – signatures to override the self-maintained ones, or additional signatures to export.

load(model_path, **kwargs)[source][source]#

Loads the exported model from disk.

Parameters:
  • model_path (str) – The directory from which to load the model.

  • **kwargs – Additional keyword arguments passed to the underlying loader.

Returns:

The loaded model instance.