model.core.python.function#
The Function base class.
- class orbax.experimental.model.core.python.function.ShloTensorSpec(shape, dtype, sharding=None, layout=None, name=None)[source][source]#
A specification for the shape, dtype, sharding, and layout of a StableHLO tensor.
- shape#
The shape of the tensor.
- Type:
Sequence[int | None] | None
- dtype#
The dtype of the tensor.
- sharding#
The sharding of the tensor. None means unspecified sharding.
- Type:
xla.xla_data_pb2.OpSharding | None
- layout#
The layout of the tensor. None means the default layout is used.
- Type:
xla.xla_data_pb2.LayoutProto | None
- name#
The name of the tensor.
- Type:
str | None
- __eq__(other)#
Return self==value.
- __hash__ = None#
- __init__(shape, dtype, sharding=None, layout=None, name=None)#
- class orbax.experimental.model.core.python.function.Function(*, input_signature, output_signature, data_names=None)[source][source]#
An abstract base class for functions whose signatures are StableHLO types.
- input_signature#
the input signature of the function.
- Type:
orbax.experimental.model.core.python.function.ShloTensorSpec | list[Tree] | tuple[Tree, …] | dict[str, Tree] | None
- output_signature#
the output signature of the function.
- Type:
orbax.experimental.model.core.python.function.ShloTensorSpec | list[Tree] | tuple[Tree, …] | dict[str, Tree] | None
- data_names#
checkpoint data names used by the function.
- Type:
Sequence[str] | None
- signature#
the pair (input_signature, output_signature).
- __eq__(other)#
Return self==value.
- __hash__ = None#
- __init__(*, input_signature, output_signature, data_names=None)#