General Utilities#
Utilities for Orbax export.
TensorSpecWithDefault#
- class orbax.export.utils.TensorSpecWithDefault(tensor_spec, default_val, is_primary=False)[source][source]#
Extends tf.TensorSpec to hold a default value.
- Constraints due to Python function calling conventions:
For a python function parameter, all corresponding tensor values in the signature must have a TensorSpecWithDefault or none of them should.
Parameters with default values should be ordered before non-default ones.
- __eq__(other)#
Return self==value.
- __hash__ = None#
- __init__(tensor_spec, default_val, is_primary=False)#
CallableSignatures#
- class orbax.export.utils.CallableSignatures(sess, signature_defs)[source][source]#
Holds TF SignatureDefs as python callables.
- classmethod from_saved_model(model_dir, tags, sess_config=None)[source][source]#
Loads a SavedModel and reconsruct its signatures as python callables.
The signatures of the object loaded by the
tf.saved_model.load
API doesn’t support default values, hence one can use this class to load the model in TF1 and reconstruct the signatures. Example:>>> loaded = CallableSignatures.from_saved_model(model_dir, ['serve']) >>> outputs = loaded.signature['serving_default'](**inputs)
The TF2 version of this example is
>>> loaded_tf2 = tf.saved_model.load(model_dir, ['serve']) >>> outputs = loaded_tf2.signatures['serving_default'](**inputs)
But the callables in loaded_tf2.signatures doesn’t have any default inputs.
- Parameters:
model_dir (
str
) – SavedModel directory.tags (
list
[str
]) – Tags to identify the metagraph to load. Same as the tags argument in tf.saved_model.load.sess_config (
Optional
[Any
]) – (Optional.) A [ConfigProto](tensorflow/tensorflow) protocol buffer with configuration options for the session.
- Returns:
A mapping of signature names to the callables.
- property signatures#
Returns a mapping for signature names to python callables.
Utility functions#
- orbax.export.utils.with_default_args(tf_fn, input_signature)[source][source]#
Creates a TF function with default args specified in input_signature.
- Parameters:
tf_fn (
Callable
[…,Any
]) – the TF function.input_signature (
Sequence
[Any
]) – the input signature. Even leaf is a tf.TensorSpec, or a orbax.export.TensorSpecWithDefault if the default value is specified.
- Return type:
PolymorphicFunction
- Returns:
A tf function with default arguments.
- orbax.export.utils.make_auto_batching_function(input_signature)[source][source]#
Creates an auto-batching function from input signature.
An auto-batching function is a function whose input tensors can have either a batch size of “b” or 1, and whose output tensors have a batch size of “b”, where “b” is the batch size of the primary input tensors.
- Requirements:
All input tensors must have a leading batch dimension.
There must be at least one primary tensor. A primary tensor is a tensor
whose tensor spec is either a tf.TensorSpec or a TensorSpecWithDefault whose is_primary attribute is True. - All primary tensors must have the same batch size. - All non-primary tensors must have a batch size of 1, or the same as the primary batch size.
Example
>>> input_signature = ( >>> tf.TensorSpec([None], tf.int32, name='primary'), >>> TensorSpecWithDefault( >>> tf.TensorSpec([None], tf.int32, name='optional'), [1] >>> ), >>> ) >>> batching_fn = utils.make_auto_batching_function(input_signature) >>> batching_fn(tf.constant([0, 0]), tf.constant([1])) (<tf.Tensor: shape=(2,), dtype=int32, numpy=array([0, 0], dtype=int32)>, <tf.Tensor: shape=(2,), dtype=int32, numpy=array([1, 1], dtype=int32)>)
- Parameters:
input_signature (
Sequence
[Any
]) – a sequence of PyTrees whose leaf node is tf.Tensor or TensorSpecWithDefault.- Return type:
Callable
[…,Any
]- Returns:
A TF function whose output tensors all have the same batch size.