Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support custom PyTree metadata. #1461

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ properties not included in any tree mapping operations.

### Added
- The ability to specify a custom `snapshot_dir` in `checkpoints_iterator`.
- User-provided custom PyTree metadata.

### Fixed
- Fix a bug where snapshots are not released by `wait_for_new_checkpoint`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils
from orbax.checkpoint._src.serialization import type_handlers
from orbax.checkpoint._src.serialization import types
from orbax.checkpoint._src.tree import types as tree_types
from orbax.checkpoint._src.tree import utils as tree_utils
import tensorstore as ts

Expand Down Expand Up @@ -436,6 +437,7 @@ async def async_save(
raise ValueError('Found empty item.')
save_args = args.save_args
ocdbt_target_data_file_size = args.ocdbt_target_data_file_size
user_metadata = args.user_metadata

save_args = _fill_missing_save_or_restore_args(item, save_args, mode='save')
byte_limiter = serialization.get_byte_limiter(self._save_concurrent_bytes)
Expand Down Expand Up @@ -476,7 +478,11 @@ async def async_save(
if multihost.is_primary_host(self._primary_host):
commit_futures.append(
self._write_metadata_file(
directory, param_infos, save_args, self._use_zarr3
directory,
param_infos=param_infos,
save_args=save_args,
user_metadata=user_metadata,
use_zarr3=self._use_zarr3,
)
)

Expand Down Expand Up @@ -728,8 +734,10 @@ class TrainState:
def _write_metadata_file(
self,
directory: epath.Path,
*,
param_infos: PyTree,
save_args: PyTree,
user_metadata: tree_types.JsonType | None = None,
use_zarr3: bool = False,
) -> future.Future:
def _save_fn():
Expand All @@ -740,6 +748,7 @@ def _save_fn():
param_infos,
save_args=save_args,
use_zarr3=use_zarr3,
user_metadata=user_metadata,
pytree_metadata_options=self._pytree_metadata_options,
)
logging.vlog(
Expand Down Expand Up @@ -816,12 +825,14 @@ def metadata(self, directory: epath.Path) -> tree_metadata.TreeMetadata:
tree containing metadata.
"""
is_ocdbt_checkpoint = type_handlers.is_ocdbt_checkpoint(directory)
internal_tree_metadata = self._read_metadata_file(directory)
return tree_metadata.build_default_tree_metadata(
self._read_metadata_file(directory).as_user_metadata(
internal_tree_metadata.as_user_metadata(
directory,
self._type_handler_registry,
use_ocdbt=is_ocdbt_checkpoint,
),
user_metadata=internal_tree_metadata.user_metadata,
)

def finalize(self, directory: epath.Path) -> None:
Expand Down Expand Up @@ -873,12 +884,19 @@ class BasePyTreeSaveArgs(CheckpointArgs):
enable_pinned_host_transfer: True by default. If False, disables transfer to
pinned host when copying from device to host, regardless of the presence
of pinned host memory.
user_metadata: A JSON-serializable object (typically just a nested
dictionary containing string keys and basic type values) that stores user-
specified metadata. This metadata is stored along with the Orbax-internal
PyTree metadata. This can be used to supplement information about the
PyTree checkpoint with information about e.g. the model used to generate
the checkpoint.
"""

item: PyTree
save_args: Optional[PyTree] = None
ocdbt_target_data_file_size: Optional[int] = None
enable_pinned_host_transfer: bool = True
user_metadata: tree_types.JsonType | None = None


@register_with_handler(BasePyTreeCheckpointHandler, for_restore=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,7 @@ def test_metadata_no_save(self, use_handler_registry):
)
self.assertIsNone(step_metadata.init_timestamp_nsecs)
self.assertIsNone(step_metadata.commit_timestamp_nsecs)
self.assertEmpty(step_metadata.custom)
self.assertEmpty(step_metadata.user_metadata)

def test_metadata_handler_registry(self):
registry = handler_registration.DefaultCheckpointHandlerRegistry()
Expand Down Expand Up @@ -781,7 +781,7 @@ def test_metadata_handler_registry(self):
)
self.assertIsNone(step_metadata.init_timestamp_nsecs)
self.assertIsNone(step_metadata.commit_timestamp_nsecs)
self.assertEmpty(step_metadata.custom)
self.assertEmpty(step_metadata.user_metadata)

def test_metadata_after_step_metadata_write(self):
handler = CompositeCheckpointHandler(
Expand All @@ -798,7 +798,7 @@ def test_metadata_after_step_metadata_write(self):
)
self.assertIsNone(step_metadata.init_timestamp_nsecs)
self.assertIsNone(step_metadata.commit_timestamp_nsecs)
self.assertEmpty(step_metadata.custom)
self.assertEmpty(step_metadata.user_metadata)

metadata_to_write = checkpoint.StepMetadata(
format='orbax',
Expand All @@ -817,7 +817,7 @@ def test_metadata_after_step_metadata_write(self):
),
init_timestamp_nsecs=1000,
commit_timestamp_nsecs=2000,
custom={
user_metadata={
'custom_key': 'custom_value',
},
)
Expand All @@ -842,7 +842,9 @@ def test_metadata_after_step_metadata_write(self):
)
self.assertEqual(step_metadata.init_timestamp_nsecs, 1000)
self.assertEqual(step_metadata.commit_timestamp_nsecs, 2000)
self.assertEqual(step_metadata.custom, {'custom_key': 'custom_value'})
self.assertEqual(
step_metadata.user_metadata, {'custom_key': 'custom_value'}
)

def test_metadata_existing_items_updates_step_metadata(self):
handler = CompositeCheckpointHandler(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from orbax.checkpoint._src.serialization import serialization
from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils
from orbax.checkpoint._src.serialization import type_handlers
from orbax.checkpoint._src.tree import types as tree_types
from orbax.checkpoint._src.tree import utils as tree_utils
import tensorstore as ts

Expand Down Expand Up @@ -431,6 +432,7 @@ def _get_impl_save_args(
save_args=args.save_args,
ocdbt_target_data_file_size=args.ocdbt_target_data_file_size,
enable_pinned_host_transfer=args.enable_pinned_host_transfer,
user_metadata=args.user_metadata,
)


Expand Down Expand Up @@ -1052,12 +1054,19 @@ class PyTreeSaveArgs(CheckpointArgs):
enable_pinned_host_transfer: True by default. If False, disables transfer to
pinned host when copying from device to host, regardless of the presence
of pinned host memory.
user_metadata: A JSON-serializable object (typically just a nested
dictionary containing string keys and basic type values) that stores user-
specified metadata. This metadata is stored along with the Orbax-internal
PyTree metadata. This can be used to supplement information about the
PyTree checkpoint with information about e.g. the model used to generate
the checkpoint.
"""

item: PyTree
save_args: Optional[PyTree] = None
ocdbt_target_data_file_size: Optional[int] = None
enable_pinned_host_transfer: bool = True
user_metadata: tree_types.JsonType | None = None

def __post_init__(self):
if isinstance(self.item, tree_metadata.TreeMetadata):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from orbax.checkpoint._src.handlers import pytree_checkpoint_handler
from orbax.checkpoint._src.metadata import pytree_metadata_options as pytree_metadata_options_lib
from orbax.checkpoint._src.metadata import tree as tree_metadata
from orbax.checkpoint._src.tree import types as tree_types
from orbax.checkpoint._src.tree import utils as tree_utils


Expand Down Expand Up @@ -144,15 +145,19 @@ async def async_save(
'Make sure to specify kwarg name `args=` when providing'
' `StandardSaveArgs`.'
)
user_metadata = None
if args is not None:
item = args.item
save_args = args.save_args
user_metadata = args.user_metadata

self._validate_save_state(item, save_args=save_args)
return await self._impl.async_save(
directory,
args=pytree_checkpoint_handler.PyTreeSaveArgs(
item=item, save_args=save_args
item=item,
save_args=save_args,
user_metadata=user_metadata,
),
)

Expand Down Expand Up @@ -266,10 +271,17 @@ class StandardSaveArgs(CheckpointArgs):
save_args: a PyTree with the same structure of `item`, which consists of
`ocp.SaveArgs` objects as values. `None` can be used for values where no
`SaveArgs` are specified.
user_metadata: A JSON-serializable object (typically just a nested
dictionary containing string keys and basic type values) that stores user-
specified metadata. This metadata is stored along with the Orbax-internal
PyTree metadata. This can be used to supplement information about the
PyTree checkpoint with information about e.g. the model used to generate
the checkpoint.
"""

item: PyTree
save_args: Optional[PyTree] = None
user_metadata: tree_types.JsonType | None = None

def __post_init__(self):
if isinstance(self.item, tree_metadata.TreeMetadata):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

"""Test for standard_checkpoint_handler.py."""

# pylint: disable=protected-access, missing-function-docstring

import functools
from typing import Any

Expand Down Expand Up @@ -127,7 +129,6 @@ def test_basic_no_item_arg(self):
test_utils.assert_tree_equal(self, self.pytree, restored)

def test_shape_dtype_struct(self):
"""Test case."""
self.handler.save(
self.directory, args=self.save_args_cls(self.mixed_pytree)
)
Expand Down Expand Up @@ -162,7 +163,7 @@ def test_custom_layout(self):
custom_layout = Layout(
device_local_layout=DLL(
major_to_minor=arr.layout.device_local_layout.major_to_minor[::-1], # pytype: disable=attribute-error
_tiling=arr.layout.device_local_layout._tiling, # pylint: disable=protected-access # pytype: disable=attribute-error
_tiling=arr.layout.device_local_layout._tiling, # pytype: disable=attribute-error
),
sharding=arr.sharding,
)
Expand Down Expand Up @@ -210,7 +211,6 @@ def test_custom_layout(self):

@parameterized.parameters((True,), (False,))
def test_change_shape(self, strict: bool):
"""Test case."""
if not hasattr(self.restore_args_cls, 'strict'):
self.skipTest('strict option not supported for this handler')
mesh = jax.sharding.Mesh(np.asarray(jax.devices()), ('x',))
Expand Down Expand Up @@ -255,7 +255,6 @@ def test_restore_unsupported_type(self):
self.handler.restore(self.directory, args=self.restore_args_cls(pytree))

def test_cast(self):
"""Test case."""
# TODO(dicentra): casting from int dtypes currently doesn't work
# in the model surgery context.
save_args = jax.tree.map(
Expand Down Expand Up @@ -289,7 +288,6 @@ def check_dtype(x, dtype):
jax.tree.map(lambda x: check_dtype(x, jnp.bfloat16), restored)

def test_flax_model(self):
"""Test case."""

@flax.struct.dataclass
class Params(flax.struct.PyTreeNode):
Expand Down Expand Up @@ -318,12 +316,10 @@ def make_params():
test_utils.assert_tree_equal(self, params, restored)

def test_empty_error(self):
"""Test case."""
with self.assertRaises(ValueError):
self.handler.save(self.directory, args=self.save_args_cls({}))

def test_empty_dict_node(self):
"""Test case."""
item = {'a': {}, 'b': 3}
self.handler.save(self.directory, args=self.save_args_cls(item))
restored = self.handler.restore(
Expand All @@ -332,7 +328,6 @@ def test_empty_dict_node(self):
self.assertDictEqual(restored, item)

def test_empty_none_node(self):
"""Test case."""
item = {'c': None, 'd': 2}
self.handler.save(self.directory, args=self.save_args_cls(item))
restored = self.handler.restore(
Expand All @@ -341,7 +336,6 @@ def test_empty_none_node(self):
self.assertDictEqual(restored, item)

def test_none_node_in_restore_args(self):
"""Test case."""
devices = np.asarray(jax.devices())
mesh = jax.sharding.Mesh(devices, ('x',))
mesh_axes = jax.sharding.PartitionSpec(
Expand All @@ -358,7 +352,6 @@ def test_none_node_in_restore_args(self):
test_utils.assert_tree_equal(self, restored, {'b': None})

def test_masked_shape_dtype_struct(self):
"""Test case."""

def _should_mask(keypath):
return keypath[0].key == 'a' or (
Expand Down Expand Up @@ -398,3 +391,12 @@ def _none(keypath, x):
# Restore it without any item.
restored = self.handler.restore(self.directory)
test_utils.assert_tree_equal(self, expected, restored)

def test_user_metadata(self):
user_metadata = {'foo': 1}
self.handler.save(
self.directory,
args=self.save_args_cls(self.pytree, user_metadata=user_metadata),
)
metadata = self.handler.metadata(self.directory)
self.assertEqual(metadata.user_metadata, user_metadata)
8 changes: 4 additions & 4 deletions checkpoint/orbax/checkpoint/_src/metadata/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class StepMetadata:
Specified as nano seconds since epoch. default=None.
commit_timestamp_nsecs: commit timestamp of a checkpoint, specified as nano
seconds since epoch. default=None.
custom: User-provided custom metadata.
user_metadata: User-provided custom metadata.
"""

format: str | None = None
Expand All @@ -94,7 +94,7 @@ class StepMetadata:
)
init_timestamp_nsecs: int | None = None
commit_timestamp_nsecs: int | None = None
custom: dict[str, Any] = dataclasses.field(default_factory=dict)
user_metadata: dict[str, Any] = dataclasses.field(default_factory=dict)


@dataclasses.dataclass
Expand All @@ -104,11 +104,11 @@ class RootMetadata:
Attributes:
format: The checkpoint file format. Users should specify the format
explicitly when using something non-standard.
custom: User-provided custom metadata.
user_metadata: User-provided custom metadata.
"""

format: str | None = None
custom: dict[str, Any] | None = dataclasses.field(default_factory=dict)
user_metadata: dict[str, Any] | None = dataclasses.field(default_factory=dict)


class MetadataStore(Protocol):
Expand Down
Loading
Loading