# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""OCDBT utilities for Orbax checkpointing."""
import asyncio
import re
import threading
from typing import Optional, Union
from absl import logging
from etils import epath
from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils
import tensorstore as ts
_SHARDING_SUFFIX_RE = r'/\d+(\.\d+)*$' # /0, /0.0, /1.0.1, etc.
_ZARRAY_SUFFIX_RE = r'/\.zarray$'
_ZARRAY_SUFFIX = '/.zarray'
async def _validate_params(
ts_kv_store: ts.KvStore,
use_zarr3: bool,
) -> None:
"""Validates the params present in tensorstore KvStore.
Supports zarr2.
NOTE: Support for zarr3 will be added later.
Args:
ts_kv_store: Open kvstore to validate, with transaction if applicable.
use_zarr3: If True, use zarr3 driver, otherwise, use zarr driver.
"""
# TODO: b/362328389 - Add support for zarr3.
if use_zarr3:
logging.info(
'Param validation support for Zarr3 will be added later (b/362328389).'
)
return
raw_ts_params = await ts_kv_store.list()
if not raw_ts_params:
# TODO: b/361090820 - Raise error once we confirm that Bennu writing empty
# states is a bug.
# e.g. //learning/deepmind/jax/roc/formats/roc_orbax:roc_orbax_test
logging.info(
'Skipping param validation: No params found in TensorStore'
' KvStore: %s.',
ts_kv_store,
)
return
process_index = multihost.process_index()
current_thread_name = threading.current_thread().name
# [a/.zarray, a/0, b.zarray, b/0.0, b/0.1, c/0, d/.zarray] -> {a, b, d}
with_zarray = set()
# [a/.zarray, a/0, b.zarray, b/0.0, b/0.1, c/0, d/.zarray] -> {a, b, c}
without_zarray = set()
for ts_param in raw_ts_params:
ts_param = ts_param.decode('utf-8')
if logging.vlog_is_on(1):
logging.vlog(
1,
'[process=%s][thread=%s] Validating raw param: %s',
process_index,
current_thread_name,
ts_param,
)
# b/0.0 -> b, a/0 -> a, a/.zarray -> a/.zarray
ts_param = re.sub(_SHARDING_SUFFIX_RE, '', ts_param)
if ts_param.endswith(_ZARRAY_SUFFIX):
# a/.zarray -> a
ts_param = re.sub(_ZARRAY_SUFFIX_RE, '', ts_param)
with_zarray.add(ts_param)
if logging.vlog_is_on(1):
logging.vlog(
1,
'[process=%s][thread=%s] Collecting param with .zarray: %s',
process_index,
current_thread_name,
ts_param,
)
else:
# b -> b
without_zarray.add(ts_param)
if logging.vlog_is_on(1):
logging.vlog(
1,
'[process=%s][thread=%s] Collecting param without .zarray: %s',
process_index,
current_thread_name,
ts_param,
)
unique = with_zarray | without_zarray
logging.vlog(
1,
'[process=%s][thread=%s] Validating params in TensorStore KvStore.',
process_index,
current_thread_name,
)
missing_params = unique - without_zarray
if missing_params:
formatted_missing_params = ' \n'.join(sorted(missing_params))
raise ValueError(
f'Save failed: {len(missing_params)}/{len(unique)} params are missing'
f' in checkpoint:\n{formatted_missing_params}.\nTensorstore KvStore:'
f' {ts_kv_store}.'
)
missing_zarrays = unique - with_zarray
if missing_zarrays:
formatted_missing_zarrays = ' \n'.join(sorted(missing_zarrays))
raise ValueError(
f'Save failed: {len(missing_zarrays)}/{len(unique)} params are missing'
f' .zarray in checkpoint:\n{formatted_missing_zarrays}.\nTensorstore'
f' KvStore: {ts_kv_store}.'
)
[docs]
async def merge_ocdbt_per_process_files(
directory: epath.Path,
ts_context: ts.Context,
use_zarr3: bool,
enable_validation: bool = True,
):
"""Merges OCDBT files written to per-process subdirectories.
With Tensorstore's OCDBT format, arrays are initially written to per-process
subdirectories, depending on which host is doing the writing. This function
can be called to merge the per-process files into a global key-value store.
The original per-process subdirectories are not and should not be deleted -
the global kvstore continues to reference them.
NOTE: If no suitable subdirs with OCDBT checkpoints are found, this function
does not raise any error and no merged checkpoint is created.
Args:
directory: checkpoint location.
ts_context: Tensorstore context.
use_zarr3: If True, use zarr3 driver, otherwise, use zarr driver for params
validation.
enable_validation: If True, validate params after merging. May have a
performance impact.
"""
open_ops = []
for process_dir in directory.glob(f'{ts_utils.PROCESS_SUBDIR_PREFIX}*'):
child_kvstore_tspec = ts_utils.build_kvstore_tspec_for_merge(
directory.as_posix(),
process_dir.name,
)
logging.vlog(1, 'child_kvstore_tspec: %s', child_kvstore_tspec)
open_ops.append(ts_utils.open_kv_store(child_kvstore_tspec, ts_context))
if not open_ops: # No per-process OCDBT checkpoint found!
logging.warning(
'[process=%s][thread=%s] Skipping merge of OCDBT checkpoints: No'
' per-process OCDBT checkpoint subdirs found in %s, ',
multihost.process_index(),
threading.current_thread().name,
directory,
)
return
parent_kvstore_tspec = ts_utils.build_kvstore_tspec(
directory.as_posix(), use_ocdbt=True
)
ts_utils.add_ocdbt_write_options(parent_kvstore_tspec)
open_ops.append(ts_utils.open_kv_store(parent_kvstore_tspec, ts_context))
opened = await asyncio.gather(*open_ops)
parent, children = opened[-1], opened[:-1]
copy_ops = []
txn = ts.Transaction(atomic=True)
for child in children:
copy_ops.append(
child.experimental_copy_range_to(parent.with_transaction(txn))
)
await asyncio.gather(*copy_ops)
# Validate merged params.
if enable_validation:
await _validate_params(parent.with_transaction(txn), use_zarr3=use_zarr3)
await txn.commit_async()
def get_process_index_for_subdir(
use_ocdbt: bool,
override_ocdbt_process_id: Optional[str] = None,
) -> Optional[Union[int, str]]:
"""If OCDBT + merge feature is in use, returns a process index."""
if use_ocdbt:
return override_ocdbt_process_id or multihost.process_index()
else:
return None