General Utilities#

Utilities for Orbax export.

TensorSpecWithDefault#

class orbax.export.utils.TensorSpecWithDefault(tensor_spec, default_val, is_primary=False)[source][source]#

Extends tf.TensorSpec to hold a default value.

This class associates a specific default value with a TensorFlow tensor specification. It automatically converts the provided default value into a tf.Tensor and validates that it strictly matches the shape and dtype of the provided tensor_spec.

Constraints due to Python function calling conventions:
  • For a python function parameter, all corresponding tensor values in the signature must have a TensorSpecWithDefault or none of them should.

  • Parameters with default values should be ordered before non-default ones.

Example

Create a TensorSpec with a default fallback value:

import tensorflow as tf
from orbax.export.utils import TensorSpecWithDefault

# Define a spec for a 1D tensor
spec = tf.TensorSpec(shape=(2,), dtype=tf.float32, name="input_a")

# Create the extended spec with a default list
# The list is automatically converted to a tf.Tensor upon initialization
default_spec = TensorSpecWithDefault(
    tensor_spec=spec,
    default_val=[1.0, 2.0]
)
tensor_spec#

The underlying tf.TensorSpec defining the expected shape and dtype.

Type:

tensorflow.python.framework.tensor.TensorSpec

default_val#

The default value to use. Upon initialization, this is automatically converted to a tf.Tensor using the dtype from tensor_spec.

Type:

Any

is_primary#

Whether this tensor is a primary input tensor. A primary input tensor is a tensor whose batch size is already or will be tiled to match the batch size of all other primary input tensors. A non-primary input tensor must have a batch size of 1, or the same as the primary batch size. Used primarily by orbax.export.utils.make_auto_batching_function.

Type:

bool

__eq__(other)#

Return self==value.

__hash__ = None#
__init__(tensor_spec, default_val, is_primary=False)#

CallableSignatures#

class orbax.export.utils.CallableSignatures(sess, signature_defs)[source][source]#

Holds TF SignatureDefs as python callables.

__init__(sess, signature_defs)[source][source]#
classmethod from_saved_model(model_dir, tags, sess_config=None)[source][source]#

Loads a SavedModel and reconsruct its signatures as python callables.

The signatures of the object loaded by the tf.saved_model.load API doesn’t support default values, hence one can use this class to load the model in TF1 and reconstruct the signatures. Example:

>>> loaded = CallableSignatures.from_saved_model(model_dir, ['serve'])
>>> outputs = loaded.signature['serving_default'](**inputs)

The TF2 version of this example is

>>> loaded_tf2 = tf.saved_model.load(model_dir, ['serve'])
>>> outputs = loaded_tf2.signatures['serving_default'](**inputs)

But the callables in loaded_tf2.signatures doesn’t have any default inputs.

Parameters:
  • model_dir (str) – SavedModel directory.

  • tags (list[str]) – Tags to identify the metagraph to load. Same as the tags argument in tf.saved_model.load.

  • sess_config (Any) – (Optional.) A [ConfigProto](tensorflow/tensorflow) protocol buffer with configuration options for the session.

Returns:

A mapping of signature names to the callables.

property signatures#

Returns a mapping for signature names to python callables.

Utility functions#

orbax.export.utils.with_default_args(tf_fn, input_signature)[source][source]#

Creates a TF function with default args specified in input_signature.

This utility wraps a standard Python or TensorFlow function and rewrites its Python signature to include default values extracted from the provided input_signature.

Example

Create a TensorFlow function where the second argument has a default value:

import tensorflow as tf
from orbax.export.utils import TensorSpecWithDefault, with_default_args

def add_tensors(x, y):
  return x + y

# Define the signature: `x` is required, `y` has a default value
signature = [
    tf.TensorSpec(shape=(2,), dtype=tf.float32, name='x'),
    TensorSpecWithDefault(
        tensor_spec=tf.TensorSpec(shape=(2,), dtype=tf.float32, name='y'),
        default_val=[1.0, 1.0]
    )
]

# Create the new tf.function with defaults applied
fn_with_defaults = with_default_args(add_tensors, signature)

# The function can now be executed with only the required 'x' argument
result = fn_with_defaults(tf.constant([2.0, 2.0]))
Parameters:
  • tf_fn (Callable[…, Any]) – the TF function.

  • input_signature (Sequence[PyTree]) – the input signature. Even leaf is a tf.TensorSpec, or a orbax.export.TensorSpecWithDefault if the default value is specified.

Return type:

PolymorphicFunction

Returns:

A tf function with default arguments.

orbax.export.utils.make_auto_batching_function(input_signature)[source][source]#

Creates an auto-batching function from input signature.

An auto-batching function is a function whose input tensors can have either a batch size of “b” or 1, and whose output tensors have a batch size of “b”, where “b” is the batch size of the primary input tensors.

Requirements:
  • All input tensors must have a leading batch dimension.

  • There must be at least one primary tensor. A primary tensor is a tensor

whose tensor spec is either a tf.TensorSpec or a TensorSpecWithDefault whose is_primary attribute is True. - All primary tensors must have the same batch size. - All non-primary tensors must have a batch size of 1, or the same as the primary batch size.

Example

>>> input_signature = (
>>>     tf.TensorSpec([None], tf.int32, name='primary'),
>>>     TensorSpecWithDefault(
>>>         tf.TensorSpec([None], tf.int32, name='optional'), [1]
>>>     ),
>>> )
>>> batching_fn = utils.make_auto_batching_function(input_signature)
>>> batching_fn(tf.constant([0, 0]), tf.constant([1]))
(<tf.Tensor: shape=(2,), dtype=int32, numpy=array([0, 0], dtype=int32)>,
<tf.Tensor: shape=(2,), dtype=int32, numpy=array([1, 1], dtype=int32)>)
Parameters:

input_signature (Sequence[PyTree]) – a sequence of PyTrees whose leaf node is tf.Tensor or TensorSpecWithDefault.

Return type:

Callable[…, Any]

Returns:

A TF function whose output tensors all have the same batch size.