# Copyright 2024 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils for writing in msgpack format. Derived from flax.serialization."""
import enum
from typing import Any, Dict
import jax
import msgpack
import numpy as np
# On-the-wire / disk serialization format
# We encode state-dicts via msgpack, using its custom type extension.
# https://github.com/msgpack/msgpack/blob/master/spec.md
#
# - ndarrays and DeviceArrays are serialized to nested msgpack-encoded string
# of (shape-tuple, dtype-name (e.g. 'float32'), row-major array-bytes).
# Note: only simple ndarray types are supported, no objects or fields.
#
# - native complex scalars are converted to nested msgpack-encoded tuples
# (real, imag).
def _ndarray_to_bytes(arr) -> bytes:
"""Save ndarray to simple msgpack encoding."""
if isinstance(arr, jax.Array):
arr = np.array(arr)
if arr.dtype.hasobject or arr.dtype.isalignedstruct:
raise ValueError('Object and structured dtypes not supported '
'for serialization of ndarrays.')
tpl = (arr.shape, arr.dtype.name, arr.tobytes('C'))
return msgpack.packb(tpl, use_bin_type=True)
def _dtype_from_name(name: str):
"""Handle JAX bfloat16 dtype correctly."""
if name == b'bfloat16':
return jax.numpy.bfloat16
else:
return np.dtype(name)
def _ndarray_from_bytes(data: bytes) -> np.ndarray:
"""Load ndarray from simple msgpack encoding."""
shape, dtype_name, buffer = msgpack.unpackb(data, raw=True)
return np.frombuffer(buffer,
dtype=_dtype_from_name(dtype_name),
count=-1,
offset=0).reshape(shape, order='C')
class _MsgpackExtType(enum.IntEnum):
"""Messagepack custom type ids."""
NDARRAY = 1
NATIVE_COMPLEX = 2
NPSCALAR = 3
TUPLE = 4
def _msgpack_ext_pack(x):
"""Messagepack encoders for custom types."""
# TODO(flax-dev): Array here only work when they are fully addressable.
# If they are not fully addressable, use the GDA path for checkpointing.
if isinstance(x, (np.ndarray, jax.Array)):
return msgpack.ExtType(_MsgpackExtType.NDARRAY, _ndarray_to_bytes(x))
if np.issctype(type(x)):
# pack scalar as ndarray
return msgpack.ExtType(
_MsgpackExtType.NPSCALAR, _ndarray_to_bytes(np.asarray(x))
)
elif isinstance(x, complex):
return msgpack.ExtType(
_MsgpackExtType.NATIVE_COMPLEX, msgpack.packb((x.real, x.imag))
)
elif isinstance(x, tuple):
return msgpack.ExtType(
_MsgpackExtType.TUPLE,
msgpack.packb(
list(x),
strict_types=True,
use_bin_type=True,
default=_msgpack_ext_pack,
),
)
else:
raise ValueError(f'Unsupported msgpack object: {x}')
return x
def _msgpack_ext_unpack(code, data):
"""Messagepack decoders for custom types."""
if code == _MsgpackExtType.NDARRAY:
return _ndarray_from_bytes(data)
elif code == _MsgpackExtType.NATIVE_COMPLEX:
complex_tuple = msgpack.unpackb(data)
return complex(complex_tuple[0], complex_tuple[1])
elif code == _MsgpackExtType.NPSCALAR:
ar = _ndarray_from_bytes(data)
return ar[()] # unpack ndarray to scalar
elif code == _MsgpackExtType.TUPLE:
return tuple(
msgpack.unpackb(data, raw=False, ext_hook=_msgpack_ext_unpack)
)
else:
raise ValueError(f'Unsupported msgpack code: {code}')
return msgpack.ExtType(code, data)
# Chunking array leaves
# msgpack has a hard limit of 2**31 - 1 bytes per object leaf. To circumvent
# this limit for giant arrays (e.g. embedding tables), we traverse the tree
# and break up arrays near the limit into flattened array chunks.
# True limit is 2**31 - 1, but leave a margin for encoding padding.
MAX_CHUNK_SIZE = 2**30
def _np_convert_in_place(d):
"""Convert any jax devicearray leaves to numpy arrays in place."""
if isinstance(d, dict):
for k, v in d.items():
if isinstance(v, jax.Array):
d[k] = np.array(v)
elif isinstance(v, dict):
_np_convert_in_place(v)
elif isinstance(d, jax.Array):
return np.array(d)
return d
_tuple_to_dict = lambda tpl: {str(x): y for x, y in enumerate(tpl)}
_dict_to_tuple = lambda dct: tuple(dct[str(i)] for i in range(len(dct)))
def _chunk(arr) -> Dict[str, Any]:
"""Convert array to a canonical dictionary of chunked arrays."""
chunksize = max(1, int(MAX_CHUNK_SIZE / arr.dtype.itemsize))
data = {'__msgpack_chunked_array__': True,
'shape': _tuple_to_dict(arr.shape)}
flatarr = arr.reshape(-1)
chunks = [flatarr[i:i + chunksize] for i in range(0, flatarr.size, chunksize)]
data['chunks'] = _tuple_to_dict(chunks)
return data
def _unchunk(data: Dict[str, Any]):
"""Convert canonical dictionary of chunked arrays back into array."""
assert '__msgpack_chunked_array__' in data
shape = _dict_to_tuple(data['shape'])
flatarr = np.concatenate(_dict_to_tuple(data['chunks']))
return flatarr.reshape(shape)
def _chunk_array_leaves_in_place(d):
"""Convert oversized array leaves to safe chunked form in place."""
if isinstance(d, dict):
for k, v in d.items():
if isinstance(v, np.ndarray):
if v.size * v.dtype.itemsize > MAX_CHUNK_SIZE:
d[k] = _chunk(v)
elif isinstance(v, dict):
_chunk_array_leaves_in_place(v)
elif isinstance(d, np.ndarray):
if d.size * d.dtype.itemsize > MAX_CHUNK_SIZE:
return _chunk(d)
return d
def _unchunk_array_leaves_in_place(d):
"""Convert chunked array leaves back into array leaves, in place."""
if isinstance(d, dict):
if '__msgpack_chunked_array__' in d:
return _unchunk(d)
else:
for k, v in d.items():
if isinstance(v, dict) and '__msgpack_chunked_array__' in v:
d[k] = _unchunk(v)
elif isinstance(v, dict):
_unchunk_array_leaves_in_place(v)
return d
[docs]def msgpack_serialize(pytree, in_place: bool = False) -> bytes:
"""Save data structure to bytes in msgpack format.
Low-level function that only supports python trees with array leaves,
for custom objects use `to_bytes`. It splits arrays above MAX_CHUNK_SIZE into
multiple chunks.
Args:
pytree: python tree of dict, list, tuple with python primitives
and array leaves.
in_place: boolean specifyng if pytree should be modified in place.
Returns:
msgpack-encoded bytes of pytree.
"""
if not in_place:
pytree = jax.tree_util.tree_map(lambda x: x, pytree)
pytree = _np_convert_in_place(pytree)
pytree = _chunk_array_leaves_in_place(pytree)
return msgpack.packb(pytree, default=_msgpack_ext_pack, strict_types=True)
[docs]def msgpack_restore(encoded_pytree: bytes):
"""Restore data structure from bytes in msgpack format.
Low-level function that only supports python trees with array leaves,
for custom objects use `from_bytes`.
Args:
encoded_pytree: msgpack-encoded bytes of python tree.
Returns:
Python tree of dict, list, tuple with python primitive
and array leaves.
"""
state_dict = msgpack.unpackb(
encoded_pytree, ext_hook=_msgpack_ext_unpack, raw=False)
return _unchunk_array_leaves_in_place(state_dict)