Source code for orbax.checkpoint._src.handlers.json_checkpoint_handler

# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""JsonCheckpointHandler class.

Implementation of CheckpointHandler interface.
"""

from __future__ import annotations

import dataclasses
import json
from typing import Any, List, Mapping, Optional

from etils import epath
from orbax.checkpoint import checkpoint_args
from orbax.checkpoint import options as options_lib
from orbax.checkpoint import utils
from orbax.checkpoint._src import asyncio_utils
from orbax.checkpoint._src.futures import future
from orbax.checkpoint._src.handlers import async_checkpoint_handler

CheckpointArgs = checkpoint_args.CheckpointArgs
register_with_handler = checkpoint_args.register_with_handler


[docs] class JsonCheckpointHandler(async_checkpoint_handler.AsyncCheckpointHandler): """Saves nested dictionary using json."""
[docs] def __init__( self, filename: Optional[str] = None, *, multiprocessing_options: options_lib.MultiprocessingOptions = options_lib.MultiprocessingOptions(), ): """Initializes JsonCheckpointHandler. Args: filename: optional file name given to the written file; defaults to 'metadata' multiprocessing_options: See orbax.checkpoint.options. """ self._filename = filename or 'metadata' self._primary_host = multiprocessing_options.primary_host
async def _save_fn(self, x, directory): if utils.is_primary_host(self._primary_host): path = directory / self._filename path.write_text(json.dumps(x))
[docs] async def async_save( self, directory: epath.Path, item: Optional[Mapping[str, Any]] = None, args: Optional[JsonSaveArgs] = None, ) -> Optional[List[future.Future]]: """Saves the given item. Args: directory: save location directory. item: Deprecated, use `args` instead. args: JsonSaveArgs (see below). Returns: A list of commit futures. """ if isinstance(item, CheckpointArgs): raise ValueError( 'Make sure to specify kwarg name `args=` when providing' ' `JsonSaveArgs`.' ) if args is not None: item = args.item return [ future.CommitFutureAwaitingContractedSignals( self._save_fn(item, directory), name='json_ch_save' ) ]
[docs] def save( self, directory: epath.Path, item: Optional[Mapping[str, Any]] = None, args: Optional[JsonSaveArgs] = None, ): async def async_save(directory, item, args): commit_futures = await self.async_save(directory, item, args) if commit_futures: for f in commit_futures: f.result() asyncio_utils.run_sync(async_save(directory, item, args))
[docs] def restore( self, directory: epath.Path, item: Optional[Mapping[str, Any]] = None, args: Optional[JsonRestoreArgs] = None, ) -> Mapping[str, Any]: """Restores json mapping from directory. `item` is unused. Args: directory: restore location directory. item: unused args: unused Returns: JSON dict. Raises: FileNotFoundError: if the file does not exist. """ del item del args path = directory / self._filename if not path.exists(): raise FileNotFoundError(f'File {path} not found.') return json.loads(path.read_text())
[docs] @register_with_handler(JsonCheckpointHandler, for_save=True) @dataclasses.dataclass class JsonSaveArgs(CheckpointArgs): """Parameters for saving to json. Attributes: item (required): a nested dictionary. """ item: Mapping[str, Any]
[docs] @register_with_handler(JsonCheckpointHandler, for_restore=True) @dataclasses.dataclass class JsonRestoreArgs(CheckpointArgs): """Json restore args. Attributes: item: unused, but included for legacy-compatibility reasons. New code should not set this attribute. """ item: Optional[bytes] = None