Source code for orbax.checkpoint.experimental.v1._src.context.context

# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Orbax context for customized checkpointing."""

from __future__ import annotations

from collections.abc import Iterable
import contextvars
import copy
import dataclasses
import typing
from typing import Any, Callable

from absl import logging
from etils import epy
from orbax.checkpoint.experimental.v1._src.context import options as options_lib


# Each Thread will have its own copy of `Context` object.
# Task and groups will have their own copy of `Context` object.
_CONTEXT: contextvars.ContextVar[Context] = contextvars.ContextVar(
    'orbax_context', default=None
)


def get_context(default: Context | None = None) -> Context:
  """Returns the currently active `Context`, or a default if no context is active.

  If called within a `with ocp.Context(...)` block, this function returns the
  `Context` object associated with that block (the active context).

  If called outside of any `with` block, this function returns `default`
  if it is provided. If `default` is not provided or `None`, it returns a
  new `Context` instance initialized with default options.

  Note: If a context is active, the `default` parameter is ignored, and the
  active context is always returned. To ensure that an explicitly provided
  context takes precedence over any active context, use the pattern:
  `ctx = explicit_context if explicit_context is not None else get_context()`.

  Args:
    default: A `Context` object to return if no context is active.

  Returns:
    The active `Context` or a default `Context`.
  """
  default = default or Context()
  return _CONTEXT.get(default)


def _get_option(
    opt_value: Any,
    parent_opt: Any,
    default_factory: Callable[[], Any],
) -> Any:
  """Resolves a configuration option dataclass during Context initialization.

  Enforces the following order of precedence:
    1. Direct keyword argument (`opt_value`): Returns a deep copy to insulate
      the context from external mutation.
    2. Parent inheritance (`parent_opt`): Returns a deep copy of the parent's
      option dataclass, ensuring the new child context is fully insulated from
      any future mutations to the parent context (and vice versa).
    3. Fallback (`default_factory`): Creates a fresh default instance.

  Args:
    opt_value: An explicitly provided option dataclass instance, or None.
    parent_opt: The corresponding option dataclass from a parent Context, or
      None.
    default_factory: A callable that produces a fresh default option instance.

  Returns:
    The resolved option dataclass instance.
  """
  if opt_value is not None:
    return copy.deepcopy(opt_value)
  if parent_opt is not None:
    return copy.deepcopy(parent_opt)
  return default_factory()


[docs] @typing.final class Context(epy.ContextManager): """Context for customized checkpointing. This class manages the configuration options (e.g., async, multiprocessing, array handling) used during Orbax checkpoint operations using a mutable namespace pattern. Creating a new :py:class:`.Context` within an existing :py:class:`.Context` sets all parameters from scratch by default. To inherit properties from a parent :py:class:`.Context`, pass the parent context as the first positional or explicit `context` keyword argument. The new context will inherit the parent's properties but can be mutated independently. WARNING: Context variables are thread-local by default. If you dispatch a checkpointing operation to a raw worker thread (e.g. `threading.Thread`), that thread will not inherit the context and will fall back to default settings. Furthermore, when background tasks or coroutines (e.g. `asyncio.create_task`, `save_async`) inherit the active context, they inherit a reference to the exact same underlying `Context` instance in memory. Consequently, if the main thread exits the `with ctx:` block and mutates the configuration of `ctx`, those changes will take effect immediately across any ongoing background asynchronous operations. To avoid unintended side effects, prefer creating a new `Context` instance (`ctx = ocp.Context(parent_ctx)`) for separate asynchronous operations rather than mutating a shared context instance mid-flight. Note: When testing or mixing checkpointer instances and free functions, explicitly wrap free functions inside their own `with ocp.Context(...)` block, or pass explicit contexts to Checkpointer constructors, to ensure each actor receives its correct active configuration independent of the surrounding context. Example: Basic usage and explicit inheritance:: from orbax.checkpoint import v1 as ocp # Basic usage ctx = ocp.Context() ctx.pytree.loading.partial_load = True with ctx: ocp.save(directory, tree) # Inheriting properties from an existing context ctx1 = ocp.Context() ctx1.pytree.loading.partial_load = True with ctx1 as outer_ctx: # inner_ctx inherits partial_load, but mutates array saving ctx2 = ocp.Context(outer_ctx) ctx2.array.saving.use_zarr3 = False with ctx2 as inner_ctx: ocp.save(directory, tree) Context is not shared across threads:: from concurrent.futures import ThreadPoolExecutor from orbax.checkpoint import v1 as ocp executor = ThreadPoolExecutor(max_workers=1) ctx = ocp.Context() ctx.pytree.loading.partial_load = True with ctx: # Thread #1 creates Context. # The following save call is executed in Thread #2, which sees # a "default" Context, NOT the one created above. executor.submit(ocp.save, directory, tree) Attributes: pytree: Options for PyTree checkpointing. See :class:`~orbax.checkpoint.experimental.v1.options.PyTreeOptions`. array: Options for saving and loading array (and array-like objects). See :class:`~orbax.checkpoint.experimental.v1.options.ArrayOptions`. asynchronous: Options for controlling asynchronous behavior. See :class:`~orbax.checkpoint.experimental.v1.options.AsyncOptions`. multiprocessing: Options for multiprocessing behavior. See :class:`~orbax.checkpoint.experimental.v1.options.MultiprocessingOptions`. file: Options for working with the file system. See :class:`~orbax.checkpoint.experimental.v1.options.FileOptions`. checkpointables: Options for controlling checkpointables behavior. See :class:`~orbax.checkpoint.experimental.v1.options.CheckpointablesOptions`. pathways: Options for Pathways checkpointing. See :class:`~orbax.checkpoint.experimental.v1.options.PathwaysOptions`. checkpoint_layout: The layout of the checkpoint. Defaults to ORBAX. See :class:`~orbax.checkpoint.experimental.v1.options.CheckpointLayout`. deletion: Options for controlling deletion behavior. See :class:`~orbax.checkpoint.experimental.v1.options.DeletionOptions`. memory: Options for controlling memory limits during save / load. See :class:`~orbax.checkpoint.experimental.v1.options.MemoryOptions`. """
[docs] def __init__( self, context: Context | None = None, *, pytree_options: options_lib.PyTreeOptions | None = None, array_options: options_lib.ArrayOptions | None = None, async_options: options_lib.AsyncOptions | None = None, multiprocessing_options: options_lib.MultiprocessingOptions | None = None, file_options: options_lib.FileOptions | None = None, checkpointables_options: options_lib.CheckpointablesOptions | None = None, pathways_options: options_lib.PathwaysOptions | None = None, checkpoint_layout: options_lib.CheckpointLayout | None = None, deletion_options: options_lib.DeletionOptions | None = None, memory_options: options_lib.MemoryOptions | None = None, safetensors_options: options_lib.SafetensorsOptions | None = None, ): if any( opt is not None for opt in ( pytree_options, array_options, async_options, multiprocessing_options, file_options, checkpointables_options, pathways_options, checkpoint_layout, deletion_options, memory_options, safetensors_options ) ): # TODO: b/513156122 - Passing option objects directly to Context.__init__ # is deprecated in favor of mutable dot-notation configuration (e.g. # ctx.array.saving...). Remove these keyword parameters. logging.warning( 'Passing direct option objects to Context.__init__ is deprecated' ' in favor of mutable dot-notation configuration (e.g.' ' ctx.array.saving.use_ocdbt = ...). These keyword arguments will be' ' removed in a future release.' ) self._pytree_options = _get_option( pytree_options, context.pytree_options if context is not None else None, options_lib.PyTreeOptions, ) self._array_options = _get_option( array_options, context.array_options if context is not None else None, options_lib.ArrayOptions, ) self._async_options = _get_option( async_options, context.async_options if context is not None else None, options_lib.AsyncOptions, ) self._multiprocessing_options = _get_option( multiprocessing_options, context.multiprocessing_options if context is not None else None, options_lib.MultiprocessingOptions, ) self._file_options = _get_option( file_options, context.file_options if context is not None else None, options_lib.FileOptions, ) self._checkpointables_options = _get_option( checkpointables_options, context.checkpointables_options if context is not None else None, options_lib.CheckpointablesOptions, ) self._pathways_options = _get_option( pathways_options, context.pathways_options if context is not None else None, options_lib.PathwaysOptions, ) self._checkpoint_layout = _get_option( checkpoint_layout, context.checkpoint_layout if context is not None else None, lambda: options_lib.CheckpointLayout.ORBAX, ) self._deletion_options = _get_option( deletion_options, context.deletion_options if context is not None else None, options_lib.DeletionOptions, ) self._memory_options = _get_option( memory_options, context.memory_options if context is not None else None, options_lib.MemoryOptions, ) self._safetensors_options = _get_option( safetensors_options, context.safetensors_options if context is not None else None, options_lib.SafetensorsOptions, )
def _check_not_frozen(self) -> None: if id(self) in options_lib.FROZEN_IDS.get(): raise RuntimeError( 'Cannot mutate an active Context. Ensure all configuration options' ' are set before entering the `with` context block.' ) @property def array(self) -> options_lib.ArrayOptions: return self._array_options @property def asynchronous(self) -> options_lib.AsyncOptions: return self._async_options @property def pytree(self) -> options_lib.PyTreeOptions: return self._pytree_options @property def file(self) -> options_lib.FileOptions: return self._file_options @property def multiprocessing(self) -> options_lib.MultiprocessingOptions: return self._multiprocessing_options @property def checkpointables(self) -> options_lib.CheckpointablesOptions: return self._checkpointables_options @property def pathways(self) -> options_lib.PathwaysOptions: return self._pathways_options @property def deletion(self) -> options_lib.DeletionOptions: return self._deletion_options @property def memory(self) -> options_lib.MemoryOptions: return self._memory_options @property def safetensors(self) -> options_lib.SafetensorsOptions: return self._safetensors_options @property def checkpoint_layout(self) -> options_lib.CheckpointLayout: return self._checkpoint_layout @checkpoint_layout.setter def checkpoint_layout(self, value: options_lib.CheckpointLayout) -> None: self._check_not_frozen() self._checkpoint_layout = value # TODO: b/513156122 - Migrate internal read sites to short-hand properties and # remove legacy aliases in the next refactor. # --- Legacy aliases for internal read access compatibility --- @property def pytree_options(self) -> options_lib.PyTreeOptions: return self._pytree_options @property def array_options(self) -> options_lib.ArrayOptions: return self._array_options @property def async_options(self) -> options_lib.AsyncOptions: return self._async_options @property def multiprocessing_options(self) -> options_lib.MultiprocessingOptions: return self._multiprocessing_options @property def file_options(self) -> options_lib.FileOptions: return self._file_options @property def checkpointables_options(self) -> options_lib.CheckpointablesOptions: return self._checkpointables_options @property def pathways_options(self) -> options_lib.PathwaysOptions: return self._pathways_options @property def deletion_options(self) -> options_lib.DeletionOptions: return self._deletion_options @property def memory_options(self) -> options_lib.MemoryOptions: return self._memory_options @property def safetensors_options(self) -> options_lib.SafetensorsOptions: return self._safetensors_options def __contextmanager__(self) -> Iterable[Context]: option_ids = _collect_ids(self) prev_frozen = options_lib.FROZEN_IDS.get() guard_token = options_lib.FROZEN_IDS.set(prev_frozen | option_ids) token = _CONTEXT.set(self) try: yield self finally: _CONTEXT.reset(token) options_lib.FROZEN_IDS.reset(guard_token)
def _collect_ids(ctx: Context) -> frozenset[int]: """Collects all object ids from the context and its options. This function traverses the context object and all its attributes. This is used to freeze the context and all its options, so that they cannot be modified after the context is entered. Args: ctx: The context object to collect ids from. Returns: A frozenset of all object ids from the context and its options. """ ids = {id(ctx)} def _traverse(obj: typing.Any) -> None: if id(obj) in ids: return if dataclasses.is_dataclass(obj): ids.add(id(obj)) for field in dataclasses.fields(obj): _traverse(getattr(obj, field.name)) elif isinstance(obj, (list, tuple, set)): for item in obj: _traverse(item) elif isinstance(obj, dict): for value in obj.values(): _traverse(value) for obj in vars(ctx).values(): _traverse(obj) return frozenset(ids)