General Utilities#
Utilities for Orbax export.
TensorSpecWithDefault#
- class orbax.export.utils.TensorSpecWithDefault(tensor_spec, default_val, is_primary=False)[source][source]#
Extends tf.TensorSpec to hold a default value.
This class associates a specific default value with a TensorFlow tensor specification. It automatically converts the provided default value into a tf.Tensor and validates that it strictly matches the shape and dtype of the provided tensor_spec.
- Constraints due to Python function calling conventions:
For a python function parameter, all corresponding tensor values in the signature must have a TensorSpecWithDefault or none of them should.
Parameters with default values should be ordered before non-default ones.
Example
Create a TensorSpec with a default fallback value:
import tensorflow as tf from orbax.export.utils import TensorSpecWithDefault # Define a spec for a 1D tensor spec = tf.TensorSpec(shape=(2,), dtype=tf.float32, name="input_a") # Create the extended spec with a default list # The list is automatically converted to a tf.Tensor upon initialization default_spec = TensorSpecWithDefault( tensor_spec=spec, default_val=[1.0, 2.0] )
- tensor_spec#
The underlying tf.TensorSpec defining the expected shape and dtype.
- Type:
tensorflow.python.framework.tensor.TensorSpec
- default_val#
The default value to use. Upon initialization, this is automatically converted to a tf.Tensor using the dtype from tensor_spec.
- Type:
Any
- is_primary#
Whether this tensor is a primary input tensor. A primary input tensor is a tensor whose batch size is already or will be tiled to match the batch size of all other primary input tensors. A non-primary input tensor must have a batch size of 1, or the same as the primary batch size. Used primarily by orbax.export.utils.make_auto_batching_function.
- Type:
bool
- __eq__(other)#
Return self==value.
- __hash__ = None#
- __init__(tensor_spec, default_val, is_primary=False)#
CallableSignatures#
- class orbax.export.utils.CallableSignatures(sess, signature_defs)[source][source]#
Holds TF SignatureDefs as python callables.
- classmethod from_saved_model(model_dir, tags, sess_config=None)[source][source]#
Loads a SavedModel and reconsruct its signatures as python callables.
The signatures of the object loaded by the
tf.saved_model.loadAPI doesn’t support default values, hence one can use this class to load the model in TF1 and reconstruct the signatures. Example:>>> loaded = CallableSignatures.from_saved_model(model_dir, ['serve']) >>> outputs = loaded.signature['serving_default'](**inputs)
The TF2 version of this example is
>>> loaded_tf2 = tf.saved_model.load(model_dir, ['serve']) >>> outputs = loaded_tf2.signatures['serving_default'](**inputs)
But the callables in loaded_tf2.signatures doesn’t have any default inputs.
- Parameters:
model_dir (
str) – SavedModel directory.tags (
list[str]) – Tags to identify the metagraph to load. Same as the tags argument in tf.saved_model.load.sess_config (
Any) – (Optional.) A [ConfigProto](tensorflow/tensorflow) protocol buffer with configuration options for the session.
- Returns:
A mapping of signature names to the callables.
- property signatures#
Returns a mapping for signature names to python callables.
Utility functions#
- orbax.export.utils.with_default_args(tf_fn, input_signature)[source][source]#
Creates a TF function with default args specified in input_signature.
This utility wraps a standard Python or TensorFlow function and rewrites its Python signature to include default values extracted from the provided input_signature.
Example
Create a TensorFlow function where the second argument has a default value:
import tensorflow as tf from orbax.export.utils import TensorSpecWithDefault, with_default_args def add_tensors(x, y): return x + y # Define the signature: `x` is required, `y` has a default value signature = [ tf.TensorSpec(shape=(2,), dtype=tf.float32, name='x'), TensorSpecWithDefault( tensor_spec=tf.TensorSpec(shape=(2,), dtype=tf.float32, name='y'), default_val=[1.0, 1.0] ) ] # Create the new tf.function with defaults applied fn_with_defaults = with_default_args(add_tensors, signature) # The function can now be executed with only the required 'x' argument result = fn_with_defaults(tf.constant([2.0, 2.0]))
- Parameters:
tf_fn (
Callable[…,Any]) – the TF function.input_signature (
Sequence[PyTree]) – the input signature. Even leaf is a tf.TensorSpec, or a orbax.export.TensorSpecWithDefault if the default value is specified.
- Return type:
PolymorphicFunction- Returns:
A tf function with default arguments.
- orbax.export.utils.make_auto_batching_function(input_signature)[source][source]#
Creates an auto-batching function from input signature.
An auto-batching function is a function whose input tensors can have either a batch size of “b” or 1, and whose output tensors have a batch size of “b”, where “b” is the batch size of the primary input tensors.
- Requirements:
All input tensors must have a leading batch dimension.
There must be at least one primary tensor. A primary tensor is a tensor
whose tensor spec is either a tf.TensorSpec or a TensorSpecWithDefault whose is_primary attribute is True. - All primary tensors must have the same batch size. - All non-primary tensors must have a batch size of 1, or the same as the primary batch size.
Example
>>> input_signature = ( >>> tf.TensorSpec([None], tf.int32, name='primary'), >>> TensorSpecWithDefault( >>> tf.TensorSpec([None], tf.int32, name='optional'), [1] >>> ), >>> ) >>> batching_fn = utils.make_auto_batching_function(input_signature) >>> batching_fn(tf.constant([0, 0]), tf.constant([1])) (<tf.Tensor: shape=(2,), dtype=int32, numpy=array([0, 0], dtype=int32)>, <tf.Tensor: shape=(2,), dtype=int32, numpy=array([1, 1], dtype=int32)>)
- Parameters:
input_signature (
Sequence[PyTree]) – a sequence of PyTrees whose leaf node is tf.Tensor or TensorSpecWithDefault.- Return type:
Callable[…,Any]- Returns:
A TF function whose output tensors all have the same batch size.