Source code for orbax.checkpoint.transform_utils

# Copyright 2024 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Provides utils for transforming PyTrees from one version to another."""

import dataclasses
import functools
import operator
import re
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union

from absl import logging
from orbax.checkpoint import type_handlers
from orbax.checkpoint import utils

PyTree = Any
ValueFn = Callable[[Any], Any]
MultiValueFn = Callable[[str, PyTree], Any]
RestoreArgs = type_handlers.RestoreArgs


[docs]@dataclasses.dataclass class Transform: r"""A representation of a transformation applied to pytree keys/values. See `apply_transformations` for usage examples. Transform represents an operation on a single key/value pair. For example, the following mapping:: {'a': Transform(original_key='b')} This denotes that the original key was named 'b', but we are changing it to 'a'. A regex can also be used as follows:: {r'(.*)a(.*)': Transform(original_key=r'\1b\2'} This denotes that the key 'b' should be renamed to 'a'. This may apply to multiple different keys at different levels of nesting. The '/' character denotes a successive level of nesting. We also have the following example:: {'a': Transform(multi_value_fn=lambda kv: kv['b'] * 2)} This signifies that the new key 'a' is the old key 'b' multiplied by two. original_key: Denotes the original name of the key. Represented as a string with '/' denoting successive levels of nesting. If the key corresponding to this Transform is a regex, backreferences (such as \1) will be replaced with the appropriate matched group in the regex. Note: not needed if multi_value_fn is provided. use_fallback: If True, takes the value from the fallback tree. If `default_to_original=True` in `apply_transformations`, the fallback tree is `new_tree`. If `default_to_original=False` in `apply_transformations`, the fallback tree is `original_tree`. value_fn: A function accepting a single value and returning a single value. The value provided as an argument is the value of the transformation key in the original PyTree. multi_value_fn: A function accepting a string and PyTree and returning any value. The string is the result key associated with the returned value, so the function implementation can know for which key it is supposed to return a value for. The PyTree argument will be the original PyTree, and the function should return the value of the key in the new PyTree. multi_value_fn_input_args: A dict of key name (in the original tree) to required input arguments (typically `RestoreArgs` - see `PyTreeCheckpointHandler`). These arguments are not used directly in `apply_transformations`, but are necessary when applying transformations when restoring from a checkpoint in `PyTreeCheckpointHandler`. These arguments identify "dependencies" in the original tree (the checkpoint) which are needed as inputs by the function, and provides additional information needed for restoration. IMPORTANT: using multi_value_fn during `PyTreeCheckpointHandler.restore` REQUIRES inputs to be identified. """ original_key: Optional[Union[str, Tuple[str]]] = None use_fallback: bool = False value_fn: Optional[ValueFn] = None multi_value_fn: Optional[MultiValueFn] = None def __post_init__(self): if self.original_key is not None: assert not self.use_fallback assert self.multi_value_fn is None if self.use_fallback: assert self.original_key is None assert self.value_fn is None assert self.multi_value_fn is None if self.value_fn is not None: assert not self.use_fallback assert self.multi_value_fn is None if self.multi_value_fn is not None: assert self.original_key is None assert not self.use_fallback assert self.value_fn is None
@dataclasses.dataclass class RestoreTransform(Transform): """Transform subclass used only during restoration from checkpoint. value_fn: Same as value_fn in the parent class, but also accepts RestoreArgs as an argument. The returned value should take into account the information provided by RestoreArgs. multi_value_fn: Same as multi_value_fn in the parent class, but also accepts RestoreArgs as an argument. The returned value should take into account the information provided by RestoreArgs. multi_value_fn_input_args: A dict of key name (in the original tree) to required input arguments (typically `RestoreArgs` - see `PyTreeCheckpointHandler`). These arguments are not used directly in `apply_transformations`, but are necessary when applying transformations when restoring from a checkpoint in `PyTreeCheckpointHandler`. These arguments identify "dependencies" in the original tree (the checkpoint) which are needed as inputs by the function, and provides additional information needed for restoration. IMPORTANT: using multi_value_fn during `PyTreeCheckpointHandler.restore` REQUIRES inputs to be identified. """ value_fn: Optional[Callable[[Any, RestoreArgs], Any]] = None multi_value_fn: Optional[Callable[[str, PyTree, RestoreArgs], Any]] = None multi_value_fn_input_args: Optional[Dict[str, Any]] = None def __post_init__(self): super().__post_init__() if self.original_key is not None: assert self.multi_value_fn_input_args is None if self.use_fallback: assert self.multi_value_fn_input_args is None if self.value_fn is not None: assert self.multi_value_fn_input_args is None if self.multi_value_fn_input_args is not None: assert self.original_key is None assert not self.use_fallback assert self.value_fn is None assert self.multi_value_fn is not None # TODO(b/233407026) Add additional error checking.
[docs]def apply_transformations(original_tree: PyTree, transformations: PyTree, new_tree: PyTree, default_to_original: Optional[bool] = True) -> PyTree: r"""Applies transformations to a pytree. Also uses `transformations` to provide structure to the output tree. Example:: original_tree = { 'a': 1, 'b': {'c': 5, 'd': [0, 1, 2, 3]}, 'f': 2, 'b1': {'c': 2}, 'b2': {'c': 3}, } transformations = { 'a1': Transform(original_key='a'), # rename # another way of doing above 'a1': Transform(multi_value_fn=lambda kv: kv['a']), 'b': { # doubled original, and drop b/d 'c': Transform(multi_value_fn=lambda kv: kv['b']['c'] * 2) }, # Copy original into multiple new keys 'c1': Transform(original_key='b/c'), 'c2': Transform(original_key='b/c'), # one to many mapping 'x': Transform(multi_value_fn=lambda kv: kv['b']['d'][0]), 'y': Transform(multi_value_fn=lambda kv: kv['b']['d'][1:]), # many to one mapping 'z': Transform(multi_value_fn=lambda kv: kv['a'] * 2 + sum(kv['b']['d'])), r'x(\d.*)': Transform(original_key=r'b\1') } # defines the structure of the result new_tree = { 'a1': ..., 'a1': ..., 'b': {'c': ...}, 'c1': ..., 'c2': ..., 'x': ..., 'y': ..., 'z': ..., # defined in original_tree and new_tree, but not in transforms. Value # carried over from original_tree. 'f': ..., # This value matters since it is not present in original_tree or # transformations, so the value here will simply be preserved in the # result. 'g': 5, # These are just 'b1', 'b2', but renamed to 'x1', 'x2', with all values # copied over. 'x1': {'c': 2} 'x2': {'c': 3} } Args: original_tree: a PyTree to be transformed. transformations: a PyTree of Transform objects. new_tree: a PyTree defining the structure of the output. A leaf value is only relevant if the key is not present in transformations or original_tree. Note: values in the provided tree must not be None, or they will be filtered out. default_to_original: If True, the values of keys unspecified in transformations will be taken from `original_tree`. If False, they will be taken from `new_tree`. Returns: a transformed PyTree with the structure of `new_tree` """ logging.warning( 'The transformations API will eventually be replaced by an upgraded' ' design. The current API will not be removed until this point, but it' ' will no longer be actively worked on.', ) if not new_tree: return {} original = utils.to_flat_dict(original_tree, sep='/') new = utils.to_flat_dict(new_tree, sep='/') transforms = utils.to_flat_dict(transformations, sep='/') unmatched_new_keys = [] for key in new: transform_found = False for transform_key, transform in transforms.items(): match = re.fullmatch(transform_key, key) if match: transform_found = True if transform.use_fallback: if not default_to_original: if key not in original: raise ValueError( f'{key} not found in origin tree (`use_fallback` requested).') new[key] = original[key] # else simply retain new[key] continue if not (transform.multi_value_fn is None or transform.value_fn is None): raise ValueError( f'Cannot provide both multi_value_fn and value_fn in {transform}') if transform.multi_value_fn is None: if transform.original_key is None: original_key = key else: original_key = match.expand(transform.original_key) if original_key not in original: raise ValueError( f'Transformation key "{original_key}" not found in origin tree.' ) if transform.value_fn is None: value_fn = lambda x: x else: value_fn = transform.value_fn new[key] = value_fn(original[original_key]) else: new[key] = transform.multi_value_fn(key, original_tree) if not transform_found: if key in original: if default_to_original: # carry over directly from original, otherwise use value from new new[key] = original[key] # if default_to_new, do not carry over key from original else: unmatched_new_keys.append(key) if unmatched_new_keys: logging.info('The following keys are not loaded from the original tree ' 'after applying specified transforms: %s', ', '.join(unmatched_new_keys)) return utils.from_flat_dict(new, target=new_tree, sep='/')
[docs]def merge_trees( *trees: Sequence[PyTree], target: Optional[PyTree] = None ) -> PyTree: """Merges the provided PyTrees into a single result. If trees have overlapping keys, the key of the last tree in the list will take precedence. Args: *trees: PyTrees to merge. target: A PyTree to provide structure for the returned value. If not provided, the result will take the form of a dictionary. Returns: A single merged PyTree. """ trees = [utils.to_flat_dict(t) for t in trees] merged = functools.reduce(operator.ior, trees, {}) return utils.from_flat_dict(merged, target=target)
def intersect_trees( *trees: Sequence[PyTree], target: Optional[PyTree] = None ) -> PyTree: """Intersects the provided trees, dropping any keys not in common between all. For overlapping keys, the key of the last tree in the list will take precedence. Args: *trees: PyTrees to intersect. target: A PyTree to provide structure for the returned value. If not provided, the result will take the form of a dictionary. Returns: A single intersected PyTree. """ trees = [utils.to_flat_dict(t) for t in trees] tree_keys = set.intersection(*[set(t.keys()) for t in trees]) return utils.from_flat_dict( {k: trees[-1][k] for k in tree_keys}, target=target )