Skip to content

Commit

Permalink
Support custom PyTree metadata.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 704424560
  • Loading branch information
cpgaffney1 authored and Orbax Authors committed Jan 8, 2025
1 parent c3a936a commit 9e58750
Show file tree
Hide file tree
Showing 15 changed files with 207 additions and 103 deletions.
1 change: 1 addition & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ properties not included in any tree mapping operations.

### Added
- The ability to specify a custom `snapshot_dir` in `checkpoints_iterator`.
- User-provided custom PyTree metadata.

### Fixed
- Fix a bug where snapshots are not released by `wait_for_new_checkpoint`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils
from orbax.checkpoint._src.serialization import type_handlers
from orbax.checkpoint._src.serialization import types
from orbax.checkpoint._src.tree import types as tree_types
from orbax.checkpoint._src.tree import utils as tree_utils
import tensorstore as ts

Expand Down Expand Up @@ -436,6 +437,7 @@ async def async_save(
raise ValueError('Found empty item.')
save_args = args.save_args
ocdbt_target_data_file_size = args.ocdbt_target_data_file_size
user_metadata = args.user_metadata

save_args = _fill_missing_save_or_restore_args(item, save_args, mode='save')
byte_limiter = serialization.get_byte_limiter(self._save_concurrent_bytes)
Expand Down Expand Up @@ -476,7 +478,11 @@ async def async_save(
if multihost.is_primary_host(self._primary_host):
commit_futures.append(
self._write_metadata_file(
directory, param_infos, save_args, self._use_zarr3
directory,
param_infos=param_infos,
save_args=save_args,
user_metadata=user_metadata,
use_zarr3=self._use_zarr3,
)
)

Expand Down Expand Up @@ -728,8 +734,10 @@ class TrainState:
def _write_metadata_file(
self,
directory: epath.Path,
*,
param_infos: PyTree,
save_args: PyTree,
user_metadata: tree_types.JsonType | None = None,
use_zarr3: bool = False,
) -> future.Future:
def _save_fn():
Expand All @@ -740,6 +748,7 @@ def _save_fn():
param_infos,
save_args=save_args,
use_zarr3=use_zarr3,
user_metadata=user_metadata,
pytree_metadata_options=self._pytree_metadata_options,
)
logging.vlog(
Expand Down Expand Up @@ -816,12 +825,14 @@ def metadata(self, directory: epath.Path) -> tree_metadata.TreeMetadata:
tree containing metadata.
"""
is_ocdbt_checkpoint = type_handlers.is_ocdbt_checkpoint(directory)
internal_tree_metadata = self._read_metadata_file(directory)
return tree_metadata.build_default_tree_metadata(
self._read_metadata_file(directory).as_user_metadata(
internal_tree_metadata.as_user_metadata(
directory,
self._type_handler_registry,
use_ocdbt=is_ocdbt_checkpoint,
),
user_metadata=internal_tree_metadata.user_metadata,
)

def finalize(self, directory: epath.Path) -> None:
Expand Down Expand Up @@ -873,12 +884,19 @@ class BasePyTreeSaveArgs(CheckpointArgs):
enable_pinned_host_transfer: True by default. If False, disables transfer to
pinned host when copying from device to host, regardless of the presence
of pinned host memory.
user_metadata: A JSON-serializable object (typically just a nested
dictionary containing string keys and basic type values) that stores user-
specified metadata. This metadata is stored along with the Orbax-internal
PyTree metadata. This can be used to supplement information about the
PyTree checkpoint with information about e.g. the model used to generate
the checkpoint.
"""

item: PyTree
save_args: Optional[PyTree] = None
ocdbt_target_data_file_size: Optional[int] = None
enable_pinned_host_transfer: bool = True
user_metadata: tree_types.JsonType | None = None


@register_with_handler(BasePyTreeCheckpointHandler, for_restore=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,7 @@ def test_metadata_no_save(self, use_handler_registry):
)
self.assertIsNone(step_metadata.init_timestamp_nsecs)
self.assertIsNone(step_metadata.commit_timestamp_nsecs)
self.assertEmpty(step_metadata.custom)
self.assertEmpty(step_metadata.user_metadata)

def test_metadata_handler_registry(self):
registry = handler_registration.DefaultCheckpointHandlerRegistry()
Expand Down Expand Up @@ -781,7 +781,7 @@ def test_metadata_handler_registry(self):
)
self.assertIsNone(step_metadata.init_timestamp_nsecs)
self.assertIsNone(step_metadata.commit_timestamp_nsecs)
self.assertEmpty(step_metadata.custom)
self.assertEmpty(step_metadata.user_metadata)

def test_metadata_after_step_metadata_write(self):
handler = CompositeCheckpointHandler(
Expand All @@ -798,7 +798,7 @@ def test_metadata_after_step_metadata_write(self):
)
self.assertIsNone(step_metadata.init_timestamp_nsecs)
self.assertIsNone(step_metadata.commit_timestamp_nsecs)
self.assertEmpty(step_metadata.custom)
self.assertEmpty(step_metadata.user_metadata)

metadata_to_write = checkpoint.StepMetadata(
format='orbax',
Expand All @@ -817,7 +817,7 @@ def test_metadata_after_step_metadata_write(self):
),
init_timestamp_nsecs=1000,
commit_timestamp_nsecs=2000,
custom={
user_metadata={
'custom_key': 'custom_value',
},
)
Expand All @@ -842,7 +842,9 @@ def test_metadata_after_step_metadata_write(self):
)
self.assertEqual(step_metadata.init_timestamp_nsecs, 1000)
self.assertEqual(step_metadata.commit_timestamp_nsecs, 2000)
self.assertEqual(step_metadata.custom, {'custom_key': 'custom_value'})
self.assertEqual(
step_metadata.user_metadata, {'custom_key': 'custom_value'}
)

def test_metadata_existing_items_updates_step_metadata(self):
handler = CompositeCheckpointHandler(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from orbax.checkpoint._src.serialization import serialization
from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils
from orbax.checkpoint._src.serialization import type_handlers
from orbax.checkpoint._src.tree import types as tree_types
from orbax.checkpoint._src.tree import utils as tree_utils
import tensorstore as ts

Expand Down Expand Up @@ -431,6 +432,7 @@ def _get_impl_save_args(
save_args=args.save_args,
ocdbt_target_data_file_size=args.ocdbt_target_data_file_size,
enable_pinned_host_transfer=args.enable_pinned_host_transfer,
user_metadata=args.user_metadata,
)


Expand Down Expand Up @@ -1052,12 +1054,19 @@ class PyTreeSaveArgs(CheckpointArgs):
enable_pinned_host_transfer: True by default. If False, disables transfer to
pinned host when copying from device to host, regardless of the presence
of pinned host memory.
user_metadata: A JSON-serializable object (typically just a nested
dictionary containing string keys and basic type values) that stores user-
specified metadata. This metadata is stored along with the Orbax-internal
PyTree metadata. This can be used to supplement information about the
PyTree checkpoint with information about e.g. the model used to generate
the checkpoint.
"""

item: PyTree
save_args: Optional[PyTree] = None
ocdbt_target_data_file_size: Optional[int] = None
enable_pinned_host_transfer: bool = True
user_metadata: tree_types.JsonType | None = None

def __post_init__(self):
if isinstance(self.item, tree_metadata.TreeMetadata):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from orbax.checkpoint._src.handlers import pytree_checkpoint_handler
from orbax.checkpoint._src.metadata import pytree_metadata_options as pytree_metadata_options_lib
from orbax.checkpoint._src.metadata import tree as tree_metadata
from orbax.checkpoint._src.tree import types as tree_types
from orbax.checkpoint._src.tree import utils as tree_utils


Expand Down Expand Up @@ -144,15 +145,19 @@ async def async_save(
'Make sure to specify kwarg name `args=` when providing'
' `StandardSaveArgs`.'
)
user_metadata = None
if args is not None:
item = args.item
save_args = args.save_args
user_metadata = args.user_metadata

self._validate_save_state(item, save_args=save_args)
return await self._impl.async_save(
directory,
args=pytree_checkpoint_handler.PyTreeSaveArgs(
item=item, save_args=save_args
item=item,
save_args=save_args,
user_metadata=user_metadata,
),
)

Expand Down Expand Up @@ -266,10 +271,17 @@ class StandardSaveArgs(CheckpointArgs):
save_args: a PyTree with the same structure of `item`, which consists of
`ocp.SaveArgs` objects as values. `None` can be used for values where no
`SaveArgs` are specified.
user_metadata: A JSON-serializable object (typically just a nested
dictionary containing string keys and basic type values) that stores user-
specified metadata. This metadata is stored along with the Orbax-internal
PyTree metadata. This can be used to supplement information about the
PyTree checkpoint with information about e.g. the model used to generate
the checkpoint.
"""

item: PyTree
save_args: Optional[PyTree] = None
user_metadata: tree_types.JsonType | None = None

def __post_init__(self):
if isinstance(self.item, tree_metadata.TreeMetadata):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

"""Test for standard_checkpoint_handler.py."""

# pylint: disable=protected-access, missing-function-docstring

import functools
from typing import Any

Expand Down Expand Up @@ -127,7 +129,6 @@ def test_basic_no_item_arg(self):
test_utils.assert_tree_equal(self, self.pytree, restored)

def test_shape_dtype_struct(self):
"""Test case."""
self.handler.save(
self.directory, args=self.save_args_cls(self.mixed_pytree)
)
Expand Down Expand Up @@ -162,7 +163,7 @@ def test_custom_layout(self):
custom_layout = Layout(
device_local_layout=DLL(
major_to_minor=arr.layout.device_local_layout.major_to_minor[::-1], # pytype: disable=attribute-error
_tiling=arr.layout.device_local_layout._tiling, # pylint: disable=protected-access # pytype: disable=attribute-error
_tiling=arr.layout.device_local_layout._tiling, # pytype: disable=attribute-error
),
sharding=arr.sharding,
)
Expand Down Expand Up @@ -210,7 +211,6 @@ def test_custom_layout(self):

@parameterized.parameters((True,), (False,))
def test_change_shape(self, strict: bool):
"""Test case."""
if not hasattr(self.restore_args_cls, 'strict'):
self.skipTest('strict option not supported for this handler')
mesh = jax.sharding.Mesh(np.asarray(jax.devices()), ('x',))
Expand Down Expand Up @@ -255,7 +255,6 @@ def test_restore_unsupported_type(self):
self.handler.restore(self.directory, args=self.restore_args_cls(pytree))

def test_cast(self):
"""Test case."""
# TODO(dicentra): casting from int dtypes currently doesn't work
# in the model surgery context.
save_args = jax.tree.map(
Expand Down Expand Up @@ -289,7 +288,6 @@ def check_dtype(x, dtype):
jax.tree.map(lambda x: check_dtype(x, jnp.bfloat16), restored)

def test_flax_model(self):
"""Test case."""

@flax.struct.dataclass
class Params(flax.struct.PyTreeNode):
Expand Down Expand Up @@ -318,12 +316,10 @@ def make_params():
test_utils.assert_tree_equal(self, params, restored)

def test_empty_error(self):
"""Test case."""
with self.assertRaises(ValueError):
self.handler.save(self.directory, args=self.save_args_cls({}))

def test_empty_dict_node(self):
"""Test case."""
item = {'a': {}, 'b': 3}
self.handler.save(self.directory, args=self.save_args_cls(item))
restored = self.handler.restore(
Expand All @@ -332,7 +328,6 @@ def test_empty_dict_node(self):
self.assertDictEqual(restored, item)

def test_empty_none_node(self):
"""Test case."""
item = {'c': None, 'd': 2}
self.handler.save(self.directory, args=self.save_args_cls(item))
restored = self.handler.restore(
Expand All @@ -341,7 +336,6 @@ def test_empty_none_node(self):
self.assertDictEqual(restored, item)

def test_none_node_in_restore_args(self):
"""Test case."""
devices = np.asarray(jax.devices())
mesh = jax.sharding.Mesh(devices, ('x',))
mesh_axes = jax.sharding.PartitionSpec(
Expand All @@ -358,7 +352,6 @@ def test_none_node_in_restore_args(self):
test_utils.assert_tree_equal(self, restored, {'b': None})

def test_masked_shape_dtype_struct(self):
"""Test case."""

def _should_mask(keypath):
return keypath[0].key == 'a' or (
Expand Down Expand Up @@ -398,3 +391,12 @@ def _none(keypath, x):
# Restore it without any item.
restored = self.handler.restore(self.directory)
test_utils.assert_tree_equal(self, expected, restored)

def test_user_metadata(self):
user_metadata = {'foo': 1}
self.handler.save(
self.directory,
args=self.save_args_cls(self.pytree, user_metadata=user_metadata),
)
metadata = self.handler.metadata(self.directory)
self.assertEqual(metadata.user_metadata, user_metadata)
8 changes: 4 additions & 4 deletions checkpoint/orbax/checkpoint/_src/metadata/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class StepMetadata:
Specified as nano seconds since epoch. default=None.
commit_timestamp_nsecs: commit timestamp of a checkpoint, specified as nano
seconds since epoch. default=None.
custom: User-provided custom metadata.
user_metadata: User-provided custom metadata.
"""

format: str | None = None
Expand All @@ -94,7 +94,7 @@ class StepMetadata:
)
init_timestamp_nsecs: int | None = None
commit_timestamp_nsecs: int | None = None
custom: dict[str, Any] = dataclasses.field(default_factory=dict)
user_metadata: dict[str, Any] = dataclasses.field(default_factory=dict)


@dataclasses.dataclass
Expand All @@ -104,11 +104,11 @@ class RootMetadata:
Attributes:
format: The checkpoint file format. Users should specify the format
explicitly when using something non-standard.
custom: User-provided custom metadata.
user_metadata: User-provided custom metadata.
"""

format: str | None = None
custom: dict[str, Any] | None = dataclasses.field(default_factory=dict)
user_metadata: dict[str, Any] | None = dataclasses.field(default_factory=dict)


class MetadataStore(Protocol):
Expand Down
Loading

0 comments on commit 9e58750

Please sign in to comment.