Skip to content

Commit

Permalink
feat: Validate metadata names for filtering and sorting (argilla-io#3993
Browse files Browse the repository at this point in the history
)

<!-- Thanks for your contribution! As part of our Community Growers
initiative 🌱, we're donating Justdiggit bunds in your name to reforest
sub-Saharan Africa. To claim your Community Growers certificate, please
contact David Berenstein in our Slack community or fill in this form
https://tally.so/r/n9XrxK once your PR has been merged. -->

# Description

Add missing metadata filter and sort naming validation in the Python
SDK.

Refs argilla-io#3748

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
frascuchon and pre-commit-ci[bot] authored Oct 19, 2023
1 parent ee2a8c6 commit 31a4d26
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ These are the section headers that we use:
- Implemented `__len__` method for filtered datasets to return the number of records matching the provided filters. ([#3916](https://github.com/argilla-io/argilla/pull/3916))
- Increase the default max result window for Elasticsearch created for Feedback datasets. ([#3929](https://github.com/argilla-io/argilla/pull/))
- Force elastic index refresh after records creation. ([#3929](https://github.com/argilla-io/argilla/pull/))
- Validate metadata fields for filtering and sorting in the Python SDK. ([#3993](https://github.com/argilla-io/argilla/pull/3993))

### Fixed

Expand Down
34 changes: 34 additions & 0 deletions src/argilla/client/feedback/dataset/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import typing

if typing.TYPE_CHECKING:
from argilla.client.feedback.dataset.base import FeedbackDatasetBase


def validate_metadata_names(dataset: "FeedbackDatasetBase", names: typing.List[str]) -> None:
"""Validates that the metadata names used in the filters are valid."""

metadata_property_names = {metadata_property.name: True for metadata_property in dataset.metadata_properties}

if not metadata_property_names:
return

for name in set(names):
if not metadata_property_names.get(name):
raise ValueError(
f"The metadata property name `{name}` does not exist in the current `FeedbackDataset` in Argilla."
f" The existing metadata properties names are: {list(metadata_property_names.keys())}."
)
10 changes: 9 additions & 1 deletion src/argilla/client/feedback/dataset/remote/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@

import httpx

from argilla.client.feedback.dataset import FeedbackDataset
from argilla.client.feedback.dataset import FeedbackDataset, helpers
from argilla.client.feedback.dataset.local import FeedbackDataset
from argilla.client.feedback.schemas.enums import ResponseStatusFilter
from argilla.client.feedback.schemas.metadata import MetadataFilters
Expand Down Expand Up @@ -441,6 +441,10 @@ def __getitem__(self, key: Union[slice, int]) -> Union[RemoteFeedbackRecord, Lis

def sort_by(self, sort: List[SortBy]) -> "RemoteFeedbackDataset":
"""Sorts the current `RemoteFeedbackDataset` based on the given sort fields and orders."""
helpers.validate_metadata_names(
dataset=self, names=[sort_.metadata_name for sort_ in sort if sort_.is_metadata_field]
)

sorted_dataset = self._create_from_dataset(self)
sorted_dataset._records = RemoteFeedbackRecords._create_from_dataset(sorted_dataset, sort_by=sort)

Expand Down Expand Up @@ -693,6 +697,10 @@ def filter_by(
if not response_status and not metadata_filters:
raise ValueError("At least one of `response_status` or `metadata_filters` must be provided.")

helpers.validate_metadata_names(
dataset=self, names=[metadata_filter.name for metadata_filter in metadata_filters]
)

filtered_dataset = RemoteFeedbackDataset._create_from_dataset(self)
filtered_dataset._records = RemoteFeedbackRecords._create_from_dataset(
filtered_dataset, response_status=response_status, metadata_filters=metadata_filters
Expand Down
6 changes: 6 additions & 0 deletions src/argilla/client/feedback/schemas/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,3 +281,9 @@ def check_order(cls, order):
def is_metadata_field(self) -> bool:
"""Returns whether the field is a metadata field."""
return self.field.startswith("metadata.")

@property
def metadata_name(self) -> Optional[str]:
"""Returns the name of the metadata field."""
if self.field.startswith("metadata."):
return self.field.split("metadata.")[1]
1 change: 1 addition & 0 deletions src/argilla/server/schemas/v1/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,7 @@ class MetadataQueryParams(BaseModel):

@property
def metadata_parsed(self) -> List[MetadataParsedQueryParam]:
# TODO: Validate metadata fields names from query params
return [MetadataParsedQueryParam(q) for q in self.metadata]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,26 @@ def test_sort_by_overrides_previous_values(self, owner: "User", test_dataset: Fe

assert records == other_records

def test_sort_by_with_wrong_field(self, owner: "User", test_dataset: FeedbackDataset):
remote = self._create_test_dataset_with_records(owner, test_dataset)

with pytest.raises(
ValueError,
match="The metadata property name `unexpected-field` does not exist in the current `FeedbackDataset` "
"in Argilla. ",
):
remote.sort_by([SortBy(field="metadata.unexpected-field", order="desc")])

def test_filter_by_wrong_field(self, owner: "User", test_dataset: FeedbackDataset):
remote = self._create_test_dataset_with_records(owner, test_dataset)

with pytest.raises(
ValueError,
match="The metadata property name `unexpected-field` does not exist in the current `FeedbackDataset` "
"in Argilla. ",
):
remote.filter_by(metadata_filters=IntegerMetadataFilter(name="unexpected-field", ge=4, le=5))

def _create_test_dataset_with_records(self, owner, test_dataset):
api.init(api_key=owner.api_key)
ws = Workspace.create(name="test-workspace")
Expand Down

0 comments on commit 31a4d26

Please sign in to comment.