Customizing Checkpointing Behavior#
Orbax allows users to specify their own logic for dealing with custom “Checkpointables”.
Custom Checkpointables#
First, ensure that you are familiar with the documentation on checkpointables. To recap, a “checkpointable” is a distinct unit of an entire checkpoint. For example, the model state is a checkpointable distinct from the dataset iterator. Embeddings, if used, may also be represented as a separate checkpointable.
Let us consider a toy example. Let’s say that in addition to our PyTree state
(represented as a dictionary of arrays, containing the parameters and optimizer
state) and our dataset iteration (represented using PyGrain), we also have an
object called Point, which has integer properties x and y. (Obviously,
since this object is a dataclass, it would be easy to just convert this to a
PyTree, and save it in the same way as the primary model state. So this example
is a bit contrived, but demonstrates the point well enough.)
Our Point class is defined as follows.
import dataclasses
import json
from typing import Any, Awaitable
import aiofiles
import jax
import numpy as np
import orbax.checkpoint.experimental.v1 as ocp
@dataclasses.dataclass
class Point:
x: int
y: int
model_state = {
'params': np.arange(16),
'opt_state': np.ones(16),
}
If we just try to save the Point (along with our other checkpointables), it
will fail because the object type is not recognized.
try:
ocp.save_checkpointables(
'/tmp/customization/ckpt1',
dict(model_state=model_state, point=Point(1, 2)),
)
except BaseException as e:
print(e)
'Could not identify a valid handler for the checkpointable: "point" and checkpointable type=<class \'__main__.Point\'>. Make sure to register a `CheckpointableHandler` for the object using `register_handler`, or by specifying a local registry (`CheckpointablesOptions`). If a handler is already registered, ensure that `is_handleable` correctly identifies the object as handleable. The available handlers are: [<class \'orbax.checkpoint.experimental.v1._src.handlers.proto_handler.ProtoHandler\'>, <class \'orbax.checkpoint.experimental.v1._src.handlers.json_handler.JsonHandler\'>, <class \'orbax.checkpoint.experimental.v1._src.handlers.stateful_checkpointable_handler.StatefulCheckpointableHandler\'>, <class \'orbax.checkpoint.experimental.v1._src.handlers.json_handler.MetricsHandler\'>, <class \'orbax.checkpoint.experimental.v1._src.handlers.leaf_handler.ShardedArrayHandler\'>, <class \'orbax.checkpoint.experimental.v1._src.handlers.leaf_handler.ArrayHandler\'>, <class \'orbax.checkpoint.experimental.v1._src.handlers.leaf_handler.ScalarHandler\'>, <class \'orbax.checkpoint.experimental.v1._src.handlers.leaf_handler.StringHandler\'>, <class \'orbax.checkpoint.experimental.v1._src.handlers.pytree_handler.PyTreeHandler\'>, <class \'orbax.checkpoint.experimental.v1._src.handlers.pytree_handler.PyTreeHandler\'>]'
/tmp/ipykernel_4129/2090599941.py:7: RuntimeWarning: coroutine '_create_paths' was never awaited
print(e)
RuntimeWarning: Enable tracemalloc to get the object allocation traceback
There are two possible approaches for implementing support for Point in Orbax.
We will start with the simpler of the two.
Implementing Point as a StatefulCheckpointable#
The Point object must implement the methods of the StatefulCheckpointable
Protocol. We need to implement save and load methods so that Orbax will know
how to deal with the Point object.
from __future__ import annotations
del Point
@dataclasses.dataclass
class Point(ocp.StatefulCheckpointable):
x: int
y: int
async def save(
self, directory: ocp.path.PathAwaitingCreation
) -> Awaitable[None]:
return self._background_save(
directory,
# If the object could be modified by the main thread while being
# written, it is important to make a copy to prevent race conditions.
dataclasses.asdict(self),
)
async def load(self, directory: ocp.path.Path) -> Awaitable[None]:
return self._background_load(directory)
async def _background_save(
self,
directory: ocp.path.PathAwaitingCreation,
value: dict[str, int],
):
# In a multiprocess setting, prevent multiple processes from writing the
# same thing.
if jax.process_index() == 0:
directory = await directory.await_creation()
async with aiofiles.open(directory / 'point.txt', 'w') as f:
contents = json.dumps(value)
await f.write(contents)
async def _background_load(
self,
directory: ocp.path.Path,
):
async with aiofiles.open(directory / 'point.txt', 'r') as f:
contents = json.loads(await f.read())
self.x = contents['x']
self.y = contents['y']
Let’s break this down.
Both save and load methods consist of two phases: blocking and non-blocking.
Blocking operations must execute now, before returning control to the caller.
Non-blocking operations may occur in a background thread, and are represented by
an Awaitable function returned back to the caller without being executed
(yet).
When saving, in the case of Point, we make a copy of the properties to prevent
them from being concurrently modified by the main thread while we are writing
them in the background thread. For a jax.Array, we would similarly need to
perform a transfer from device memory to host memory. When the blocking
operations complete, we can construct an awaitable function that writes the
values to a file. Note also that we must wait for the parent directory to be
created, since upper layers of Orbax have already scheduled this execution
asynchronously.
Loading is similar. Typically there are fewer operations that need to happen
synchronously, as the caller should know they cannot do anything with the object
until it is fully loaded. Again, the awaitable function that is run in the
background should return nothing, and instead set relevant properties in self
after loading from disk.
Now we can successfully save the Point.
ocp.save_checkpointables(
'/tmp/customization/ckpt1',
dict(model_state=model_state, point=Point(1, 2)),
)
It is important to note that because Point is a stateful checkpointable, we
have to provide a Point object in order to restore it. In typical usage, we
should construct a Point object with “uninitialized” values. Calling
load_checkpointables then updates the provided object as a side effect (it
also returns it).
uninitialized_point = Point(0, 0)
ocp.load_checkpointables(
'/tmp/customization/ckpt1',
dict(point=uninitialized_point),
)
uninitialized_point
Point(x=1, y=2)
Supporting Point with CheckpointableHandler#
While StatefulCheckpointable has a simple and powerful interface, it may not
be the right fit in every case. StatefulCheckpointable may be insufficient in
cases such as:
Pointmay be defined in some third-party library that we cannot easily control, and thus could not directly addsaveandloadmethods to the class itself.When loading, users might need to customize loading behavior in a more dynamic way. For a
jax.Array, resharding, casting, and reshaping are common operations. For aPoint, users might want to castxandybetweenintandfloatmore dynamically.We may have multiple different ways to save and load
Pointthat users want to enable in different contexts. In such cases, placing all that different logic within the singlePointclass may add too much complexity.
For such cases (and others), Orbax provides an interface called
CheckpointableHandler.
First, let’s redefine our Point class and also introduce an AbstractPoint
class. This allows us to specify the type of x or y that should be used for
loading.
del Point
import asyncio
from typing import Type
Scalar = int | float
@dataclasses.dataclass
class Point:
x: Scalar
y: Scalar
@dataclasses.dataclass
class AbstractPoint:
x: Type[Scalar]
y: Type[Scalar]
async def _write_point(
directory: ocp.path.Path, checkpointable: dict[str, Scalar]
):
async with aiofiles.open(directory / 'point.txt', 'w') as f:
contents = json.dumps(checkpointable)
await f.write(contents)
async def _write_point_metadata(
directory: ocp.path.Path, checkpointable: dict[str, Scalar]
):
async with aiofiles.open(directory / 'point_metadata.txt', 'w') as f:
contents = json.dumps(
{k: type(v).__name__ for k, v in checkpointable.items()}
)
await f.write(contents)
class PointHandler(ocp.CheckpointableHandler[Point, AbstractPoint]):
async def _background_save(
self,
directory: ocp.path.PathAwaitingCreation,
checkpointable: dict[str, Scalar],
):
if jax.process_index() == 0:
directory = await directory.await_creation()
await asyncio.gather(
_write_point(directory, checkpointable),
_write_point_metadata(directory, checkpointable),
)
async def _background_load(
self,
directory: ocp.path.Path,
abstract_checkpointable: AbstractPoint | None = None,
) -> Point:
async with aiofiles.open(directory / 'point.txt', 'r') as f:
contents = json.loads(await f.read())
if abstract_checkpointable is None:
return Point(**contents)
else:
return Point(
abstract_checkpointable.x(contents['x']),
abstract_checkpointable.y(contents['y']),
)
async def save(
self,
directory: ocp.path.PathAwaitingCreation,
checkpointable: Point,
) -> Awaitable[None]:
return self._background_save(directory, dataclasses.asdict(checkpointable))
async def load(
self,
directory: ocp.path.Path,
abstract_checkpointable: AbstractPoint | None = None,
) -> Awaitable[Point]:
return self._background_load(directory, abstract_checkpointable)
async def metadata(self, directory: ocp.path.Path) -> AbstractPoint:
async with aiofiles.open(directory / 'point_metadata.txt', 'r') as f:
contents = json.loads(await f.read())
return AbstractPoint(
**{k: getattr(__builtins__, v) for k, v in contents.items()}
)
def is_handleable(self, checkpointable: Any) -> bool:
return isinstance(checkpointable, Point)
def is_abstract_handleable(self, abstract_checkpointable: Any) -> bool:
return isinstance(abstract_checkpointable, AbstractPoint)
This class associates itself with two types, the Checkpointable and the
AbstractCheckpointable (Point and AbstractPoint in this case). Point is
the input for saving, and AbstractPoint (or None) is the input for loading
(as well as the parent directory in both cases).
Saving logic in this class is essentially the same as in our
StatefulCheckpointable definition above.
Loading is different because loading is no longer stateful - it instead accepts
an optional AbstractPoint and returns a newly constructed Point. Providing
None as the input indicates that the object should simply be restored exactly
as it was saved. (Note that for some objects, this may not be possible, and it
may be necessary to raise an error if some input from the user is required to
know how to load.) Otherwise, the provided AbstractCheckpointable serves as
the guide describing how the concrete loaded object (Point in this case)
should be constructed.
We also have the capability of defining a metadata method in this class. In
the case of Point, the object is obviously quite lightweight already. For real
use cases, the checkpoint itself may be expensive to load fully, and some
metadata describing important properties that can be loaded cheaply is
essential. The metadata method should return an instance of
AbstractCheckpointable.
Finally, two additional methods, is_handleable and is_abstract_handleable
should be defined. These methods accept any object, and decide whether the given
object is an acceptable input for saving or loading, respectively. In most
cases, a simple isinstance check will suffice, but for more generic
constructs, like PyTrees, more involved logic is necessary.
We can now register PointHandler in order to deal with Point objects.
ocp.handlers.register_handler(PointHandler)
__main__.PointHandler
ocp.save_checkpointables(
'/tmp/customization/ckpt2',
dict(model_state=model_state, point=Point(1, 2.4)),
)
Since the AbstractPoint is optional, we do not need to specify any arguments
to load everything successfully.
ocp.load_checkpointables('/tmp/customization/ckpt2')
{'point': Point(x=1, y=2.4),
'model_state': {'opt_state': array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]),
'params': array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])}}
However, if desired, we can specify an abstract checkpointable to customize the dtypes of the restored values.
ocp.load_checkpointables(
'/tmp/customization/ckpt2', dict(point=AbstractPoint(x=float, y=int))
)
{'point': Point(x=1.0, y=2)}
We can use checkpointables_metadata to load the metadata, in the form of an
AbstractPoint.
ocp.checkpointables_metadata('/tmp/customization/ckpt2').metadata['point']
AbstractPoint(x=<class 'int'>, y=<class 'float'>)