Skip to content

Commit

Permalink
Support custom PyTree metadata.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 704424560
  • Loading branch information
cpgaffney1 authored and Orbax Authors committed Jan 3, 2025
1 parent 605185a commit 86857c0
Show file tree
Hide file tree
Showing 8 changed files with 112 additions and 15 deletions.
3 changes: 3 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
make the change unnoticeable to most users, but also has additional accessible
properties not included in any tree mapping operations.

### Added
- User-provided custom PyTree metadata.


## [0.11.0] - 2024-12-30

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils
from orbax.checkpoint._src.serialization import type_handlers
from orbax.checkpoint._src.serialization import types
from orbax.checkpoint._src.tree import types as tree_types
from orbax.checkpoint._src.tree import utils as tree_utils
import tensorstore as ts

Expand Down Expand Up @@ -436,6 +437,7 @@ async def async_save(
raise ValueError('Found empty item.')
save_args = args.save_args
ocdbt_target_data_file_size = args.ocdbt_target_data_file_size
custom_metadata = args.custom_metadata

save_args = _fill_missing_save_or_restore_args(item, save_args, mode='save')
byte_limiter = serialization.get_byte_limiter(self._save_concurrent_bytes)
Expand Down Expand Up @@ -476,7 +478,11 @@ async def async_save(
if multihost.is_primary_host(self._primary_host):
commit_futures.append(
self._write_metadata_file(
directory, param_infos, save_args, self._use_zarr3
directory,
param_infos=param_infos,
save_args=save_args,
custom_metadata=custom_metadata,
use_zarr3=self._use_zarr3,
)
)

Expand Down Expand Up @@ -728,8 +734,10 @@ class TrainState:
def _write_metadata_file(
self,
directory: epath.Path,
*,
param_infos: PyTree,
save_args: PyTree,
custom_metadata: tree_types.JsonType | None = None,
use_zarr3: bool = False,
) -> future.Future:
def _save_fn():
Expand All @@ -740,6 +748,7 @@ def _save_fn():
param_infos,
save_args=save_args,
use_zarr3=use_zarr3,
custom=custom_metadata,
pytree_metadata_options=self._pytree_metadata_options,
)
logging.vlog(
Expand Down Expand Up @@ -816,12 +825,14 @@ def metadata(self, directory: epath.Path) -> tree_metadata.TreeMetadata:
tree containing metadata.
"""
is_ocdbt_checkpoint = type_handlers.is_ocdbt_checkpoint(directory)
internal_tree_metadata = self._read_metadata_file(directory)
return tree_metadata.build_default_tree_metadata(
self._read_metadata_file(directory).as_user_metadata(
internal_tree_metadata.as_user_metadata(
directory,
self._type_handler_registry,
use_ocdbt=is_ocdbt_checkpoint,
),
custom=internal_tree_metadata.custom,
)

def finalize(self, directory: epath.Path) -> None:
Expand Down Expand Up @@ -873,12 +884,19 @@ class BasePyTreeSaveArgs(CheckpointArgs):
enable_pinned_host_transfer: True by default. If False, disables transfer to
pinned host when copying from device to host, regardless of the presence
of pinned host memory.
custom_metadata: A JSON-serializable object (typically just a nested
dictionary containing string keys and basic type values) that stores user-
specified metadata. This metadata is stored along with the Orbax-internal
PyTree metadata. This can be used to supplement information about the
PyTree checkpoint with information about e.g. the model used to generate
the checkpoint.
"""

item: PyTree
save_args: Optional[PyTree] = None
ocdbt_target_data_file_size: Optional[int] = None
enable_pinned_host_transfer: bool = True
custom_metadata: tree_types.JsonType | None = None


@register_with_handler(BasePyTreeCheckpointHandler, for_restore=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from orbax.checkpoint._src.serialization import serialization
from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils
from orbax.checkpoint._src.serialization import type_handlers
from orbax.checkpoint._src.tree import types as tree_types
from orbax.checkpoint._src.tree import utils as tree_utils
import tensorstore as ts

Expand Down Expand Up @@ -428,6 +429,7 @@ def _get_impl_save_args(
save_args=args.save_args,
ocdbt_target_data_file_size=args.ocdbt_target_data_file_size,
enable_pinned_host_transfer=args.enable_pinned_host_transfer,
custom_metadata=args.custom_metadata,
)


Expand Down Expand Up @@ -1046,12 +1048,19 @@ class PyTreeSaveArgs(CheckpointArgs):
enable_pinned_host_transfer: True by default. If False, disables transfer to
pinned host when copying from device to host, regardless of the presence
of pinned host memory.
custom_metadata: A JSON-serializable object (typically just a nested
dictionary containing string keys and basic type values) that stores user-
specified metadata. This metadata is stored along with the Orbax-internal
PyTree metadata. This can be used to supplement information about the
PyTree checkpoint with information about e.g. the model used to generate
the checkpoint.
"""

item: PyTree
save_args: Optional[PyTree] = None
ocdbt_target_data_file_size: Optional[int] = None
enable_pinned_host_transfer: bool = True
custom_metadata: tree_types.JsonType | None = None

def __post_init__(self):
if isinstance(self.item, tree_metadata.TreeMetadata):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from orbax.checkpoint._src.handlers import pytree_checkpoint_handler
from orbax.checkpoint._src.metadata import pytree_metadata_options as pytree_metadata_options_lib
from orbax.checkpoint._src.metadata import tree as tree_metadata
from orbax.checkpoint._src.tree import types as tree_types
from orbax.checkpoint._src.tree import utils as tree_utils


Expand Down Expand Up @@ -144,15 +145,19 @@ async def async_save(
'Make sure to specify kwarg name `args=` when providing'
' `StandardSaveArgs`.'
)
custom_metadata = None
if args is not None:
item = args.item
save_args = args.save_args
custom_metadata = args.custom_metadata

self._validate_save_state(item, save_args=save_args)
return await self._impl.async_save(
directory,
args=pytree_checkpoint_handler.PyTreeSaveArgs(
item=item, save_args=save_args
item=item,
save_args=save_args,
custom_metadata=custom_metadata,
),
)

Expand Down Expand Up @@ -266,10 +271,17 @@ class StandardSaveArgs(CheckpointArgs):
save_args: a PyTree with the same structure of `item`, which consists of
`ocp.SaveArgs` objects as values. `None` can be used for values where no
`SaveArgs` are specified.
custom_metadata: A JSON-serializable object (typically just a nested
dictionary containing string keys and basic type values) that stores user-
specified metadata. This metadata is stored along with the Orbax-internal
PyTree metadata. This can be used to supplement information about the
PyTree checkpoint with information about e.g. the model used to generate
the checkpoint.
"""

item: PyTree
save_args: Optional[PyTree] = None
custom_metadata: tree_types.JsonType | None = None

def __post_init__(self):
if isinstance(self.item, tree_metadata.TreeMetadata):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

"""Test for standard_checkpoint_handler.py."""

# pylint: disable=protected-access, missing-function-docstring

import functools
from typing import Any

Expand Down Expand Up @@ -127,7 +129,6 @@ def test_basic_no_item_arg(self):
test_utils.assert_tree_equal(self, self.pytree, restored)

def test_shape_dtype_struct(self):
"""Test case."""
self.handler.save(
self.directory, args=self.save_args_cls(self.mixed_pytree)
)
Expand Down Expand Up @@ -162,7 +163,7 @@ def test_custom_layout(self):
custom_layout = Layout(
device_local_layout=DLL(
major_to_minor=arr.layout.device_local_layout.major_to_minor[::-1], # pytype: disable=attribute-error
_tiling=arr.layout.device_local_layout._tiling, # pylint: disable=protected-access # pytype: disable=attribute-error
_tiling=arr.layout.device_local_layout._tiling, # pytype: disable=attribute-error
),
sharding=arr.sharding,
)
Expand Down Expand Up @@ -210,7 +211,6 @@ def test_custom_layout(self):

@parameterized.parameters((True,), (False,))
def test_change_shape(self, strict: bool):
"""Test case."""
if not hasattr(self.restore_args_cls, 'strict'):
self.skipTest('strict option not supported for this handler')
mesh = jax.sharding.Mesh(np.asarray(jax.devices()), ('x',))
Expand Down Expand Up @@ -255,7 +255,6 @@ def test_restore_unsupported_type(self):
self.handler.restore(self.directory, args=self.restore_args_cls(pytree))

def test_cast(self):
"""Test case."""
# TODO(dicentra): casting from int dtypes currently doesn't work
# in the model surgery context.
save_args = jax.tree.map(
Expand Down Expand Up @@ -289,7 +288,6 @@ def check_dtype(x, dtype):
jax.tree.map(lambda x: check_dtype(x, jnp.bfloat16), restored)

def test_flax_model(self):
"""Test case."""

@flax.struct.dataclass
class Params(flax.struct.PyTreeNode):
Expand Down Expand Up @@ -318,12 +316,10 @@ def make_params():
test_utils.assert_tree_equal(self, params, restored)

def test_empty_error(self):
"""Test case."""
with self.assertRaises(ValueError):
self.handler.save(self.directory, args=self.save_args_cls({}))

def test_empty_dict_node(self):
"""Test case."""
item = {'a': {}, 'b': 3}
self.handler.save(self.directory, args=self.save_args_cls(item))
restored = self.handler.restore(
Expand All @@ -332,7 +328,6 @@ def test_empty_dict_node(self):
self.assertDictEqual(restored, item)

def test_empty_none_node(self):
"""Test case."""
item = {'c': None, 'd': 2}
self.handler.save(self.directory, args=self.save_args_cls(item))
restored = self.handler.restore(
Expand All @@ -341,7 +336,6 @@ def test_empty_none_node(self):
self.assertDictEqual(restored, item)

def test_none_node_in_restore_args(self):
"""Test case."""
devices = np.asarray(jax.devices())
mesh = jax.sharding.Mesh(devices, ('x',))
mesh_axes = jax.sharding.PartitionSpec(
Expand All @@ -358,7 +352,6 @@ def test_none_node_in_restore_args(self):
test_utils.assert_tree_equal(self, restored, {'b': None})

def test_masked_shape_dtype_struct(self):
"""Test case."""

def _should_mask(keypath):
return keypath[0].key == 'a' or (
Expand Down Expand Up @@ -398,3 +391,14 @@ def _none(keypath, x):
# Restore it without any item.
restored = self.handler.restore(self.directory)
test_utils.assert_tree_equal(self, expected, restored)

def test_custom_metadata(self):
custom_metadata = {'foo': 1}
self.handler.save(
self.directory,
args=self.save_args_cls(
self.pytree, custom_metadata=custom_metadata
),
)
metadata = self.handler.metadata(self.directory)
self.assertEqual(metadata.custom, custom_metadata)
20 changes: 19 additions & 1 deletion checkpoint/orbax/checkpoint/_src/metadata/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import enum
import functools
import inspect
import json
import operator
import typing
from typing import Any, Dict, Hashable, List, Optional, Protocol, Tuple, TypeAlias, TypeVar, Union
Expand All @@ -37,6 +38,7 @@
from orbax.checkpoint._src.metadata import value_metadata_entry
from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils
from orbax.checkpoint._src.serialization import types
from orbax.checkpoint._src.tree import types as tree_types
from orbax.checkpoint._src.tree import utils as tree_utils


Expand All @@ -58,6 +60,7 @@
_USE_ZARR3 = 'use_zarr3'
_STORE_ARRAY_DATA_EQUAL_TO_FILL_VALUE = 'store_array_data_equal_to_fill_value'
_VALUE_METADATA_TREE = 'value_metadata_tree'
_CUSTOM_FIELD = 'custom'


class KeyType(enum.Enum):
Expand Down Expand Up @@ -200,12 +203,13 @@ def jax_keypath(self) -> KeyPath:
return tuple(keypath)


@dataclasses.dataclass
@dataclasses.dataclass(kw_only=True)
class InternalTreeMetadata:
"""Metadata representation of a PyTree."""

tree_metadata_entries: List[InternalTreeMetadataEntry]
use_zarr3: bool
custom: tree_types.JsonType | None
store_array_data_equal_to_fill_value: bool
pytree_metadata_options: PyTreeMetadataOptions
value_metadata_tree: PyTree | None = None
Expand All @@ -219,6 +223,14 @@ def __post_init__(self):
len(self.tree_metadata_entries),
self.value_metadata_tree is not None,
)
# Validate JSON-serializability of custom metadata.
try:
json.dumps(self.custom)
except TypeError as e:
raise TypeError(
'Failed to encode `custom` metadata as JSON object. Please ensure'
' your custom metadata is JSON-serializable.'
) from e

@classmethod
def build(
Expand All @@ -227,6 +239,7 @@ def build(
*,
save_args: Optional[PyTree] = None,
use_zarr3: bool = False,
custom: tree_types.JsonType | None = None,
pytree_metadata_options: PyTreeMetadataOptions = (
PYTREE_METADATA_OPTIONS
),
Expand Down Expand Up @@ -267,6 +280,7 @@ def build(
return InternalTreeMetadata(
tree_metadata_entries=tree_metadata_entries,
use_zarr3=use_zarr3,
custom=custom,
store_array_data_equal_to_fill_value=ts_utils.STORE_ARRAY_DATA_EQUAL_TO_FILL_VALUE,
pytree_metadata_options=pytree_metadata_options,
value_metadata_tree=value_metadata_tree,
Expand Down Expand Up @@ -295,6 +309,7 @@ def to_json(self) -> Dict[str, Any]:
},
_USE_ZARR3: True/False,
_STORE_ARRAY_DATA_EQUAL_TO_FILL_VALUE: True,
_CUSTOM_FIELD: ...,
_VALUE_METADATA_TREE: '{
"mu_nu": {
"category": "namedtuple",
Expand Down Expand Up @@ -353,6 +368,7 @@ def to_json(self) -> Dict[str, Any]:
_STORE_ARRAY_DATA_EQUAL_TO_FILL_VALUE: (
self.store_array_data_equal_to_fill_value
),
_CUSTOM_FIELD: self.custom,
}
# TODO: b/365169723 - Support versioned evolution of metadata storage.
if (
Expand All @@ -379,6 +395,7 @@ def from_json(
) -> InternalTreeMetadata:
"""Returns an InternalTreeMetadata instance from its JSON representation."""
use_zarr3 = json_dict.get(_USE_ZARR3, False)
custom = json_dict.get(_CUSTOM_FIELD, None)
store_array_data_equal_to_fill_value = json_dict.get(
_STORE_ARRAY_DATA_EQUAL_TO_FILL_VALUE, False
)
Expand All @@ -405,6 +422,7 @@ def from_json(
return InternalTreeMetadata(
tree_metadata_entries=tree_metadata_entries,
use_zarr3=use_zarr3,
custom=custom,
pytree_metadata_options=pytree_metadata_options,
value_metadata_tree=value_metadata_tree,
store_array_data_equal_to_fill_value=store_array_data_equal_to_fill_value,
Expand Down
Loading

0 comments on commit 86857c0

Please sign in to comment.