Skip to content

Commit

Permalink
Internal change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 713488575
  • Loading branch information
niketkumar authored and Orbax Authors committed Jan 9, 2025
1 parent 1aa4c35 commit eba004e
Show file tree
Hide file tree
Showing 5 changed files with 406 additions and 17 deletions.
5 changes: 5 additions & 0 deletions checkpoint/orbax/checkpoint/_src/asyncio_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,8 @@ def run_sync(
except RuntimeError:
pass
return asyncio.run(coro)


async def chain(*awaitables):
"""Executes `awaitables` sequentially and returns the results."""
return [await a for a in awaitables]
229 changes: 229 additions & 0 deletions checkpoint/orbax/checkpoint/_src/metadata/ts_array_metadata_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
# Copyright 2024 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Storage for `tensorstore_utils.ArrayMetadata` (not value.ArrayMetadata)."""

import dataclasses
import json
import threading
from typing import Any, Iterator, List, Sequence
from absl import logging
from etils import epath
from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils


@dataclasses.dataclass(frozen=True, kw_only=True)
class SerializedArrayMetadata:
"""Serialized version of `tensorstore_utils.ArrayMetadata`.
Not all fields of `tensorstore_utils.ArrayMetadata` are serialized.
Used in subchunking based checkpointing context.
"""

param_name: str # Unique full name of the parameter.
write_shape: ts_utils.Shape
chunk_shape: ts_utils.Shape


class PathResolver:
"""Resolves paths for the metadata store read and write."""

_metadata_subdir = 'ts_array_metadatas'

def _file_name(self, process_index: int) -> str:
return f'process_{process_index}'

def get_process_index(self, file_path: epath.Path) -> int:
"""Returns the process index from the file path."""
process_index = file_path.name.removeprefix('process_')
if process_index.isdigit():
return int(process_index)
raise ValueError(
f'Invalid ts_array_metadata file path: {file_path}; expected file name'
' to start with "process_"'
)

def get_write_file_path(
self, checkpoint_dir: epath.Path, process_index: int
) -> epath.Path:
"""Returns the file path to write."""
return (
checkpoint_dir / self._metadata_subdir / self._file_name(process_index)
)

def get_read_file_paths(
self, checkpoint_dir: epath.Path, process_index: int | None = None
) -> Iterator[epath.Path] | epath.Path | None:
"""Returns the file paths to read.
Args:
checkpoint_dir: The base path containing metadata for each process.
process_index: The process index to read. If None, then read all processes
under `checkpoint_dir`.
Returns:
Iterator of file paths to read if `process_index` is None. A file path to
read if `process_index` is not None. None if `process_index` is not None
but metadata file does not exist.
"""
if process_index is None:
return checkpoint_dir.glob(f'{self._metadata_subdir}/*')
file_path = (
checkpoint_dir / self._metadata_subdir / self._file_name(process_index)
)
if file_path.exists():
return file_path
return None


class SerDeserializer:
"""Serializes and deserializes `tensorstore_utils.ArrayMetadata`."""

def _to_dict(self, array_metadata: ts_utils.ArrayMetadata) -> dict[str, Any]:
"""Converts `array_metadata` to a dictionary."""
return {
'array_metadata': {
'param_name': array_metadata.param_name,
'write_shape': array_metadata.write_shape,
'chunk_shape': array_metadata.chunk_shape,
}
}

def _from_dict(self, obj: dict[str, Any]) -> Any:
"""Converts a json object to `SerializedArrayMetadata` or `obj`."""
if 'array_metadata' in obj:
array_metadata = obj['array_metadata']
return SerializedArrayMetadata(
param_name=array_metadata['param_name'],
write_shape=tuple(array_metadata['write_shape']),
chunk_shape=tuple(array_metadata['chunk_shape']),
)
return obj

def serialize(self, array_metadatas: Sequence[ts_utils.ArrayMetadata]) -> str:
"""Serializes `array_metadatas` to string."""
obj = {
'array_metadatas': [
self._to_dict(array_metadata) for array_metadata in array_metadatas
]
}
return json.dumps(obj)

def deserialize(self, serialized: str) -> List[SerializedArrayMetadata]:
"""Deserializes `serialized` to `tensorstore_utils.ArrayMetadata`."""
obj = json.loads(serialized, object_hook=self._from_dict)
return obj['array_metadatas']


class Store:
"""Storage for `tensorstore_utils.ArrayMetadata` (not value.ArrayMetadata)."""

def __init__(
self,
path_resolver: PathResolver = PathResolver(),
ser_deser: SerDeserializer = SerDeserializer(),
):
self._path_resolver = path_resolver
self._ser_deser = ser_deser

async def write(
self,
checkpoint_dir: epath.Path,
array_metadatas: Sequence[ts_utils.ArrayMetadata],
process_index: int | None = None,
) -> None:
"""Writes `array_metadatas` to a file under `checkpoint_dir`.
See `PathResolver.get_write_file_path()` for the file path resolution.
Args:
checkpoint_dir: The base path containing metadata for each process.
array_metadatas: The sequence of metadata to write.
process_index: The Jax process index used to resolve the file path. If
None, then the current process index is used.
"""
if process_index is None:
process_index = multihost.process_index()
file_path = self._path_resolver.get_write_file_path(
checkpoint_dir, process_index
)
file_path.parent.mkdir(parents=True, exist_ok=True)
file_path.write_text(self._ser_deser.serialize(array_metadatas))
logging.info(
'[process=%s][thread=%s] Wrote %d tensorstore_utils.ArrayMetadata'
' to %s',
multihost.process_index(),
threading.current_thread().name,
len(array_metadatas),
file_path,
)

def read(
self,
checkpoint_dir: epath.Path,
process_index: int | None = None,
) -> (
dict[int, List[SerializedArrayMetadata]]
| List[SerializedArrayMetadata]
| None
):
"""Reads `SerializedArrayMetadata` from storage under `checkpoint_dir`.
Args:
checkpoint_dir: The base path containing metadata for each process.
process_index: The process index to read. If None, then read all processes
under `checkpoint_dir`.
Returns:
A dictionary of process index to list of metadata if `process_index`
is None. A list of metadata if `process_index` is not None. None if
metadata does not exist.
"""
if not checkpoint_dir.exists():
raise ValueError(
f'Checkpoint directory does not exist: {checkpoint_dir}.'
)
file_paths = self._path_resolver.get_read_file_paths(
checkpoint_dir, process_index
)
if file_paths is None:
logging.warning(
'[process=%s][thread=%s] No metadata found for process_index=%s,'
' checkpoint_dir=%s.',
multihost.process_index(),
threading.current_thread().name,
process_index,
checkpoint_dir,
)
return None
if isinstance(file_paths, epath.Path):
return self._ser_deser.deserialize(file_paths.read_text())
result = {
self._path_resolver.get_process_index(
file_path
): self._ser_deser.deserialize(file_path.read_text())
for file_path in file_paths
}
if not result:
logging.warning(
'[process=%s][thread=%s] No metadata found for any process_index,'
' checkpoint_dir=%s.',
multihost.process_index(),
threading.current_thread().name,
checkpoint_dir,
)
return None
return result
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright 2024 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for `ts_array_metadata_store` module."""

import unittest
from absl.testing import absltest
from absl.testing import parameterized
from etils import epath
import numpy as np
from orbax.checkpoint._src.metadata import ts_array_metadata_store
from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils


class StoreTest(parameterized.TestCase, unittest.IsolatedAsyncioTestCase):

def setUp(self):
super().setUp()
self.checkpoint_dir = epath.Path(self.create_tempdir().full_path)
self.store = ts_array_metadata_store.Store()

def test_non_existing_checkpoint_dir(self):
with self.assertRaisesRegex(
ValueError, 'Checkpoint directory does not exist'
):
_ = self.store.read(self.checkpoint_dir / 'unknown_dir')

def test_non_existing_metadata_files(self):
self.assertIsNone(self.store.read(self.checkpoint_dir))

(self.checkpoint_dir / 'ts_array_metadatas').mkdir(
parents=True, exist_ok=False
)
self.assertIsNone(self.store.read(self.checkpoint_dir))

async def test_write_and_read_single_process(self):
process_index = 0
array_metadatas = [
ts_utils.ArrayMetadata(
param_name='a',
shape=(1, 2, 3),
dtype=np.dtype(int),
write_shape=(1, 2, 3),
chunk_shape=(1, 2, 3),
use_ocdbt=False,
use_zarr3=False,
),
ts_utils.ArrayMetadata(
param_name='b',
shape=(1, 1, 1),
dtype=np.dtype(int),
write_shape=(1, 1, 1),
chunk_shape=(1, 1, 1),
use_ocdbt=False,
use_zarr3=False,
),
]
await self.store.write(
self.checkpoint_dir, array_metadatas, process_index=process_index
)

self.assertEqual(
self.store.read(self.checkpoint_dir, process_index=process_index),
[
ts_array_metadata_store.SerializedArrayMetadata(
param_name='a',
write_shape=(1, 2, 3),
chunk_shape=(1, 2, 3),
),
ts_array_metadata_store.SerializedArrayMetadata(
param_name='b',
write_shape=(1, 1, 1),
chunk_shape=(1, 1, 1),
),
],
)

async def test_write_and_read_multiple_process(self):
for process_index in [0, 1, 2]:
array_metadatas = [
ts_utils.ArrayMetadata(
param_name=f'a_{process_index}',
shape=(1, 2, 3),
dtype=np.dtype(int),
write_shape=(1, 2, 3),
chunk_shape=(1, 2, 3),
use_ocdbt=False,
use_zarr3=False,
),
]
await self.store.write(
self.checkpoint_dir, array_metadatas, process_index=process_index
)

self.assertEqual(
self.store.read(self.checkpoint_dir, process_index=None),
{
0: [
ts_array_metadata_store.SerializedArrayMetadata(
param_name='a_0',
write_shape=(1, 2, 3),
chunk_shape=(1, 2, 3),
)
],
1: [
ts_array_metadata_store.SerializedArrayMetadata(
param_name='a_1',
write_shape=(1, 2, 3),
chunk_shape=(1, 2, 3),
)
],
2: [
ts_array_metadata_store.SerializedArrayMetadata(
param_name='a_2',
write_shape=(1, 2, 3),
chunk_shape=(1, 2, 3),
)
],
},
)


if __name__ == '__main__':
absltest.main()
Loading

0 comments on commit eba004e

Please sign in to comment.