JaxModule#

Wraps JAX functions and parameters into a tf.Module.

JaxModule#

class orbax.export.jax_module.JaxModule(params, apply_fn, trainable=None, input_polymorphic_shape=None, jax2tf_kwargs=None, jit_compile=True, name=None, pspecs=None, allow_multi_axis_sharding_conslidation=None)[source][source]#

An exportable module for JAX functions and parameters.

Holds tf.Variables converted from JAX parameters, as well as TF functions converted from JAX functions and bound with the tf.Variables.

__init__(params, apply_fn, trainable=None, input_polymorphic_shape=None, jax2tf_kwargs=None, jit_compile=True, name=None, pspecs=None, allow_multi_axis_sharding_conslidation=None)[source][source]#

JaxModule constructor.

Parameters:
  • params (Any) – a pytree of JAX parameters.

  • apply_fn (Union[Callable[[Any, Any], Any], Mapping[str, Callable[[Any, Any], Any]]]) – A JAX ApplyFn (i.e. of signature apply_fn(params, x)), or a mapping of method key to ApplyFn. If it is an ApplyFn, it will be assigned a key JaxModule.DEFAULT_METHOD_KEY automatically, which can be used to look up the TF function converted from it.

  • trainable (Union[bool, Any, None]) – a pytree in the same structure as params and boolean leaves to tell if a parameter is trainable. Alternatively, it can be a single boolean value to tell if all the parameters are trainable or not. By default all parameters are non-trainable. The default value is subject to change in the future, thus it is recommended to specify the value explicitly.

  • input_polymorphic_shape (Union[Any, Mapping[str, Any], None]) – the polymorhpic shape for the inputs of apply_fn. If apply_fn is a mapping, input_polymorphic_shape must be a mapping of method key to the input polymorphic shape for the method.

  • jax2tf_kwargs (Optional[Mapping[str, Any]]) – options passed to jax2tf. polymorphic_shape is inferred from input_polymorphic_shape and should not be set. with_gradient, if set, should be consistent with the trainable argument above. If jax2tf_kwargs is unspecified, the default jax2tf option will be applied. If apply_fn is a mapping, jax2tf_kwargs must either be unspecified or a mapping of method key to the jax2tf kwargs for the method.

  • jit_compile (Union[bool, Mapping[str, bool]]) – whether to jit compile the jax2tf converted functions. If apply_fn is a mapping, this can either be a boolean applied to all functions or a mapping of method key to the jit compile option for the method.

  • name (Optional[str]) – the name of the module.

  • pspecs (Optional[Any]) – an optional pytree of PartitionSpecs of the params in the same structure as params. If set, the leaves of params must be jax.Array, and JaxModule must be created within a DTensor export context from with maybe_enable_dtensor_export_on(mesh).

  • allow_multi_axis_sharding_conslidation (Optional[bool]) – Disallowed by default. When set to true, it will allow conslidating JAX array multiple axis sharding into DTensor single axis sharding during checkpoint conversion. This would enable sharding across multiple axis names support for JAX model.

update_variables(params)[source][source]#

Updates the variables associated with self.

Parameters:

params (Any) – A PyTree of JAX parameters. The PyTree structure must be the same as that of the params used to initialize the model. Additionally, the shape and dtype of each parameter must be the same as the original parameter.

property methods: Mapping[str, Callable[[...], Any]]#

Named methods in TF context.

Return type:

Mapping[str, Callable[…, Any]]

property jax_methods: Mapping[str, Callable[[...], Any]]#

Named methods in JAX context for validation.

Return type:

Mapping[str, Callable[…, Any]]