JaxModule#

Wraps JAX functions and parameters into a tf.Module.

JaxModule#

class orbax.export.jax_module.JaxModule(params, apply_fn, trainable=None, input_polymorphic_shape=None, input_polymorphic_shape_symbol_values=None, jax2tf_kwargs=None, jit_compile=True, pspecs=None, allow_multi_axis_sharding_consolidation=None, export_version=ExportModelType.TF_SAVEDMODEL, jax2obm_kwargs=None, jax2obm_options=None)[source][source]#

An exportable module for JAX functions and parameters.

Holds tf.Variables converted from JAX parameters, as well as TF functions converted from JAX functions and bound with the tf.Variables. It can be exported to TF SavedModel or Orbax Model format.

Example

wraps a JAX function and parameters into a JaxModule:

import jax
import jax.numpy as jnp
from orbax.export import JaxModule

# Define parameters and a model application function
params = {'weights': jnp.ones((3, 3))}

def my_apply_fn(params, inputs):
  return jnp.dot(inputs, params['weights'])

# Create the exportable module (Defaults to TF_SAVEDMODEL format)
jax_module = JaxModule(
    params=params,
    apply_fn=my_apply_fn
)
__init__(params, apply_fn, trainable=None, input_polymorphic_shape=None, input_polymorphic_shape_symbol_values=None, jax2tf_kwargs=None, jit_compile=True, pspecs=None, allow_multi_axis_sharding_consolidation=None, export_version=ExportModelType.TF_SAVEDMODEL, jax2obm_kwargs=None, jax2obm_options=None)[source][source]#

JaxModule constructor.

Parameters:
  • params (PyTree) – a pytree of JAX parameters or parameter specs (e.g. `jax.ShapeDtypeStruct`s).

  • apply_fn (UnionType[Callable[PyTree, PyTree, PyTree], ApplyFnInfo, Mapping[str, Callable[PyTree, PyTree, PyTree]], Mapping[str, ApplyFnInfo]]) – A single ApplyFn (taking model_params and model_inputs), a single ApplyFnInfo object (containing ApplyFn and input/output keys), or a mapping of method keys to ApplyFn`s or `ApplyFnInfo objects. If it is a single ApplyFn or ApplyFnInfo, it will be assigned a key constants.DEFAULT_METHOD_KEY automatically, which can be used to look up the TF function converted from it.

  • trainable (Union[bool, PyTree, None]) – a pytree in the same structure as params and boolean leaves to tell if a parameter is trainable. Alternatively, it can be a single boolean value to tell if all the parameters are trainable or not. By default all parameters are non-trainable. The default value is subject to change in the future, thus it is recommended to specify the value explicitly. Currently trainable is only relevant for TF SavedModel export.

  • input_polymorphic_shape (Union[PyTree, Mapping[str, PyTree], None]) – the polymorhpic shape for the inputs of apply_fn. If apply_fn is a mapping, input_polymorphic_shape must be a mapping of method key to the input polymorphic shape for the method.

  • input_polymorphic_shape_symbol_values (Union[PyTree, Mapping[str, PyTree], None]) – optional mapping of symbol names presented in input_polymorphic_shape to possible values (e.g. {‘batch_size’: (1, 2), ‘seq_len’: (128, 512)}). When there are multiple ``apply_fn``s in the form of a flat mapping, this argument must be a flat mapping with the same keys (e.g. { ‘serving_default’: { ‘batch_size’: (1, 2), ‘seq_len’: (128, 512)}). When this argument is set, the polymoprhic shape will be concretized to a set of all possible concreteized input shape combinations. This is only relevant for export model type constants.ExportModelType.ORBAX_MODEL

  • jax2tf_kwargs (Optional[Mapping[str, Any], None]) – options passed to jax2tf. polymorphic_shape is inferred from input_polymorphic_shape and should not be set. with_gradient, if set, should be consistent with the trainable argument above. If jax2tf_kwargs is unspecified, the default jax2tf option will be applied. If apply_fn is a mapping, jax2tf_kwargs must either be unspecified or a mapping of method key to the jax2tf kwargs for the method. The jax2tf_kwargs is only relevant for TF SavedModel export.

  • jit_compile (Union[bool, Mapping[str, bool]]) – whether to jit compile the jax2tf converted functions. If apply_fn is a mapping, this can either be a boolean applied to all functions or a mapping of method key to the jit compile option for the method. The jit_compile is only relevant for TF SavedModel export as all methods for the Orbax model export are jit compiled.

  • pspecs (Optional[PyTree, None]) – an optional pytree of PartitionSpecs of the params in the same structure as params. If set, the leaves of params must be jax.Array, and JaxModule must be created within a DTensor export context from with maybe_enable_dtensor_export_on(mesh). DTensor export is only supported for TF SavedModel export.

  • allow_multi_axis_sharding_consolidation (Optional[bool, None]) – Disallowed by default. When set to true, it will allow consolidating JAX array multiple axis sharding into DTensor single axis sharding during checkpoint conversion. This would enable sharding across multiple axis names support for JAX model. This is only relevant for TF SavedModel export.

  • export_version (ExportModelType) – The model export version. Either TF_SAVEDMODEL or ORBAX_MODEL.

  • jax2obm_kwargs (Optional[Mapping[str, Any], None]) – DEPRECATED: use jax2obm_options instead. Options passed to the Orbax Model export. Accepted arguments are ‘native_serialization_platforms’ which must be a tuple of OrbaxNativeSerializationType.

  • jax2obm_options (UnionType[Jax2ObmOptions, Mapping[str, Jax2ObmOptions], None]) – Options for jax2obm conversion. If apply_fn is a mapping, this can also be a mapping from method keys to Jax2ObmOptions.Currently, when it is a mapping, most options must be shared across different apply functions, except for enable_auto_layout and native_serialization_disabled_checks.

Raises:

ValueError – If jax2obm_kwargs and jax2obm_options are both provided, or if input_polymorphic_shape_symbol_values or ApplyFnInfo are provided but export_version is not constants.ExportModelType.ORBAX_MODEL, or if export_version is not supported.

property apply_fn_map: Mapping[str, Callable[[PyTree, PyTree], PyTree] | ApplyFnInfo]#

Returns a mapping from method keys to ApplyFn or ApplyFnInfo objects.

Each value in the mapping is either an ApplyFn (a callable that takes model parameters and inputs) or an ApplyFnInfo object. ApplyFnInfo wraps an ApplyFn along with its input and output keys, and is used for specifying preprocessing/postprocessing dependencies when exporting to constants.ExportModelType.ORBAX_MODEL format.

If a single ApplyFn or ApplyFnInfo was provided during initialization, it is keyed by constants.DEFAULT_METHOD_KEY.

Return type:

Mapping[str, UnionType[Callable[PyTree, PyTree, PyTree], ApplyFnInfo]]

property model_params: PyTree#

Returns the model parameters.

Return type:

PyTree

property model_param_names: Sequence[str]#

Returns the list of model parameter names.

The name format matches the one used by JSV to look up parameters in the checkpoint (e.g. “params.key.subkey”).

Return type:

Sequence[str]

property export_version: ExportModelType#

Returns the export version.

Return type:

ExportModelType

export_module()[source][source]#

Returns the export module.

Return type:

OrbaxModuleBase

property jax2tf_kwargs_map: Mapping[str, Any]#

Returns the jax2tf_kwargs_map.

Return type:

Mapping[str, Any]

property input_polymorphic_shape_map: Mapping[str, PyTree]#

Returns the polymorphic shapes.

Return type:

Mapping[str, PyTree]

property with_gradient: bool#

Returns the with_gradient.

Return type:

bool

update_variables(params)[source][source]#

Updates the variables associated with self.

Parameters:

params (PyTree) – A PyTree of JAX parameters. The PyTree structure must be the same as that of the params used to initialize the model. Additionally, the shape and dtype of each parameter must be the same as the original parameter.

orbax_module()[source][source]#

Returns the OrbaxModule associated with this JaxModule.

Return type:

OrbaxModuleBase

property methods: Mapping[str, Callable[[...], Any]]#

Named methods in TF context.

Return type:

Mapping[str, Callable[…, Any]]

property jax_methods: Mapping[str, Callable[[...], Any]]#

Named methods in JAX context for validation.

Return type:

Mapping[str, Callable[…, Any]]

obm_module_to_jax_exported_map(model_inputs)[source][source]#

Converts the orbax.export JaxModule to jax_export.Exported.

Parameters:

model_inputs (PyTree) – The model inputs.

Return type:

Mapping[str, Exported]

Returns:

A mapping from method key to jax_export.Exported.

property save_shlo_to_file: bool#

Returns True if StableHLO should be saved as an external file.

Return type:

bool