JaxModule#
Wraps JAX functions and parameters into a tf.Module.
JaxModule#
- class orbax.export.jax_module.JaxModule(params, apply_fn, trainable=None, input_polymorphic_shape=None, jax2tf_kwargs=None, jit_compile=True, name=None, pspecs=None, allow_multi_axis_sharding_conslidation=None)[source][source]#
An exportable module for JAX functions and parameters.
Holds tf.Variables converted from JAX parameters, as well as TF functions converted from JAX functions and bound with the tf.Variables.
- __init__(params, apply_fn, trainable=None, input_polymorphic_shape=None, jax2tf_kwargs=None, jit_compile=True, name=None, pspecs=None, allow_multi_axis_sharding_conslidation=None)[source][source]#
JaxModule constructor.
- Parameters:
params (
Any
) – a pytree of JAX parameters.apply_fn (
Union
[Callable
[[Any
,Any
],Any
],Mapping
[str
,Callable
[[Any
,Any
],Any
]]]) – A JAXApplyFn
(i.e. of signatureapply_fn(params, x)
), or a mapping of method key toApplyFn
. If it is anApplyFn
, it will be assigned a keyJaxModule.DEFAULT_METHOD_KEY
automatically, which can be used to look up the TF function converted from it.trainable (
Union
[bool
,Any
,None
]) – a pytree in the same structure asparams
and boolean leaves to tell if a parameter is trainable. Alternatively, it can be a single boolean value to tell if all the parameters are trainable or not. By default all parameters are non-trainable. The default value is subject to change in the future, thus it is recommended to specify the value explicitly.input_polymorphic_shape (
Union
[Any
,Mapping
[str
,Any
],None
]) – the polymorhpic shape for the inputs ofapply_fn
. Ifapply_fn
is a mapping,input_polymorphic_shape
must be a mapping of method key to the input polymorphic shape for the method.jax2tf_kwargs (
Optional
[Mapping
[str
,Any
]]) – options passed to jax2tf.polymorphic_shape
is inferred frominput_polymorphic_shape
and should not be set.with_gradient
, if set, should be consistent with thetrainable
argument above. Ifjax2tf_kwargs
is unspecified, the default jax2tf option will be applied. Ifapply_fn
is a mapping, jax2tf_kwargs must either be unspecified or a mapping of method key to the jax2tf kwargs for the method.jit_compile (
Union
[bool
,Mapping
[str
,bool
]]) – whether to jit compile the jax2tf converted functions. Ifapply_fn
is a mapping, this can either be a boolean applied to all functions or a mapping of method key to the jit compile option for the method.name (
Optional
[str
]) – the name of the module.pspecs (
Optional
[Any
]) – an optional pytree of PartitionSpecs of theparams
in the same structure asparams
. If set, the leaves ofparams
must be jax.Array, andJaxModule
must be created within a DTensor export context fromwith maybe_enable_dtensor_export_on(mesh)
.allow_multi_axis_sharding_conslidation (
Optional
[bool
]) – Disallowed by default. When set to true, it will allow conslidating JAX array multiple axis sharding into DTensor single axis sharding during checkpoint conversion. This would enable sharding across multiple axis names support for JAX model.
- update_variables(params)[source][source]#
Updates the variables associated with self.
- Parameters:
params (
Any
) – A PyTree of JAX parameters. The PyTree structure must be the same as that of the params used to initialize the model. Additionally, the shape and dtype of each parameter must be the same as the original parameter.
- property methods: Mapping[str, Callable[[...], Any]]#
Named methods in TF context.
- Return type:
Mapping
[str
,Callable
[…,Any
]]
- property jax_methods: Mapping[str, Callable[[...], Any]]#
Named methods in JAX context for validation.
- Return type:
Mapping
[str
,Callable
[…,Any
]]