Source code for orbax.checkpoint._src.metadata.sharding
# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ShardingMetadata representing Sharding property."""
from __future__ import annotations
import abc
import dataclasses
import enum
import json
import logging
from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union
import jax
import numpy as np
from orbax.checkpoint._src.sharding_utils import make_single_device_sharding
_AXIS_TYPE_MAP = {str(val): val for val in jax.sharding.AxisType}
PartitionSpecElement = Union[None, str, Tuple[str, ...]]
_PARTITION_SPEC = 'partition_spec'
_SHARDING = '_sharding'
_SHARDING_TYPE = 'sharding_type'
_DEVICE_STR = 'device_str'
_MESH_AXES = 'axis_names'
_MESH_AXIS_TYPES = 'axis_types'
_MESH_SHAPE = 'shape'
_DEVICES_SHAPE = 'shape'
_DEVICE_MESH = 'device_mesh'
_MEMORY_KIND = 'memory_kind'
_ID = 'id'
class ShardingTypes(enum.Enum):
NAMED_SHARDING = 'NamedSharding'
SINGLE_DEVICE_SHARDING = 'SingleDeviceSharding'
POSITIONAL_SHARDING = 'PositionalSharding'
@dataclasses.dataclass
class DeviceMetadata:
"""TPU Device metadata class."""
id: int
@classmethod
def from_dict(cls, data: dict[str, Any]) -> DeviceMetadata:
return DeviceMetadata(
id=data[_ID],
)
@classmethod
def from_jax_device(cls, device: jax.Device) -> DeviceMetadata:
return DeviceMetadata(
id=device.id,
)
def __eq__(self, other: DeviceMetadata):
return self.id == other.id
@dataclasses.dataclass
class DeviceMetadataMesh:
"""Contain a mesh of DeviceMetadata class."""
mesh: Sequence[Any]
@classmethod
def from_dict(cls, data: dict[str, Any]) -> DeviceMetadataMesh:
mesh = data['mesh']
device_mesh = jax.tree.map(
DeviceMetadata.from_dict,
mesh,
is_leaf=lambda x: isinstance(x, Mapping) and _ID in x,
)
return DeviceMetadataMesh(mesh=device_mesh)
@classmethod
def from_jax_mesh(
cls, mesh: jax.sharding.Mesh
) -> Optional[DeviceMetadataMesh]:
"""Take in of jax.sharding.Mesh and convert into DeviceMetadata while keeping the sequences.
Support only TPU-device. If there is any non-TPU device, return None.
Args:
mesh: jax.sharding.Mesh
Returns:
DeviceMetadataMesh if only TPU-devices are in the mesh.
"""
if isinstance(devices := mesh.devices, np.ndarray):
devices = devices.tolist()
device_mesh = jax.tree.map(
DeviceMetadata.from_jax_device,
devices,
is_leaf=lambda x: isinstance(x, jax.Device),
)
return DeviceMetadataMesh(mesh=device_mesh)
def to_jax_device_mesh(self):
"""return a jax Device mesh.
Returns:
Nested sequence of jax.Device
"""
# build device.id to device map
device_map = {}
for device in jax.devices():
device_map[device.id] = device
def build_device(m: DeviceMetadata) -> jax.Device:
if ret := device_map.get(m.id):
return ret
else:
raise ValueError(
'The available devices are different from the devices used to'
' save the checkpoint. Please restore checkpoint by passing'
f' new shardings for target devices. Original={self.mesh},'
f' current available={jax.devices()}'
)
return jax.tree.map(
build_device,
self.mesh,
is_leaf=lambda x: isinstance(x, DeviceMetadata),
)
def __eq__(self, other):
return self.mesh == other.mesh
[docs]
@dataclasses.dataclass
class ShardingMetadata(abc.ABC):
"""ShardingMetadata representing Sharding property.
This ShardingMetadata only represents the following `jax.sharding.Sharding`:
jax.sharding.NamedSharding
jax.sharding.SingleDeviceSharding
"""
[docs]
@classmethod
@abc.abstractmethod
def from_jax_sharding(cls, jax_sharding) -> ShardingMetadata:
"""Converts `jax.sharding.Sharding` to `ShardingMetadata`."""
[docs]
@abc.abstractmethod
def to_jax_sharding(self) -> jax.sharding.Sharding:
"""Converts `ShardingMetadata` to `jax.sharding.Sharding`."""
[docs]
@classmethod
@abc.abstractmethod
def from_deserialized_dict(
cls, deserialized_dict: dict[str, str]
) -> ShardingMetadata:
"""Converts serialized_string in the form of `dict[str, str]` to `ShardingMetadata`."""
[docs]
@abc.abstractmethod
def to_serialized_string(self) -> str:
"""Converts `ShardingMetadata` to `serialized_string`."""
[docs]
@dataclasses.dataclass
class NamedShardingMetadata(ShardingMetadata):
"""NamedShardingMetadata representing `jax.sharding.NamedSharding`."""
shape: np.ndarray
axis_names: List[str]
partition_spec: Tuple[
PartitionSpecElement, ...
] # Each element is either ``None``, a string, or a tuple of strings.
axis_types: Optional[Tuple[jax.sharding.AxisType, ...]] = None
# Optional device mesh. If it's None, use jax.devices(),
# otherwise, the stored device_mesh will be used to recreate NamedSharding.
device_mesh: Optional[DeviceMetadataMesh] = None
[docs]
@classmethod
def from_jax_sharding(
cls, jax_sharding: jax.sharding.NamedSharding
) -> NamedShardingMetadata:
return cls(
shape=np.array(list(jax_sharding.mesh.shape.values())),
axis_names=list(jax_sharding.mesh.axis_names),
axis_types=tuple(jax_sharding.mesh.axis_types),
partition_spec=tuple(jax_sharding.spec),
device_mesh=DeviceMetadataMesh.from_jax_mesh(jax_sharding.mesh),
)
[docs]
def to_jax_sharding(self) -> jax.sharding.NamedSharding:
if self.device_mesh:
mesh_devices = self.device_mesh.to_jax_device_mesh()
else:
mesh_devices = jax.devices()
return jax.sharding.NamedSharding(
jax.sharding.Mesh(
np.asarray(mesh_devices).reshape(self.shape),
axis_names=self.axis_names,
axis_types=self.axis_types,
),
spec=jax.sharding.PartitionSpec(*self.partition_spec),
)
[docs]
@classmethod
def from_deserialized_dict(
cls, deserialized_dict: dict[str, Any]
) -> NamedShardingMetadata:
if (
_MESH_SHAPE in deserialized_dict
and _MESH_AXES in deserialized_dict
and _PARTITION_SPEC in deserialized_dict
):
shape = np.array(deserialized_dict[_MESH_SHAPE])
axis_names = list(deserialized_dict[_MESH_AXES])
axis_types = None
if axis_types_raw := deserialized_dict.get(_MESH_AXIS_TYPES):
axis_types = tuple([_AXIS_TYPE_MAP[s] for s in axis_types_raw])
partition_spec = tuple(deserialized_dict[_PARTITION_SPEC])
if device_mesh_dic := deserialized_dict.get(_DEVICE_MESH):
device_mesh = DeviceMetadataMesh.from_dict(device_mesh_dic)
else:
device_mesh = None
return cls(
shape=shape,
axis_names=axis_names,
axis_types=axis_types,
partition_spec=partition_spec,
device_mesh=device_mesh,
)
else:
raise ValueError(
f'Sharding data not found in deserialized_dict: {deserialized_dict}'
)
[docs]
def to_serialized_string(self) -> str:
sharding_data = {}
sharding_data[_SHARDING_TYPE] = ShardingTypes.NAMED_SHARDING.value
sharding_data[_MESH_SHAPE] = self.shape.tolist()
sharding_data[_MESH_AXES] = self.axis_names
if self.axis_types is not None:
sharding_data[_MESH_AXIS_TYPES] = [str(a) for a in self.axis_types]
sharding_data[_PARTITION_SPEC] = self.partition_spec
if self.device_mesh:
sharding_data[_DEVICE_MESH] = dataclasses.asdict(self.device_mesh)
return json.dumps(sharding_data)
def __repr__(self):
return (
f'NamedShardingMetadata(shape={self.shape},'
f' axis_names={self.axis_names}, axis_types={self.axis_types},'
f' partition_spec={self.partition_spec}) device_mesh={self.device_mesh}'
)
[docs]
def __eq__(self, other):
return (
np.array_equal(self.shape, other.shape)
and self.axis_names == other.axis_names
and self.axis_types == other.axis_types
and self.partition_spec == other.partition_spec
and self.device_mesh == other.device_mesh
)
[docs]
@dataclasses.dataclass
class SingleDeviceShardingMetadata(ShardingMetadata):
"""SingleDeviceShardingMetadata representing `jax.sharding.SingleDeviceSharding`."""
device_str: str
[docs]
@classmethod
def from_jax_sharding(
cls, jax_sharding: jax.sharding.SingleDeviceSharding
) -> SingleDeviceShardingMetadata:
return cls(device_str=str(next(iter(jax_sharding.device_set))))
[docs]
def to_jax_sharding(self) -> jax.sharding.SingleDeviceSharding:
# JAX 0.10 changed CPU devices so they report as cpu:0 not TFRT_CPU_0
device_map = {
str(device).replace('TFRT_CPU_', 'cpu:'): device
for device in jax.local_devices()
}
device_str = self.device_str.replace('TFRT_CPU_', 'cpu:')
if device := device_map.get(device_str, None):
return make_single_device_sharding(device)
raise ValueError(
f'Device {device_str} was not found in jax.local_devices().'
)
[docs]
@classmethod
def from_deserialized_dict(
cls, deserialized_dict: dict[str, str]
) -> SingleDeviceShardingMetadata:
if (
_DEVICE_STR in deserialized_dict
and deserialized_dict[_DEVICE_STR] is not None
):
return cls(device_str=deserialized_dict[_DEVICE_STR])
raise ValueError(
f'Device str not found in deserialized_dict: {deserialized_dict}'
)
[docs]
def to_serialized_string(self) -> str:
sharding_data = {}
sharding_data[_SHARDING_TYPE] = ShardingTypes.SINGLE_DEVICE_SHARDING.value
sharding_data[_DEVICE_STR] = self.device_str
return json.dumps(sharding_data)
def __repr__(self):
return f'SingleDeviceShardingMetadata(device_str={self.device_str})'
[docs]
def __eq__(self, other):
if not isinstance(other, SingleDeviceShardingMetadata):
return False
# JAX 0.10 changed CPU devices so they report as cpu:0 not TFRT_CPU_0
return (
self.device_str.replace('TFRT_CPU_', 'cpu:')
== other.device_str.replace('TFRT_CPU_', 'cpu:')
)
def from_jax_sharding(jax_sharding) -> Optional[ShardingMetadata]:
"""Converts `jax.sharding.Sharding` to `ShardingMetadata`."""
if isinstance(jax_sharding, jax.sharding.NamedSharding):
return NamedShardingMetadata.from_jax_sharding(jax_sharding)
elif isinstance(jax_sharding, jax.sharding.SingleDeviceSharding):
return SingleDeviceShardingMetadata.from_jax_sharding(jax_sharding)
else:
logging.warning(
'Conversion for %s has not been implemented.', type(jax_sharding)
)
def from_serialized_string(serialized_str) -> ShardingMetadata:
"""Converts `serialized_string` to `ShardingMetadata`."""
deserialized_dict = json.loads(serialized_str)
if deserialized_dict[_SHARDING_TYPE] == ShardingTypes.NAMED_SHARDING.value:
return NamedShardingMetadata.from_deserialized_dict(deserialized_dict)
elif (
deserialized_dict[_SHARDING_TYPE]
== ShardingTypes.SINGLE_DEVICE_SHARDING.value
):
return SingleDeviceShardingMetadata.from_deserialized_dict(
deserialized_dict
)
elif (
deserialized_dict[_SHARDING_TYPE]
== ShardingTypes.POSITIONAL_SHARDING.value
):
raise ValueError(
'jax.sharding.PositionalSharding has been deprecated. Please use'
' jax.NamedSharding instead.')
else:
raise NotImplementedError(
f'Conversion for {deserialized_dict[_SHARDING_TYPE]} has not been'
' implemented.'
)
def get_sharding_or_none(serialized_string):
try:
return from_serialized_string(serialized_string.item()).to_jax_sharding()
except ValueError as e:
logging.error(e)