General Utilities#

Utilities for Orbax export.

TensorSpecWithDefault#

class orbax.export.utils.TensorSpecWithDefault(tensor_spec, default_val, is_primary=False)[source][source]#

Extends tf.TensorSpec to hold a default value.

Constraints due to Python function calling conventions:
  • For a python function parameter, all corresponding tensor values in the signature must have a TensorSpecWithDefault or none of them should.

  • Parameters with default values should be ordered before non-default ones.

__eq__(other)#

Return self==value.

__hash__ = None#
__init__(tensor_spec, default_val, is_primary=False)#

CallableSignatures#

class orbax.export.utils.CallableSignatures(sess, signature_defs)[source][source]#

Holds TF SignatureDefs as python callables.

__init__(sess, signature_defs)[source][source]#
classmethod from_saved_model(model_dir, tags, sess_config=None)[source][source]#

Loads a SavedModel and reconsruct its signatures as python callables.

The signatures of the object loaded by the tf.saved_model.load API doesn’t support default values, hence one can use this class to load the model in TF1 and reconstruct the signatures. Example:

>>> loaded = CallableSignatures.from_saved_model(model_dir, ['serve'])
>>> outputs = loaded.signature['serving_default'](**inputs)

The TF2 version of this example is

>>> loaded_tf2 = tf.saved_model.load(model_dir, ['serve'])
>>> outputs = loaded_tf2.signatures['serving_default'](**inputs)

But the callables in loaded_tf2.signatures doesn’t have any default inputs.

Parameters:
  • model_dir (str) – SavedModel directory.

  • tags (list[str]) – Tags to identify the metagraph to load. Same as the tags argument in tf.saved_model.load.

  • sess_config (Optional[Any]) – (Optional.) A [ConfigProto](tensorflow/tensorflow) protocol buffer with configuration options for the session.

Returns:

A mapping of signature names to the callables.

property signatures#

Returns a mapping for signature names to python callables.

Utility functions#

orbax.export.utils.with_default_args(tf_fn, input_signature)[source][source]#

Creates a TF function with default args specified in input_signature.

Parameters:
  • tf_fn (Callable[…, Any]) – the TF function.

  • input_signature (Sequence[Any]) – the input signature. Even leaf is a tf.TensorSpec, or a orbax.export.TensorSpecWithDefault if the default value is specified.

Return type:

PolymorphicFunction

Returns:

A tf function with default arguments.

orbax.export.utils.make_auto_batching_function(input_signature)[source][source]#

Creates an auto-batching function from input signature.

An auto-batching function is a function whose input tensors can have either a batch size of “b” or 1, and whose output tensors have a batch size of “b”, where “b” is the batch size of the primary input tensors.

Requirements:
  • All input tensors must have a leading batch dimension.

  • There must be at least one primary tensor. A primary tensor is a tensor

whose tensor spec is either a tf.TensorSpec or a TensorSpecWithDefault whose is_primary attribute is True. - All primary tensors must have the same batch size. - All non-primary tensors must have a batch size of 1, or the same as the primary batch size.

Example

>>> input_signature = (
>>>     tf.TensorSpec([None], tf.int32, name='primary'),
>>>     TensorSpecWithDefault(
>>>         tf.TensorSpec([None], tf.int32, name='optional'), [1]
>>>     ),
>>> )
>>> batching_fn = utils.make_auto_batching_function(input_signature)
>>> batching_fn(tf.constant([0, 0]), tf.constant([1]))
(<tf.Tensor: shape=(2,), dtype=int32, numpy=array([0, 0], dtype=int32)>,
<tf.Tensor: shape=(2,), dtype=int32, numpy=array([1, 1], dtype=int32)>)
Parameters:

input_signature (Sequence[Any]) – a sequence of PyTrees whose leaf node is tf.Tensor or TensorSpecWithDefault.

Return type:

Callable[…, Any]

Returns:

A TF function whose output tensors all have the same batch size.