Skip to content

Commit

Permalink
Support updating table metadata (apache#139)
Browse files Browse the repository at this point in the history
* Implement table metadata updater first draft

* fix updater error and add tests

* implement apply_metadata_update which is simpler

* remove old implementation

* re-organize method place

* fix nit

* fix test

* add another test

* clear TODO

* add a combined test

* Fix merge conflict

* remove table requirement validation for PR simplification

* make context private and solve elif issue

* remove private field access

* push snapshot ref validation to its builder using pydantic

* fix comment

* remove unnecessary code for AddSchemaUpdate update

* replace if with elif

* enhance the set current schema update implementation and some other changes

* make apply_table_update private

* fix an error

* remove unnecessary last_added_schema_id
  • Loading branch information
HonahX authored Dec 4, 2023
1 parent 2ca2bb0 commit 8330610
Show file tree
Hide file tree
Showing 7 changed files with 535 additions and 101 deletions.
203 changes: 193 additions & 10 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
# under the License.
from __future__ import annotations

import datetime
import itertools
import uuid
from abc import ABC, abstractmethod
from copy import copy
from dataclasses import dataclass
from enum import Enum
from functools import cached_property
from functools import cached_property, singledispatch
from itertools import chain
from typing import (
TYPE_CHECKING,
Expand All @@ -41,6 +42,7 @@

from pydantic import Field, SerializeAsAny
from sortedcontainers import SortedList
from typing_extensions import Annotated

from pyiceberg.exceptions import ResolveError, ValidationError
from pyiceberg.expressions import (
Expand Down Expand Up @@ -69,8 +71,13 @@
promote,
visit,
)
from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER, TableMetadata
from pyiceberg.table.refs import SnapshotRef
from pyiceberg.table.metadata import (
INITIAL_SEQUENCE_NUMBER,
SUPPORTED_TABLE_FORMAT_VERSION,
TableMetadata,
TableMetadataUtil,
)
from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef
from pyiceberg.table.snapshots import Snapshot, SnapshotLogEntry
from pyiceberg.table.sorting import SortOrder
from pyiceberg.typedef import (
Expand All @@ -90,6 +97,7 @@
StructType,
)
from pyiceberg.utils.concurrent import ExecutorFactory
from pyiceberg.utils.datetime import datetime_to_millis

if TYPE_CHECKING:
import pandas as pd
Expand Down Expand Up @@ -320,9 +328,9 @@ class SetSnapshotRefUpdate(TableUpdate):
ref_name: str = Field(alias="ref-name")
type: Literal["tag", "branch"]
snapshot_id: int = Field(alias="snapshot-id")
max_age_ref_ms: int = Field(alias="max-ref-age-ms")
max_snapshot_age_ms: int = Field(alias="max-snapshot-age-ms")
min_snapshots_to_keep: int = Field(alias="min-snapshots-to-keep")
max_ref_age_ms: Annotated[Optional[int], Field(alias="max-ref-age-ms", default=None)]
max_snapshot_age_ms: Annotated[Optional[int], Field(alias="max-snapshot-age-ms", default=None)]
min_snapshots_to_keep: Annotated[Optional[int], Field(alias="min-snapshots-to-keep", default=None)]


class RemoveSnapshotsUpdate(TableUpdate):
Expand Down Expand Up @@ -350,6 +358,184 @@ class RemovePropertiesUpdate(TableUpdate):
removals: List[str]


class _TableMetadataUpdateContext:
_updates: List[TableUpdate]

def __init__(self) -> None:
self._updates = []

def add_update(self, update: TableUpdate) -> None:
self._updates.append(update)

def is_added_snapshot(self, snapshot_id: int) -> bool:
return any(
update.snapshot.snapshot_id == snapshot_id
for update in self._updates
if update.action == TableUpdateAction.add_snapshot
)

def is_added_schema(self, schema_id: int) -> bool:
return any(
update.schema_.schema_id == schema_id for update in self._updates if update.action == TableUpdateAction.add_schema
)


@singledispatch
def _apply_table_update(update: TableUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata:
"""Apply a table update to the table metadata.
Args:
update: The update to be applied.
base_metadata: The base metadata to be updated.
context: Contains previous updates and other change tracking information in the current transaction.
Returns:
The updated metadata.
"""
raise NotImplementedError(f"Unsupported table update: {update}")


@_apply_table_update.register(UpgradeFormatVersionUpdate)
def _(update: UpgradeFormatVersionUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata:
if update.format_version > SUPPORTED_TABLE_FORMAT_VERSION:
raise ValueError(f"Unsupported table format version: {update.format_version}")
elif update.format_version < base_metadata.format_version:
raise ValueError(f"Cannot downgrade v{base_metadata.format_version} table to v{update.format_version}")
elif update.format_version == base_metadata.format_version:
return base_metadata

updated_metadata_data = copy(base_metadata.model_dump())
updated_metadata_data["format-version"] = update.format_version

context.add_update(update)
return TableMetadataUtil.parse_obj(updated_metadata_data)


@_apply_table_update.register(AddSchemaUpdate)
def _(update: AddSchemaUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata:
if update.last_column_id < base_metadata.last_column_id:
raise ValueError(f"Invalid last column id {update.last_column_id}, must be >= {base_metadata.last_column_id}")

updated_metadata_data = copy(base_metadata.model_dump())
updated_metadata_data["last-column-id"] = update.last_column_id
updated_metadata_data["schemas"].append(update.schema_.model_dump())

context.add_update(update)
return TableMetadataUtil.parse_obj(updated_metadata_data)


@_apply_table_update.register(SetCurrentSchemaUpdate)
def _(update: SetCurrentSchemaUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata:
new_schema_id = update.schema_id
if new_schema_id == -1:
# The last added schema should be in base_metadata.schemas at this point
new_schema_id = max(schema.schema_id for schema in base_metadata.schemas)
if not context.is_added_schema(new_schema_id):
raise ValueError("Cannot set current schema to last added schema when no schema has been added")

if new_schema_id == base_metadata.current_schema_id:
return base_metadata

schema = base_metadata.schema_by_id(new_schema_id)
if schema is None:
raise ValueError(f"Schema with id {new_schema_id} does not exist")

updated_metadata_data = copy(base_metadata.model_dump())
updated_metadata_data["current-schema-id"] = new_schema_id

context.add_update(update)
return TableMetadataUtil.parse_obj(updated_metadata_data)


@_apply_table_update.register(AddSnapshotUpdate)
def _(update: AddSnapshotUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata:
if len(base_metadata.schemas) == 0:
raise ValueError("Attempting to add a snapshot before a schema is added")
elif len(base_metadata.partition_specs) == 0:
raise ValueError("Attempting to add a snapshot before a partition spec is added")
elif len(base_metadata.sort_orders) == 0:
raise ValueError("Attempting to add a snapshot before a sort order is added")
elif base_metadata.snapshot_by_id(update.snapshot.snapshot_id) is not None:
raise ValueError(f"Snapshot with id {update.snapshot.snapshot_id} already exists")
elif (
base_metadata.format_version == 2
and update.snapshot.sequence_number is not None
and update.snapshot.sequence_number <= base_metadata.last_sequence_number
and update.snapshot.parent_snapshot_id is not None
):
raise ValueError(
f"Cannot add snapshot with sequence number {update.snapshot.sequence_number} "
f"older than last sequence number {base_metadata.last_sequence_number}"
)

updated_metadata_data = copy(base_metadata.model_dump())
updated_metadata_data["last-updated-ms"] = update.snapshot.timestamp_ms
updated_metadata_data["last-sequence-number"] = update.snapshot.sequence_number
updated_metadata_data["snapshots"].append(update.snapshot.model_dump())
context.add_update(update)
return TableMetadataUtil.parse_obj(updated_metadata_data)


@_apply_table_update.register(SetSnapshotRefUpdate)
def _(update: SetSnapshotRefUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata:
snapshot_ref = SnapshotRef(
snapshot_id=update.snapshot_id,
snapshot_ref_type=update.type,
min_snapshots_to_keep=update.min_snapshots_to_keep,
max_snapshot_age_ms=update.max_snapshot_age_ms,
max_ref_age_ms=update.max_ref_age_ms,
)

existing_ref = base_metadata.refs.get(update.ref_name)
if existing_ref is not None and existing_ref == snapshot_ref:
return base_metadata

snapshot = base_metadata.snapshot_by_id(snapshot_ref.snapshot_id)
if snapshot is None:
raise ValueError(f"Cannot set {snapshot_ref.ref_name} to unknown snapshot {snapshot_ref.snapshot_id}")

update_metadata_data = copy(base_metadata.model_dump())
update_last_updated_ms = True
if context.is_added_snapshot(snapshot_ref.snapshot_id):
update_metadata_data["last-updated-ms"] = snapshot.timestamp_ms
update_last_updated_ms = False

if update.ref_name == MAIN_BRANCH:
update_metadata_data["current-snapshot-id"] = snapshot_ref.snapshot_id
if update_last_updated_ms:
update_metadata_data["last-updated-ms"] = datetime_to_millis(datetime.datetime.now().astimezone())
update_metadata_data["snapshot-log"].append(
SnapshotLogEntry(
snapshot_id=snapshot_ref.snapshot_id,
timestamp_ms=update_metadata_data["last-updated-ms"],
).model_dump()
)

update_metadata_data["refs"][update.ref_name] = snapshot_ref.model_dump()
context.add_update(update)
return TableMetadataUtil.parse_obj(update_metadata_data)


def update_table_metadata(base_metadata: TableMetadata, updates: Tuple[TableUpdate, ...]) -> TableMetadata:
"""Update the table metadata with the given updates in one transaction.
Args:
base_metadata: The base metadata to be updated.
updates: The updates in one transaction.
Returns:
The metadata with the updates applied.
"""
context = _TableMetadataUpdateContext()
new_metadata = base_metadata

for update in updates:
new_metadata = _apply_table_update(update, new_metadata, context)

return new_metadata


class TableRequirement(IcebergBaseModel):
type: str

Expand Down Expand Up @@ -552,10 +738,7 @@ def current_snapshot(self) -> Optional[Snapshot]:

def snapshot_by_id(self, snapshot_id: int) -> Optional[Snapshot]:
"""Get the snapshot of this table with the given id, or None if there is no matching snapshot."""
try:
return next(snapshot for snapshot in self.metadata.snapshots if snapshot.snapshot_id == snapshot_id)
except StopIteration:
return None
return self.metadata.snapshot_by_id(snapshot_id)

def snapshot_by_name(self, name: str) -> Optional[Snapshot]:
"""Return the snapshot referenced by the given name or null if no such reference exists."""
Expand Down
10 changes: 10 additions & 0 deletions pyiceberg/table/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@
INITIAL_SPEC_ID = 0
DEFAULT_SCHEMA_ID = 0

SUPPORTED_TABLE_FORMAT_VERSION = 2


def cleanup_snapshot_id(data: Dict[str, Any]) -> Dict[str, Any]:
"""Run before validation."""
Expand Down Expand Up @@ -216,6 +218,14 @@ class TableMetadataCommonFields(IcebergBaseModel):
There is always a main branch reference pointing to the
current-snapshot-id even if the refs map is null."""

def snapshot_by_id(self, snapshot_id: int) -> Optional[Snapshot]:
"""Get the snapshot by snapshot_id."""
return next((snapshot for snapshot in self.snapshots if snapshot.snapshot_id == snapshot_id), None)

def schema_by_id(self, schema_id: int) -> Optional[Schema]:
"""Get the schema by schema_id."""
return next((schema for schema in self.schemas if schema.schema_id == schema_id), None)


class TableMetadataV1(TableMetadataCommonFields, IcebergBaseModel):
"""Represents version 1 of the Table Metadata.
Expand Down
22 changes: 18 additions & 4 deletions pyiceberg/table/refs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
from enum import Enum
from typing import Optional

from pydantic import Field
from pydantic import Field, model_validator
from typing_extensions import Annotated

from pyiceberg.exceptions import ValidationError
from pyiceberg.typedef import IcebergBaseModel

MAIN_BRANCH = "main"
Expand All @@ -36,6 +38,18 @@ def __repr__(self) -> str:
class SnapshotRef(IcebergBaseModel):
snapshot_id: int = Field(alias="snapshot-id")
snapshot_ref_type: SnapshotRefType = Field(alias="type")
min_snapshots_to_keep: Optional[int] = Field(alias="min-snapshots-to-keep", default=None)
max_snapshot_age_ms: Optional[int] = Field(alias="max-snapshot-age-ms", default=None)
max_ref_age_ms: Optional[int] = Field(alias="max-ref-age-ms", default=None)
min_snapshots_to_keep: Annotated[Optional[int], Field(alias="min-snapshots-to-keep", default=None, gt=0)]
max_snapshot_age_ms: Annotated[Optional[int], Field(alias="max-snapshot-age-ms", default=None, gt=0)]
max_ref_age_ms: Annotated[Optional[int], Field(alias="max-ref-age-ms", default=None, gt=0)]

@model_validator(mode='after')
def check_min_snapshots_to_keep(self) -> 'SnapshotRef':
if self.min_snapshots_to_keep is not None and self.snapshot_ref_type == SnapshotRefType.TAG:
raise ValidationError("Tags do not support setting minSnapshotsToKeep")
return self

@model_validator(mode='after')
def check_max_snapshot_age_ms(self) -> 'SnapshotRef':
if self.max_snapshot_age_ms is not None and self.snapshot_ref_type == SnapshotRefType.TAG:
raise ValidationError("Tags do not support setting maxSnapshotAgeMs")
return self
42 changes: 40 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
from pyiceberg.schema import Accessor, Schema
from pyiceberg.serializers import ToOutputFile
from pyiceberg.table import FileScanTask, Table
from pyiceberg.table.metadata import TableMetadataV2
from pyiceberg.table.metadata import TableMetadataV1, TableMetadataV2
from pyiceberg.typedef import UTF8
from pyiceberg.types import (
BinaryType,
Expand Down Expand Up @@ -354,6 +354,32 @@ def all_avro_types() -> Dict[str, Any]:
}


EXAMPLE_TABLE_METADATA_V1 = {
"format-version": 1,
"table-uuid": "d20125c8-7284-442c-9aea-15fee620737c",
"location": "s3://bucket/test/location",
"last-updated-ms": 1602638573874,
"last-column-id": 3,
"schema": {
"type": "struct",
"fields": [
{"id": 1, "name": "x", "required": True, "type": "long"},
{"id": 2, "name": "y", "required": True, "type": "long", "doc": "comment"},
{"id": 3, "name": "z", "required": True, "type": "long"},
],
},
"partition-spec": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}],
"properties": {},
"current-snapshot-id": -1,
"snapshots": [{"snapshot-id": 1925, "timestamp-ms": 1602638573822}],
}


@pytest.fixture(scope="session")
def example_table_metadata_v1() -> Dict[str, Any]:
return EXAMPLE_TABLE_METADATA_V1


EXAMPLE_TABLE_METADATA_WITH_SNAPSHOT_V1 = {
"format-version": 1,
"table-uuid": "b55d9dda-6561-423a-8bfc-787980ce421f",
Expand Down Expand Up @@ -1780,7 +1806,19 @@ def example_task(data_file: str) -> FileScanTask:


@pytest.fixture
def table(example_table_metadata_v2: Dict[str, Any]) -> Table:
def table_v1(example_table_metadata_v1: Dict[str, Any]) -> Table:
table_metadata = TableMetadataV1(**example_table_metadata_v1)
return Table(
identifier=("database", "table"),
metadata=table_metadata,
metadata_location=f"{table_metadata.location}/uuid.metadata.json",
io=load_file_io(),
catalog=NoopCatalog("NoopCatalog"),
)


@pytest.fixture
def table_v2(example_table_metadata_v2: Dict[str, Any]) -> Table:
table_metadata = TableMetadataV2(**example_table_metadata_v2)
return Table(
identifier=("database", "table"),
Expand Down
Loading

0 comments on commit 8330610

Please sign in to comment.