model.core.python.function#

The Function base class.

class orbax.experimental.model.core.python.function.ShloDType(*values)[source][source]#
class orbax.experimental.model.core.python.function.ShloTensorSpec(shape, dtype, sharding=None, layout=None, name=None)[source][source]#

A specification for the shape, dtype, sharding, and layout of a StableHLO tensor.

shape#

The shape of the tensor.

Type:

Sequence[int | None] | None

dtype#

The dtype of the tensor.

Type:

orbax.experimental.model.core.python.function.ShloDType

sharding#

The sharding of the tensor. None means unspecified sharding.

Type:

xla.xla_data_pb2.OpSharding | None

layout#

The layout of the tensor. None means the default layout is used.

Type:

xla.xla_data_pb2.LayoutProto | None

name#

The name of the tensor.

Type:

str | None

__eq__(other)#

Return self==value.

__hash__ = None#
__init__(shape, dtype, sharding=None, layout=None, name=None)#
class orbax.experimental.model.core.python.function.Function(*, input_signature, output_signature, data_names=None)[source][source]#

An abstract base class for functions whose signatures are StableHLO types.

input_signature#

the input signature of the function.

Type:

orbax.experimental.model.core.python.function.ShloTensorSpec | list[Tree] | tuple[Tree, …] | dict[str, Tree] | None

output_signature#

the output signature of the function.

Type:

orbax.experimental.model.core.python.function.ShloTensorSpec | list[Tree] | tuple[Tree, …] | dict[str, Tree] | None

data_names#

checkpoint data names used by the function.

Type:

Sequence[str] | None

signature#

the pair (input_signature, output_signature).

__eq__(other)#

Return self==value.

__hash__ = None#
__init__(*, input_signature, output_signature, data_names=None)#