From cc4cfdfae24c9c9da86360ff0ab3ea6d395a4d93 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Mon, 16 Oct 2023 18:10:25 +0200 Subject: [PATCH] feat: Some feedback dataset improvements (#3937) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Description The main improvements that this PR brings are: 1. **Feedback dataset class method unification**: All expected methods are defined in the base class, so they must be available for each dataset implementation. For those cases where the method has less sense or is not implemented yet, the user will be notified with a warning (@davidberenstein1957 review and improve, please) 2. **More general workflow with response unification**: The unification workflow support dataset connected to Argilla. This means that the `prepare_for_training` can be applied with remote datasets. The `unify_responses` returns a dataset where responses are unified. As a common practice, returning data is preferable to modifying values internally. We can avoid weird side effects. So, the unification workflow should be as: ```python from argilla import MultiLabelQuestionStrategy, FeedbackDataset ​ dataset = FeedbackDataset.from_argilla(name="my-dataset") strategy = MultiLabelQuestionStrategy("majority") # "disagreement", "majority_weighted (WIP)" unified_dataset = dataset.unify_responses( question=dataset.question_by_name("tags"), strategy=strategy, ) unified_dataset... ``` **Type of change** (Please delete options that are not relevant. Remember to title the PR according to the type of change) - [X] Refactor (change restructuring the codebase without changing functionality) - [X] Improvement (change adding some improvement to an existing functionality) **How Has This Been Tested** (Please describe the tests that you ran to verify your changes. And ideally, reference `tests`) Some of those flows have been tests locally **Checklist** - [ ] I added relevant documentation - [x] I followed the style guidelines of this project - [x] I did a self-review of my code - [ ] I made corresponding changes to the documentation - [x] My changes generate no new warnings - [x] I have added tests that prove my fix is effective or that my feature works - [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK) (see text above) - [x] I have added relevant notes to the `CHANGELOG.md` file (See https://keepachangelog.com/) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: davidberenstein1957 --- CHANGELOG.md | 2 + src/argilla/cli/datasets/__main__.py | 2 +- src/argilla/cli/datasets/list.py | 2 +- src/argilla/cli/datasets/push.py | 2 +- .../client/feedback/dataset/__init__.py | 2 +- src/argilla/client/feedback/dataset/base.py | 164 ++------ src/argilla/client/feedback/dataset/local.py | 158 -------- .../client/feedback/dataset/local/dataset.py | 370 ++++++++++++++++++ .../feedback/dataset/{ => local}/mixins.py | 62 +-- .../client/feedback/dataset/remote/base.py | 29 +- .../client/feedback/dataset/remote/dataset.py | 84 ++++ .../feedback/dataset/remote/filtered.py | 36 +- .../client/feedback/schemas/remote/shared.py | 2 +- .../dataset/{ => local}/test_dataset.py | 47 ++- .../feedback/dataset/remote/test_dataset.py | 25 ++ .../feedback/dataset/remote/test_filtered.py | 20 +- tests/unit/cli/datasets/test_delete.py | 4 +- tests/unit/cli/datasets/test_list.py | 8 +- tests/unit/cli/datasets/test_push.py | 4 +- .../unit/client/feedback/dataset/test_base.py | 27 +- 20 files changed, 673 insertions(+), 377 deletions(-) delete mode 100644 src/argilla/client/feedback/dataset/local.py create mode 100644 src/argilla/client/feedback/dataset/local/dataset.py rename src/argilla/client/feedback/dataset/{ => local}/mixins.py (85%) rename tests/integration/client/feedback/dataset/{ => local}/test_dataset.py (94%) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1479b11a08..a5ccc02363 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,8 @@ These are the section headers that we use: - Updated `Dockerfile` to use multi stage build ([#3221](https://github.com/argilla-io/argilla/pull/3221) and [#3793](https://github.com/argilla-io/argilla/pull/3793)). - Updated active learning for text classification notebooks to use the most recent small-text version ([#3831](https://github.com/argilla-io/argilla/pull/3831)). - Changed argilla dataset name in the active learning for text classification notebooks to be consistent with the default names in the huggingface spaces ([#3831](https://github.com/argilla-io/argilla/pull/3831)). +- FeedbackDataset API methods have been aligned to be accessible through the several implementations ([#3937](https://github.com/argilla-io/argilla/pull/3937)). +- The `unify_responses` support for remote datasets ([#3937](https://github.com/argilla-io/argilla/pull/3937)). ### Fixed diff --git a/src/argilla/cli/datasets/__main__.py b/src/argilla/cli/datasets/__main__.py index e6b4b60086..c6c022609f 100644 --- a/src/argilla/cli/datasets/__main__.py +++ b/src/argilla/cli/datasets/__main__.py @@ -32,7 +32,7 @@ def callback( init_callback() from argilla.cli.rich import echo_in_panel - from argilla.client.feedback.dataset.local import FeedbackDataset + from argilla.client.feedback.dataset.local.dataset import FeedbackDataset if ctx.invoked_subcommand not in _COMMANDS_REQUIRING_DATASET: return diff --git a/src/argilla/cli/datasets/list.py b/src/argilla/cli/datasets/list.py index 8b9b804138..88bd75c209 100644 --- a/src/argilla/cli/datasets/list.py +++ b/src/argilla/cli/datasets/list.py @@ -31,7 +31,7 @@ def list_datasets( from argilla.cli.rich import echo_in_panel, get_argilla_themed_table from argilla.client.api import list_datasets as list_datasets_api - from argilla.client.feedback.dataset.local import FeedbackDataset + from argilla.client.feedback.dataset.local.dataset import FeedbackDataset from argilla.client.workspaces import Workspace console = Console() diff --git a/src/argilla/cli/datasets/push.py b/src/argilla/cli/datasets/push.py index 93a64fd54e..5b1c386a25 100644 --- a/src/argilla/cli/datasets/push.py +++ b/src/argilla/cli/datasets/push.py @@ -29,7 +29,7 @@ def push_to_huggingface( from rich.spinner import Spinner from argilla.cli.rich import echo_in_panel - from argilla.client.feedback.dataset.local import FeedbackDataset + from argilla.client.feedback.dataset.local.dataset import FeedbackDataset dataset: "FeedbackDataset" = ctx.obj diff --git a/src/argilla/client/feedback/dataset/__init__.py b/src/argilla/client/feedback/dataset/__init__.py index 564e19262f..0f9646d49b 100644 --- a/src/argilla/client/feedback/dataset/__init__.py +++ b/src/argilla/client/feedback/dataset/__init__.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from argilla.client.feedback.dataset.local import FeedbackDataset +from argilla.client.feedback.dataset.local.dataset import FeedbackDataset __all__ = ["FeedbackDataset"] diff --git a/src/argilla/client/feedback/dataset/base.py b/src/argilla/client/feedback/dataset/base.py index 7a5edbc06a..cbe891d9ae 100644 --- a/src/argilla/client/feedback/dataset/base.py +++ b/src/argilla/client/feedback/dataset/base.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging -from abc import ABC, abstractproperty +from abc import ABC, ABCMeta, abstractmethod from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union from pydantic import ValidationError @@ -21,23 +20,10 @@ from argilla.client.feedback.integrations.huggingface import HuggingFaceDatasetMixin from argilla.client.feedback.schemas import ( FeedbackRecord, - FieldSchema, ) from argilla.client.feedback.schemas.types import AllowedFieldTypes, AllowedQuestionTypes -from argilla.client.feedback.training.schemas import ( - TrainingTaskForChatCompletion, - TrainingTaskForDPO, - TrainingTaskForPPO, - TrainingTaskForQuestionAnswering, - TrainingTaskForRM, - TrainingTaskForSentenceSimilarity, - TrainingTaskForSFT, - TrainingTaskForTextClassification, - TrainingTaskTypes, -) from argilla.client.feedback.utils import generate_pydantic_schema -from argilla.client.models import Framework -from argilla.utils.dependency import require_dependencies, requires_dependencies +from argilla.utils.dependency import requires_dependencies if TYPE_CHECKING: from datasets import Dataset @@ -48,10 +34,7 @@ ) -_LOGGER = logging.getLogger(__name__) - - -class FeedbackDatasetBase(ABC, HuggingFaceDatasetMixin): +class FeedbackDatasetBase(ABC, HuggingFaceDatasetMixin, metaclass=ABCMeta): """Base class with shared functionality for `FeedbackDataset` and `RemoteFeedbackDataset`.""" def __init__( @@ -135,7 +118,7 @@ def __init__( self._guidelines = guidelines @property - @abstractproperty + @abstractmethod def records(self) -> Any: """Returns the records of the dataset.""" pass @@ -189,6 +172,11 @@ def question_by_name(self, name: str) -> Union[AllowedQuestionTypes, "AllowedRem f" {', '.join(q.name for q in self._questions)}" ) + @abstractmethod + def add_records(self, *args, **kwargs) -> None: + """Adds the given records to the `FeedbackDataset`.""" + pass + def _parse_records( self, records: Union[FeedbackRecord, Dict[str, Any], List[Union[FeedbackRecord, Dict[str, Any]]]] ) -> List[FeedbackRecord]: @@ -275,110 +263,32 @@ def format_as(self, format: Literal["datasets"]) -> "Dataset": return self._huggingface_format(self) raise ValueError(f"Unsupported format '{format}'.") - # TODO(alvarobartt,davidberenstein1957): we should consider having something like - # `export(..., training=True)` to export the dataset records in any format, replacing - # both `format_as` and `prepare_for_training` - def prepare_for_training( - self, - framework: Union[Framework, str], - task: TrainingTaskTypes, - train_size: Optional[float] = 1, - test_size: Optional[float] = None, - seed: Optional[int] = None, - lang: Optional[str] = None, - ) -> Any: - """ - Prepares the dataset for training for a specific training framework and NLP task by splitting the dataset into train and test sets. + @abstractmethod + def pull(self): + """Pulls the dataset from Argilla and returns a local instance of it.""" + pass - Args: - framework: the framework to use for training. Currently supported frameworks are: `transformers`, `peft`, - `setfit`, `spacy`, `spacy-transformers`, `span_marker`, `spark-nlp`, `openai`, `trl`, `sentence-transformers`. - task: the NLP task to use for training. Currently supported tasks are: `TrainingTaskForTextClassification`, - `TrainingTaskForSFT`, `TrainingTaskForRM`, `TrainingTaskForPPO`, `TrainingTaskForDPO`, `TrainingTaskForSentenceSimilarity`. - train_size: the size of the train set. If `None`, the whole dataset will be used for training. - test_size: the size of the test set. If `None`, the whole dataset will be used for testing. - seed: the seed to use for splitting the dataset into train and test sets. - lang: the spaCy language to use for training. If `None`, the language of the dataset will be used. - """ - if isinstance(framework, str): - framework = Framework(framework) - - # validate train and test sizes - if train_size is None: - train_size = 1 - if test_size is None: - test_size = 1 - train_size - - # check if all numbers are larger than 0 - if not [abs(train_size), abs(test_size)] == [train_size, test_size]: - raise ValueError("`train_size` and `test_size` must be larger than 0.") - # check if train sizes sum up to 1 - if not (train_size + test_size) == 1: - raise ValueError("`train_size` and `test_size` must sum to 1.") - - if test_size == 0: - test_size = None - - if len(self.records) < 1: - raise ValueError( - "No records found in the dataset. Make sure you add records to the" - " dataset via the `FeedbackDataset.add_records` method first." - ) - - if isinstance(task, (TrainingTaskForTextClassification, TrainingTaskForSentenceSimilarity)): - if task.formatting_func is None: - # in sentence-transformer models we can train without labels - if task.label: - self.unify_responses(question=task.label.question, strategy=task.label.strategy) - elif isinstance(task, TrainingTaskForQuestionAnswering): - if task.formatting_func is None: - self.unify_responses(question=task.answer.name, strategy="disagreement") - elif not isinstance( - task, - ( - TrainingTaskForSFT, - TrainingTaskForRM, - TrainingTaskForPPO, - TrainingTaskForDPO, - TrainingTaskForChatCompletion, - ), - ): - raise ValueError(f"Training data {type(task)} is not supported yet") - - data = task._format_data(self) - if framework in [ - Framework.TRANSFORMERS, - Framework.SETFIT, - Framework.SPAN_MARKER, - Framework.PEFT, - ]: - return task._prepare_for_training_with_transformers( - data=data, train_size=train_size, seed=seed, framework=framework - ) - elif framework in [Framework.SPACY, Framework.SPACY_TRANSFORMERS]: - require_dependencies("spacy") - import spacy - - if lang is None: - _LOGGER.warning("spaCy `lang` is not provided. Using `en`(English) as default language.") - lang = spacy.blank("en") - elif lang.isinstance(str): - if len(lang) == 2: - lang = spacy.blank(lang) - else: - lang = spacy.load(lang) - return task._prepare_for_training_with_spacy(data=data, train_size=train_size, seed=seed, lang=lang) - elif framework is Framework.SPARK_NLP: - return task._prepare_for_training_with_spark_nlp(data=data, train_size=train_size, seed=seed) - elif framework is Framework.OPENAI: - return task._prepare_for_training_with_openai(data=data, train_size=train_size, seed=seed) - elif framework is Framework.TRL: - return task._prepare_for_training_with_trl(data=data, train_size=train_size, seed=seed) - elif framework is Framework.TRLX: - return task._prepare_for_training_with_trlx(data=data, train_size=train_size, seed=seed) - elif framework is Framework.SENTENCE_TRANSFORMERS: - return task._prepare_for_training_with_sentence_transformers(data=data, train_size=train_size, seed=seed) - else: - raise NotImplementedError( - f"Framework {framework} is not supported. Choose from: {[e.value for e in Framework]}" - ) + @abstractmethod + def filter_by(self, *args, **kwargs): + """Filters the current `FeedbackDataset`.""" + pass + + @abstractmethod + def delete(self): + """Deletes the `FeedbackDataset` from Argilla.""" + pass + + @abstractmethod + def prepare_for_training(self, *args, **kwargs) -> Any: + """Prepares the `FeedbackDataset` for training by creating the training.""" + pass + + @abstractmethod + def push_to_argilla(self, *args, **kwargs) -> "FeedbackDatasetBase": + """Pushes the `FeedbackDataset` to Argilla.""" + pass + + @abstractmethod + def unify_responses(self, *args, **kwargs): + """Unifies the responses for a given question.""" + pass diff --git a/src/argilla/client/feedback/dataset/local.py b/src/argilla/client/feedback/dataset/local.py deleted file mode 100644 index ce3480ae72..0000000000 --- a/src/argilla/client/feedback/dataset/local.py +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union - -from argilla.client.feedback.constants import FETCHING_BATCH_SIZE -from argilla.client.feedback.dataset.base import FeedbackDatasetBase -from argilla.client.feedback.dataset.mixins import ArgillaMixin, UnificationMixin -from argilla.client.feedback.schemas.fields import TextField -from argilla.client.feedback.schemas.types import AllowedQuestionTypes - -if TYPE_CHECKING: - from argilla.client.feedback.schemas.records import FeedbackRecord - from argilla.client.feedback.schemas.types import AllowedFieldTypes - - -class FeedbackDataset(FeedbackDatasetBase, ArgillaMixin, UnificationMixin): - def __init__( - self, - *, - fields: List["AllowedFieldTypes"], - questions: List[AllowedQuestionTypes], - guidelines: Optional[str] = None, - ) -> None: - """Initializes a `FeedbackDataset` instance locally. - - Args: - fields: contains the fields that will define the schema of the records in the dataset. - questions: contains the questions that will be used to annotate the dataset. - guidelines: contains the guidelines for annotating the dataset. Defaults to `None`. - - Raises: - TypeError: if `fields` is not a list of `FieldSchema`. - ValueError: if `fields` does not contain at least one required field. - TypeError: if `questions` is not a list of `TextQuestion`, `RatingQuestion`, - `LabelQuestion`, and/or `MultiLabelQuestion`. - ValueError: if `questions` does not contain at least one required question. - TypeError: if `guidelines` is not None and not a string. - ValueError: if `guidelines` is an empty string. - - Examples: - >>> import argilla as rg - >>> rg.init(api_url="...", api_key="...") - >>> dataset = rg.FeedbackDataset( - ... fields=[ - ... rg.TextField(name="text", required=True), - ... rg.TextField(name="label", required=True), - ... ], - ... questions=[ - ... rg.TextQuestion( - ... name="question-1", - ... description="This is the first question", - ... required=True, - ... ), - ... rg.RatingQuestion( - ... name="question-2", - ... description="This is the second question", - ... required=True, - ... values=[1, 2, 3, 4, 5], - ... ), - ... rg.LabelQuestion( - ... name="question-3", - ... description="This is the third question", - ... required=True, - ... labels=["positive", "negative"], - ... ), - ... rg.MultiLabelQuestion( - ... name="question-4", - ... description="This is the fourth question", - ... required=True, - ... labels=["category-1", "category-2", "category-3"], - ... ), - ... ], - ... guidelines="These are the annotation guidelines.", - ... ) - """ - super().__init__(fields=fields, questions=questions, guidelines=guidelines) - - self._records = [] - - @property - def records(self) -> List["FeedbackRecord"]: - """Returns the records in the dataset.""" - return self._records - - def __repr__(self) -> str: - """Returns a string representation of the dataset.""" - return f"" - - def __len__(self) -> int: - """Returns the number of records in the dataset.""" - return len(self._records) - - def __getitem__(self, key: Union[slice, int]) -> Union["FeedbackRecord", List["FeedbackRecord"]]: - """Returns the record(s) at the given index(es). - - Args: - key: the index(es) of the record(s) to return. Can either be a single index or a slice. - - Returns: - Either the record of the given index, or a list with the records at the given indexes. - """ - if len(self._records) < 1: - raise RuntimeError( - "In order to get items from `FeedbackDataset` you need to add them first" " with `add_records`." - ) - if isinstance(key, int) and len(self._records) < key: - raise IndexError(f"This dataset contains {len(self)} records, so index {key} is out of range.") - return self._records[key] - - def iter(self, batch_size: Optional[int] = FETCHING_BATCH_SIZE) -> Iterator[List["FeedbackRecord"]]: - """Returns an iterator over the records in the dataset. - - Args: - batch_size: the size of the batches to return. Defaults to 100. - """ - for i in range(0, len(self._records), batch_size): - yield self._records[i : i + batch_size] - - def add_records( - self, - records: Union["FeedbackRecord", Dict[str, Any], List[Union["FeedbackRecord", Dict[str, Any]]]], - ) -> None: - """Adds the given records to the dataset, and stores them locally. If you are - planning to push those to Argilla, you will need to call `push_to_argilla` afterwards, - to both create the dataset in Argilla and push the records to it. Then, from a - `FeedbackDataset` pushed to Argilla, you'll just need to call `add_records` and - those will be automatically uploaded to Argilla. - - Args: - records: can be a single `FeedbackRecord`, a list of `FeedbackRecord`, - a single dictionary, or a list of dictionaries. If a dictionary is provided, - it will be converted to a `FeedbackRecord` internally. - - Raises: - ValueError: if the given records are an empty list. - ValueError: if the given records are neither: `FeedbackRecord`, list of `FeedbackRecord`, - list of dictionaries as a record or dictionary as a record. - ValueError: if the given records do not match the expected schema. - """ - records = self._parse_records(records) - self._validate_records(records) - - if len(self._records) > 0: - self._records += records - else: - self._records = records diff --git a/src/argilla/client/feedback/dataset/local/dataset.py b/src/argilla/client/feedback/dataset/local/dataset.py new file mode 100644 index 0000000000..10fa9ad299 --- /dev/null +++ b/src/argilla/client/feedback/dataset/local/dataset.py @@ -0,0 +1,370 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import textwrap +import warnings +from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union + +from argilla.client.feedback.constants import FETCHING_BATCH_SIZE +from argilla.client.feedback.dataset.base import FeedbackDatasetBase +from argilla.client.feedback.dataset.local.mixins import ArgillaMixin +from argilla.client.feedback.schemas.questions import ( + LabelQuestion, + MultiLabelQuestion, + RankingQuestion, + RatingQuestion, + TextQuestion, +) +from argilla.client.feedback.schemas.types import AllowedQuestionTypes +from argilla.client.feedback.training.schemas import ( + TrainingTaskForChatCompletion, + TrainingTaskForDPO, + TrainingTaskForPPO, + TrainingTaskForQuestionAnswering, + TrainingTaskForRM, + TrainingTaskForSentenceSimilarity, + TrainingTaskForSFT, + TrainingTaskForTextClassification, + TrainingTaskTypes, +) +from argilla.client.feedback.unification import ( + LabelQuestionStrategy, + MultiLabelQuestionStrategy, + RankingQuestionStrategy, + RatingQuestionStrategy, + TextQuestionStrategy, +) +from argilla.client.models import Framework +from argilla.utils.dependency import require_dependencies + +if TYPE_CHECKING: + from argilla.client.feedback.schemas.records import FeedbackRecord + from argilla.client.feedback.schemas.types import AllowedFieldTypes + + +_LOGGER = logging.getLogger(__name__) + + +class FeedbackDataset(ArgillaMixin, FeedbackDatasetBase): + def __init__( + self, + *, + fields: List["AllowedFieldTypes"], + questions: List[AllowedQuestionTypes], + guidelines: Optional[str] = None, + ) -> None: + """Initializes a `FeedbackDataset` instance locally. + + Args: + fields: contains the fields that will define the schema of the records in the dataset. + questions: contains the questions that will be used to annotate the dataset. + guidelines: contains the guidelines for annotating the dataset. Defaults to `None`. + + Raises: + TypeError: if `fields` is not a list of `FieldSchema`. + ValueError: if `fields` does not contain at least one required field. + TypeError: if `questions` is not a list of `TextQuestion`, `RatingQuestion`, + `LabelQuestion`, and/or `MultiLabelQuestion`. + ValueError: if `questions` does not contain at least one required question. + TypeError: if `guidelines` is not None and not a string. + ValueError: if `guidelines` is an empty string. + + Examples: + >>> import argilla as rg + >>> rg.init(api_url="...", api_key="...") + >>> dataset = rg.FeedbackDataset( + ... fields=[ + ... rg.TextField(name="text", required=True), + ... rg.TextField(name="label", required=True), + ... ], + ... questions=[ + ... rg.TextQuestion( + ... name="question-1", + ... description="This is the first question", + ... required=True, + ... ), + ... rg.RatingQuestion( + ... name="question-2", + ... description="This is the second question", + ... required=True, + ... values=[1, 2, 3, 4, 5], + ... ), + ... rg.LabelQuestion( + ... name="question-3", + ... description="This is the third question", + ... required=True, + ... labels=["positive", "negative"], + ... ), + ... rg.MultiLabelQuestion( + ... name="question-4", + ... description="This is the fourth question", + ... required=True, + ... labels=["category-1", "category-2", "category-3"], + ... ), + ... ], + ... guidelines="These are the annotation guidelines.", + ... ) + """ + super().__init__(fields=fields, questions=questions, guidelines=guidelines) + + self._records = [] + + @property + def records(self) -> List["FeedbackRecord"]: + """Returns the records in the dataset.""" + return self._records + + def __repr__(self) -> str: + """Returns a string representation of the dataset.""" + return ( + "FeedbackDataset(" + + textwrap.indent( + f"\nfields={self.fields}\nquestions={self.questions}\nguidelines={self.guidelines})", " " + ) + + "\n)" + ) + + def __len__(self) -> int: + """Returns the number of records in the dataset.""" + return len(self._records) + + def __getitem__(self, key: Union[slice, int]) -> Union["FeedbackRecord", List["FeedbackRecord"]]: + """Returns the record(s) at the given index(es). + + Args: + key: the index(es) of the record(s) to return. Can either be a single index or a slice. + + Returns: + Either the record of the given index, or a list with the records at the given indexes. + """ + if len(self._records) < 1: + raise RuntimeError( + "In order to get items from `FeedbackDataset` you need to add them first" " with `add_records`." + ) + if isinstance(key, int) and len(self._records) < key: + raise IndexError(f"This dataset contains {len(self)} records, so index {key} is out of range.") + return self._records[key] + + def iter(self, batch_size: Optional[int] = FETCHING_BATCH_SIZE) -> Iterator[List["FeedbackRecord"]]: + """Returns an iterator over the records in the dataset. + + Args: + batch_size: the size of the batches to return. Defaults to 100. + """ + for i in range(0, len(self._records), batch_size): + yield self._records[i : i + batch_size] + + def add_records( + self, records: Union["FeedbackRecord", Dict[str, Any], List[Union["FeedbackRecord", Dict[str, Any]]]] + ) -> None: + """Adds the given records to the dataset, and stores them locally. If you are + planning to push those to Argilla, you will need to call `push_to_argilla` afterwards, + to both create the dataset in Argilla and push the records to it. Then, from a + `FeedbackDataset` pushed to Argilla, you'll just need to call `add_records` and + those will be automatically uploaded to Argilla. + + Args: + records: can be a single `FeedbackRecord`, a list of `FeedbackRecord`, + a single dictionary, or a list of dictionaries. If a dictionary is provided, + it will be converted to a `FeedbackRecord` internally. + + Raises: + ValueError: if the given records are an empty list. + ValueError: if the given records are neither: `FeedbackRecord`, list of `FeedbackRecord`, + list of dictionaries as a record or dictionary as a record. + ValueError: if the given records do not match the expected schema. + """ + records = self._parse_records(records) + self._validate_records(records) + + if len(self._records) > 0: + self._records += records + else: + self._records = records + + def pull(self) -> "FeedbackDataset": + warnings.warn( + "`pull` method is not supported for local datasets and won't take any effect." + "First, you need to push the dataset to Argilla with `FeedbackDataset.push_to_argilla()`." + "After, use `FeedbackDataset.from_argilla(...).pull()`.", + UserWarning, + ) + return self + + def filter_by(self, *args, **kwargs) -> "FeedbackDataset": + warnings.warn( + "`filter_by` method is not supported for local datasets and won't take any effect. " + "First, you need to push the dataset to Argilla with `FeedbackDataset.push_to_argilla()`." + "After, use `FeedbackDataset.from_argilla(...).filter_by()`.", + UserWarning, + ) + return self + + def delete(self): + warnings.warn( + "`delete` method is not supported for local datasets and won't take any effect. " + "First, you need to push the dataset to Argilla with `FeedbackDataset.push_to_argilla`." + "After, use `FeedbackDataset.from_argilla(...).delete()`", + UserWarning, + ) + return self + + def unify_responses( + self: "FeedbackDatasetBase", + question: Union[str, LabelQuestion, MultiLabelQuestion, RatingQuestion], + strategy: Union[ + str, LabelQuestionStrategy, MultiLabelQuestionStrategy, RatingQuestionStrategy, RankingQuestionStrategy + ], + ) -> "FeedbackDataset": + """ + The `unify_responses` function takes a question and a strategy as input and applies the strategy + to unify the responses for that question. + + Args: + question The `question` parameter can be either a string representing the name of the + question, or an instance of one of the question classes (`LabelQuestion`, `MultiLabelQuestion`, + `RatingQuestion`, `RankingQuestion`). + strategy The `strategy` parameter is used to specify the strategy to be used for unifying + responses for a given question. It can be either a string or an instance of a strategy class. + """ + if isinstance(question, str): + question = self.question_by_name(question) + + if isinstance(strategy, str): + if isinstance(question, LabelQuestion): + strategy = LabelQuestionStrategy(strategy) + elif isinstance(question, MultiLabelQuestion): + strategy = MultiLabelQuestionStrategy(strategy) + elif isinstance(question, RatingQuestion): + strategy = RatingQuestionStrategy(strategy) + elif isinstance(question, RankingQuestion): + strategy = RankingQuestionStrategy(strategy) + elif isinstance(question, TextQuestion): + strategy = TextQuestionStrategy(strategy) + else: + raise ValueError(f"Question {question} is not supported yet") + + strategy.unify_responses(self.records, question) + return self + + # TODO(alvarobartt,davidberenstein1957): we should consider having something like + # `export(..., training=True)` to export the dataset records in any format, replacing + # both `format_as` and `prepare_for_training` + def prepare_for_training( + self, + framework: Union[Framework, str], + task: TrainingTaskTypes, + train_size: Optional[float] = 1, + test_size: Optional[float] = None, + seed: Optional[int] = None, + lang: Optional[str] = None, + ) -> Any: + """ + Prepares the dataset for training for a specific training framework and NLP task by splitting the dataset into train and test sets. + + Args: + framework: the framework to use for training. Currently supported frameworks are: `transformers`, `peft`, + `setfit`, `spacy`, `spacy-transformers`, `span_marker`, `spark-nlp`, `openai`, `trl`, `sentence-transformers`. + task: the NLP task to use for training. Currently supported tasks are: `TrainingTaskForTextClassification`, + `TrainingTaskForSFT`, `TrainingTaskForRM`, `TrainingTaskForPPO`, `TrainingTaskForDPO`, `TrainingTaskForSentenceSimilarity`. + train_size: the size of the train set. If `None`, the whole dataset will be used for training. + test_size: the size of the test set. If `None`, the whole dataset will be used for testing. + seed: the seed to use for splitting the dataset into train and test sets. + lang: the spaCy language to use for training. If `None`, the language of the dataset will be used. + """ + if isinstance(framework, str): + framework = Framework(framework) + + # validate train and test sizes + if train_size is None: + train_size = 1 + if test_size is None: + test_size = 1 - train_size + + # check if all numbers are larger than 0 + if not [abs(train_size), abs(test_size)] == [train_size, test_size]: + raise ValueError("`train_size` and `test_size` must be larger than 0.") + # check if train sizes sum up to 1 + if not (train_size + test_size) == 1: + raise ValueError("`train_size` and `test_size` must sum to 1.") + + if test_size == 0: + test_size = None + + if len(self.records) < 1: + raise ValueError( + "No records found in the dataset. Make sure you add records to the" + " dataset via the `FeedbackDataset.add_records()` method first." + ) + + local_dataset = self.pull() + if isinstance(task, (TrainingTaskForTextClassification, TrainingTaskForSentenceSimilarity)): + if task.formatting_func is None: + # in sentence-transformer models we can train without labels + if task.label: + local_dataset = local_dataset.unify_responses( + question=task.label.question, strategy=task.label.strategy + ) + elif isinstance(task, TrainingTaskForQuestionAnswering): + if task.formatting_func is None: + local_dataset = self.unify_responses(question=task.answer.name, strategy="disagreement") + elif not isinstance( + task, + ( + TrainingTaskForSFT, + TrainingTaskForRM, + TrainingTaskForPPO, + TrainingTaskForDPO, + TrainingTaskForChatCompletion, + ), + ): + raise ValueError(f"Training data {type(task)} is not supported yet") + + data = task._format_data(local_dataset) + if framework in [ + Framework.TRANSFORMERS, + Framework.SETFIT, + Framework.SPAN_MARKER, + Framework.PEFT, + ]: + return task._prepare_for_training_with_transformers( + data=data, train_size=train_size, seed=seed, framework=framework + ) + elif framework in [Framework.SPACY, Framework.SPACY_TRANSFORMERS]: + require_dependencies("spacy") + import spacy + + if lang is None: + _LOGGER.warning("spaCy `lang` is not provided. Using `en`(English) as default language.") + lang = spacy.blank("en") + elif lang.isinstance(str): + if len(lang) == 2: + lang = spacy.blank(lang) + else: + lang = spacy.load(lang) + return task._prepare_for_training_with_spacy(data=data, train_size=train_size, seed=seed, lang=lang) + elif framework is Framework.SPARK_NLP: + return task._prepare_for_training_with_spark_nlp(data=data, train_size=train_size, seed=seed) + elif framework is Framework.OPENAI: + return task._prepare_for_training_with_openai(data=data, train_size=train_size, seed=seed) + elif framework is Framework.TRL: + return task._prepare_for_training_with_trl(data=data, train_size=train_size, seed=seed) + elif framework is Framework.TRLX: + return task._prepare_for_training_with_trlx(data=data, train_size=train_size, seed=seed) + elif framework is Framework.SENTENCE_TRANSFORMERS: + return task._prepare_for_training_with_sentence_transformers(data=data, train_size=train_size, seed=seed) + else: + raise NotImplementedError( + f"Framework {framework} is not supported. Choose from: {[e.value for e in Framework]}" + ) diff --git a/src/argilla/client/feedback/dataset/mixins.py b/src/argilla/client/feedback/dataset/local/mixins.py similarity index 85% rename from src/argilla/client/feedback/dataset/mixins.py rename to src/argilla/client/feedback/dataset/local/mixins.py index 4994795086..6788e57785 100644 --- a/src/argilla/client/feedback/dataset/mixins.py +++ b/src/argilla/client/feedback/dataset/local/mixins.py @@ -15,19 +15,10 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union from uuid import UUID -from tqdm import trange - from argilla.client.api import ArgillaSingleton from argilla.client.feedback.constants import PUSHING_BATCH_SIZE from argilla.client.feedback.dataset.remote.dataset import RemoteFeedbackDataset from argilla.client.feedback.schemas.enums import FieldTypes, QuestionTypes -from argilla.client.feedback.schemas.questions import ( - LabelQuestion, - MultiLabelQuestion, - RankingQuestion, - RatingQuestion, - TextQuestion, -) from argilla.client.feedback.schemas.remote.fields import RemoteTextField from argilla.client.feedback.schemas.remote.questions import ( RemoteLabelQuestion, @@ -36,28 +27,22 @@ RemoteRatingQuestion, RemoteTextQuestion, ) -from argilla.client.feedback.unification import ( - LabelQuestionStrategy, - MultiLabelQuestionStrategy, - RankingQuestionStrategy, - RatingQuestionStrategy, - TextQuestionStrategy, -) from argilla.client.feedback.utils import feedback_dataset_in_argilla from argilla.client.sdk.v1.datasets import api as datasets_api_v1 from argilla.client.workspaces import Workspace +from tqdm import trange if TYPE_CHECKING: import httpx - from argilla.client.client import Argilla as ArgillaClient - from argilla.client.feedback.dataset.local import FeedbackDataset + from argilla.client.feedback.dataset.local.dataset import FeedbackDataset from argilla.client.feedback.schemas.types import AllowedRemoteFieldTypes, AllowedRemoteQuestionTypes from argilla.client.sdk.v1.datasets.models import FeedbackDatasetModel class ArgillaMixin: - def __delete_dataset(self: "FeedbackDataset", client: "httpx.Client", id: UUID) -> None: + @staticmethod + def __delete_dataset(client: "httpx.Client", id: Union[str, UUID]) -> None: try: datasets_api_v1.delete_dataset(client=client, id=id) except Exception as e: @@ -342,42 +327,3 @@ def list(cls: Type["FeedbackDataset"], workspace: Optional[str] = None) -> List[ ) for dataset in datasets ] - - -class UnificationMixin: - def unify_responses( - self, - question: Union[str, LabelQuestion, MultiLabelQuestion, RatingQuestion], - strategy: Union[ - str, LabelQuestionStrategy, MultiLabelQuestionStrategy, RatingQuestionStrategy, RankingQuestionStrategy - ], - ) -> None: - """ - The `unify_responses` function takes a question and a strategy as input and applies the strategy - to unify the responses for that question. - - Args: - question The `question` parameter can be either a string representing the name of the - question, or an instance of one of the question classes (`LabelQuestion`, `MultiLabelQuestion`, - `RatingQuestion`, `RankingQuestion`). - strategy The `strategy` parameter is used to specify the strategy to be used for unifying - responses for a given question. It can be either a string or an instance of a strategy class. - """ - if isinstance(question, str): - question = self.question_by_name(question) - - if isinstance(strategy, str): - if isinstance(question, LabelQuestion): - strategy = LabelQuestionStrategy(strategy) - elif isinstance(question, MultiLabelQuestion): - strategy = MultiLabelQuestionStrategy(strategy) - elif isinstance(question, RatingQuestion): - strategy = RatingQuestionStrategy(strategy) - elif isinstance(question, RankingQuestion): - strategy = RankingQuestionStrategy(strategy) - elif isinstance(question, TextQuestion): - strategy = TextQuestionStrategy(strategy) - else: - raise ValueError(f"Question {question} is not supported yet") - - strategy.unify_responses(self.records, question) diff --git a/src/argilla/client/feedback/dataset/remote/base.py b/src/argilla/client/feedback/dataset/remote/base.py index 4098b4c7cf..9830c41814 100644 --- a/src/argilla/client/feedback/dataset/remote/base.py +++ b/src/argilla/client/feedback/dataset/remote/base.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import textwrap import warnings from abc import ABC, abstractmethod from datetime import datetime from typing import TYPE_CHECKING, Any, Dict, Generic, Iterator, List, Optional, Type, TypeVar, Union +from argilla import Workspace from argilla.client.feedback.dataset.base import FeedbackDatasetBase from argilla.client.feedback.dataset.remote.mixins import ArgillaRecordsMixin from argilla.client.feedback.schemas.remote.records import RemoteFeedbackRecord @@ -28,7 +30,7 @@ import httpx - from argilla.client.feedback.dataset.local import FeedbackDataset + from argilla.client.feedback.dataset.local.dataset import FeedbackDataset from argilla.client.feedback.schemas.records import FeedbackRecord from argilla.client.feedback.schemas.types import AllowedRemoteFieldTypes, AllowedRemoteQuestionTypes from argilla.client.sdk.v1.datasets.models import FeedbackRecordsModel @@ -91,7 +93,7 @@ def delete(self) -> None: pass -class RemoteFeedbackDatasetBase(Generic[T], FeedbackDatasetBase): +class RemoteFeedbackDatasetBase(FeedbackDatasetBase, Generic[T]): records_cls: Type[T] def __init__( @@ -186,10 +188,17 @@ def updated_at(self) -> datetime: def __repr__(self) -> str: """Returns a string representation of the dataset.""" + indent = " " return ( - f"" + "RemoteFeedbackDataset(" + + textwrap.indent(f"\nid={self.id}", indent) + + textwrap.indent(f"\nname={self.name}", indent) + + textwrap.indent(f"\nworkspace={self.workspace}", indent) + + textwrap.indent(f"\nurl={self.url}", indent) + + textwrap.indent(f"\nfields={self.fields}", indent) + + textwrap.indent(f"\nquestions={self.questions}", indent) + + textwrap.indent(f"\nguidelines={self.guidelines}", indent) + + ")" ) def __len__(self) -> int: @@ -252,7 +261,7 @@ def pull(self) -> "FeedbackDataset": A local instance of the dataset which is a `FeedbackDataset` object. """ # Importing here to avoid circular imports - from argilla.client.feedback.dataset.local import FeedbackDataset + from argilla.client.feedback.dataset.local.dataset import FeedbackDataset instance = FeedbackDataset( fields=self.fields, @@ -263,3 +272,11 @@ def pull(self) -> "FeedbackDataset": records=[record.to_local() for record in self._records], ) return instance + + def push_to_argilla( + self, name: str, workspace: Optional[Union[str, "Workspace"]] = None, show_progress: bool = False + ) -> "RemoteFeedbackDatasetBase": + warnings.warn( + "Already pushed datasets cannot be pushed to Argilla again because they are synced automatically." + ) + return self diff --git a/src/argilla/client/feedback/dataset/remote/dataset.py b/src/argilla/client/feedback/dataset/remote/dataset.py index d6b2eac040..4d1a7d69b3 100644 --- a/src/argilla/client/feedback/dataset/remote/dataset.py +++ b/src/argilla/client/feedback/dataset/remote/dataset.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from datetime import datetime from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union @@ -20,8 +21,23 @@ from argilla.client.feedback.constants import DELETE_DATASET_RECORDS_MAX_NUMBER, PUSHING_BATCH_SIZE from argilla.client.feedback.dataset.remote.base import RemoteFeedbackDatasetBase, RemoteFeedbackRecordsBase from argilla.client.feedback.dataset.remote.filtered import FilteredRemoteFeedbackDataset +from argilla.client.feedback.schemas.questions import ( + LabelQuestion, + MultiLabelQuestion, + RatingQuestion, +) from argilla.client.feedback.schemas.records import FeedbackRecord from argilla.client.feedback.schemas.remote.records import RemoteFeedbackRecord +from argilla.client.feedback.training.schemas import ( + TrainingTaskTypes, +) +from argilla.client.feedback.unification import ( + LabelQuestionStrategy, + MultiLabelQuestionStrategy, + RankingQuestionStrategy, + RatingQuestionStrategy, +) +from argilla.client.models import Framework from argilla.client.sdk.users.models import UserRole from argilla.client.sdk.v1.datasets import api as datasets_api_v1 from argilla.client.utils import allowed_for_roles @@ -31,6 +47,7 @@ import httpx + from argilla.client.feedback.dataset.local import FeedbackDataset from argilla.client.feedback.schemas.enums import ResponseStatusFilter from argilla.client.feedback.schemas.types import AllowedRemoteFieldTypes, AllowedRemoteQuestionTypes from argilla.client.sdk.v1.datasets.models import FeedbackRecordsModel @@ -188,3 +205,70 @@ def delete(self) -> None: datasets_api_v1.delete_dataset(client=self._client, id=self.id) except Exception as e: raise RuntimeError(f"Failed while deleting the `FeedbackDataset` from Argilla with exception: {e}") from e + + def unify_responses( + self, + question: Union[str, LabelQuestion, MultiLabelQuestion, RatingQuestion], + strategy: Union[ + str, LabelQuestionStrategy, MultiLabelQuestionStrategy, RatingQuestionStrategy, RankingQuestionStrategy + ], + ) -> "FeedbackDataset": + """ + The `unify_responses` function takes a question and a strategy as input and applies the strategy + to unify the responses for that question. + + Args: + question The `question` parameter can be either a string representing the name of the + question, or an instance of one of the question classes (`LabelQuestion`, `MultiLabelQuestion`, + `RatingQuestion`, `RankingQuestion`). + strategy The `strategy` parameter is used to specify the strategy to be used for unifying + responses for a given question. It can be either a string or an instance of a strategy class. + """ + warnings.warn( + "A local `FeedbackDataset` returned because " + "`unify_responses` is not supported for `RemoteFeedbackDataset`. " + "`RemoteFeedbackDataset`.pull().unify_responses(*args, **kwargs)` is applied.", + UserWarning, + ) + local = self.pull() + return local.unify_responses(question=question, strategy=strategy) + + def prepare_for_training( + self, + framework: Union[Framework, str], + task: TrainingTaskTypes, + train_size: Optional[float] = 1, + test_size: Optional[float] = None, + seed: Optional[int] = None, + lang: Optional[str] = None, + ) -> Any: + """ + Prepares the dataset for training for a specific training framework and NLP task by splitting the dataset into train and test sets. + + Args: + framework: the framework to use for training. Currently supported frameworks are: `transformers`, `peft`, + `setfit`, `spacy`, `spacy-transformers`, `span_marker`, `spark-nlp`, `openai`, `trl`, `sentence-transformers`. + task: the NLP task to use for training. Currently supported tasks are: `TrainingTaskForTextClassification`, + `TrainingTaskForSFT`, `TrainingTaskForRM`, `TrainingTaskForPPO`, `TrainingTaskForDPO`, `TrainingTaskForSentenceSimilarity`. + train_size: the size of the train set. If `None`, the whole dataset will be used for training. + test_size: the size of the test set. If `None`, the whole dataset will be used for testing. + seed: the seed to use for splitting the dataset into train and test sets. + lang: the spaCy language to use for training. If `None`, the language of the dataset will be used. + """ + warnings.warn( + ( + "A local `FeedbackDataset` returned because " + "`prepare_for_training` is not supported for `RemoteFeedbackDataset`. " + "`RemoteFeedbackDataset`.pull().prepare_for_training(*args, **kwargs)` is applied." + ), + UserWarning, + ) + local = self.pull() + return local.prepare_for_training( + framework=framework, + task=task, + train_size=train_size, + test_size=test_size, + seed=seed, + lang=lang, + ) diff --git a/src/argilla/client/feedback/dataset/remote/filtered.py b/src/argilla/client/feedback/dataset/remote/filtered.py index 0a5175918b..109417eada 100644 --- a/src/argilla/client/feedback/dataset/remote/filtered.py +++ b/src/argilla/client/feedback/dataset/remote/filtered.py @@ -24,6 +24,7 @@ import httpx from argilla.client.feedback.dataset.remote.dataset import RemoteFeedbackDataset + from argilla.client.feedback.schemas.enums import ResponseStatusFilter from argilla.client.feedback.schemas.records import FeedbackRecord from argilla.client.feedback.schemas.remote.records import RemoteFeedbackRecord from argilla.client.feedback.schemas.types import AllowedRemoteFieldTypes, AllowedRemoteQuestionTypes @@ -38,7 +39,7 @@ def __init__(self, dataset: "RemoteFeedbackDataset", filters: Dict[str, Any]) -> self._filters = filters def __len__(self) -> None: - raise NotImplementedError("`__len__` does not work for filtered datasets.") + raise NotImplementedError("`__len__` does not work for `FilteredRemoteFeedbackDataset`.") def _fetch_records(self, offset: int, limit: int) -> "FeedbackRecordsModel": """Fetches a batch of records from Argilla.""" @@ -55,10 +56,10 @@ def add( records: Union["FeedbackRecord", Dict[str, Any], List[Union["FeedbackRecord", Dict[str, Any]]]], show_progress: bool = True, ) -> None: - raise NotImplementedError("`records.add` does not work for filtered datasets.") + raise NotImplementedError("`records.add` does not work for `FilteredRemoteFeedbackDataset`.") def delete(self, records: List["RemoteFeedbackRecord"]) -> None: - raise NotImplementedError("`records.delete` does not work for filtered datasets.") + raise NotImplementedError("`records.delete` does not work for `FilteredRemoteFeedbackDataset`.") class FilteredRemoteFeedbackDataset(RemoteFeedbackDatasetBase[FilteredRemoteFeedbackRecords]): @@ -93,4 +94,31 @@ def __init__( ) def delete(self) -> None: - raise NotImplementedError("`delete` does not work for filtered datasets.") + raise NotImplementedError("`delete` does not work for `FilteredRemoteFeedbackDataset`.") + + def prepare_for_training(self, *args, **kwargs) -> Any: + raise NotImplementedError("`prepare_for_training` does not work for `FilteredRemoteFeedbackDataset`.") + + def unify_responses(self, *args, **kwargs): + raise NotImplementedError("`unify_responses` does not work for `FilteredRemoteFeedbackDataset`.") + + def filter_by( + self, response_status: Union["ResponseStatusFilter", List["ResponseStatusFilter"]] + ) -> "FilteredRemoteFeedbackDataset": + if not isinstance(response_status, list): + response_status = [response_status] + + return self.__class__( + client=self._client, + id=self.id, + name=self.name, + workspace=self.workspace, + created_at=self.created_at, + updated_at=self.updated_at, + fields=self.fields, + questions=self.questions, + guidelines=self.guidelines, + filters={ + "response_status": [status.value if hasattr(status, "value") else status for status in response_status] + }, + ) diff --git a/src/argilla/client/feedback/schemas/remote/shared.py b/src/argilla/client/feedback/schemas/remote/shared.py index ce3ecb952a..cbb0933f91 100644 --- a/src/argilla/client/feedback/schemas/remote/shared.py +++ b/src/argilla/client/feedback/schemas/remote/shared.py @@ -42,7 +42,7 @@ def to_local(self) -> BaseModel: @classmethod @abstractmethod - def from_api(cls) -> Type["RemoteSchema"]: + def from_api(cls, payload: "BaseModel") -> Type["RemoteSchema"]: """Abstract method to be implemented by subclasses to convert the API payload into a remote schema.""" raise NotImplementedError diff --git a/tests/integration/client/feedback/dataset/test_dataset.py b/tests/integration/client/feedback/dataset/local/test_dataset.py similarity index 94% rename from tests/integration/client/feedback/dataset/test_dataset.py rename to tests/integration/client/feedback/dataset/local/test_dataset.py index 980c2a6dbe..6348b3c145 100644 --- a/tests/integration/client/feedback/dataset/test_dataset.py +++ b/tests/integration/client/feedback/dataset/local/test_dataset.py @@ -17,6 +17,7 @@ import datasets import pytest +from argilla import Workspace from argilla.client import api from argilla.client.feedback.config import DatasetConfig from argilla.client.feedback.dataset import FeedbackDataset @@ -654,6 +655,7 @@ def test_push_to_huggingface_and_from_huggingface( "feedback_dataset_records", ) def test_prepare_for_training_text_classification( + owner: "ServerUser", framework: Union[Framework, str], question: str, feedback_dataset_guidelines: str, @@ -667,7 +669,48 @@ def test_prepare_for_training_text_classification( questions=feedback_dataset_questions, ) dataset.add_records(feedback_dataset_records) - label = dataset.question_by_name(question) + + api.init(api_key=owner.api_key) + ws = Workspace.create(name="test-workspace") + + remote = dataset.push_to_argilla(name="test-dataset", workspace=ws) + + label = remote.question_by_name(question) task = TrainingTask.for_text_classification(text=dataset.fields[0], label=label) - dataset.prepare_for_training(framework=framework, task=task) + data = remote.prepare_for_training(framework=framework, task=task) + assert data is not None + + +@pytest.mark.usefixtures( + "feedback_dataset_guidelines", + "feedback_dataset_fields", + "feedback_dataset_questions", + "feedback_dataset_records", +) +def test_warning_remote_dataset_methods( + feedback_dataset_guidelines: str, + feedback_dataset_fields: List["AllowedFieldTypes"], + feedback_dataset_questions: List["AllowedQuestionTypes"], + feedback_dataset_records: List[FeedbackRecord], +): + dataset = FeedbackDataset( + guidelines=feedback_dataset_guidelines, + fields=feedback_dataset_fields, + questions=feedback_dataset_questions, + ) + + with pytest.warns( + UserWarning, match="`pull` method is not supported for local datasets and won't take any effect." + ): + dataset.pull() + + with pytest.warns( + UserWarning, match="`filter_by` method is not supported for local datasets and won't take any effect." + ): + dataset.filter_by() + + with pytest.warns( + UserWarning, match="`delete` method is not supported for local datasets and won't take any effect." + ): + dataset.delete() diff --git a/tests/integration/client/feedback/dataset/remote/test_dataset.py b/tests/integration/client/feedback/dataset/remote/test_dataset.py index 8b53b8207c..181d92ffee 100644 --- a/tests/integration/client/feedback/dataset/remote/test_dataset.py +++ b/tests/integration/client/feedback/dataset/remote/test_dataset.py @@ -147,3 +147,28 @@ async def test_attributes(self, role: UserRole) -> None: assert isinstance(remote_dataset.url, str) assert isinstance(remote_dataset.created_at, datetime) assert isinstance(remote_dataset.updated_at, datetime) + + @pytest.mark.parametrize("role", [UserRole.owner, UserRole.admin]) + async def test_warning_local_methods(self, role: UserRole) -> None: + dataset = await DatasetFactory.create() + await TextFieldFactory.create(dataset=dataset, required=True) + await TextQuestionFactory.create(dataset=dataset, required=True) + await RecordFactory.create_batch(dataset=dataset, size=10) + user = await UserFactory.create(role=role, workspaces=[dataset.workspace]) + + api.init(api_key=user.api_key) + ds = FeedbackDataset.from_argilla(id=dataset.id) + + with pytest.raises(ValueError, match="`FeedbackRecord.fields` does not match the expected schema"): + with pytest.warns( + UserWarning, + match="A local `FeedbackDataset` returned because `unify_responses` is not supported for `RemoteFeedbackDataset`. ", + ): + ds.unify_responses(question=None, strategy=None) + + with pytest.raises(ValueError, match="`FeedbackRecord.fields` does not match the expected schema"): + with pytest.warns( + UserWarning, + match="A local `FeedbackDataset` returned because `prepare_for_training` is not supported for `RemoteFeedbackDataset`. ", + ): + ds.prepare_for_training(framework=None, task=None) diff --git a/tests/integration/client/feedback/dataset/remote/test_filtered.py b/tests/integration/client/feedback/dataset/remote/test_filtered.py index b109a97acb..4e798e8dcf 100644 --- a/tests/integration/client/feedback/dataset/remote/test_filtered.py +++ b/tests/integration/client/feedback/dataset/remote/test_filtered.py @@ -18,7 +18,7 @@ import pytest from argilla.client import api -from argilla.client.feedback.dataset.local import FeedbackDataset +from argilla.client.feedback.dataset.local.dataset import FeedbackDataset from argilla.client.feedback.dataset.remote.filtered import FilteredRemoteFeedbackDataset, FilteredRemoteFeedbackRecords from argilla.client.feedback.schemas.enums import ResponseStatusFilter from argilla.client.feedback.schemas.records import FeedbackRecord @@ -90,19 +90,27 @@ async def test_not_implemented_methods( filtered_dataset = remote_dataset.filter_by(response_status=status) assert isinstance(filtered_dataset, FilteredRemoteFeedbackDataset) - with pytest.raises(NotImplementedError, match="`records.delete` does not work for filtered datasets."): + with pytest.raises( + NotImplementedError, match="`records.delete` does not work for `FilteredRemoteFeedbackDataset`." + ): filtered_dataset.delete_records(remote_dataset.records[0]) - with pytest.raises(NotImplementedError, match="`records.delete` does not work for filtered datasets."): + with pytest.raises( + NotImplementedError, match="`records.delete` does not work for `FilteredRemoteFeedbackDataset`." + ): filtered_dataset.records.delete(remote_dataset.records[0]) - with pytest.raises(NotImplementedError, match="`records.add` does not work for filtered datasets."): + with pytest.raises( + NotImplementedError, match="`records.add` does not work for `FilteredRemoteFeedbackDataset`." + ): filtered_dataset.add_records(FeedbackRecord(fields={text_field.name: "test"})) - with pytest.raises(NotImplementedError, match="`records.add` does not work for filtered datasets."): + with pytest.raises( + NotImplementedError, match="`records.add` does not work for `FilteredRemoteFeedbackDataset`." + ): filtered_dataset.records.add(FeedbackRecord(fields={text_field.name: "test"})) - with pytest.raises(NotImplementedError, match="`delete` does not work for filtered datasets."): + with pytest.raises(NotImplementedError, match="`delete` does not work for `FilteredRemoteFeedbackDataset`."): filtered_dataset.delete() @pytest.mark.parametrize("role", [UserRole.owner, UserRole.admin]) diff --git a/tests/unit/cli/datasets/test_delete.py b/tests/unit/cli/datasets/test_delete.py index 0b9b819f87..25b58c367e 100644 --- a/tests/unit/cli/datasets/test_delete.py +++ b/tests/unit/cli/datasets/test_delete.py @@ -33,7 +33,7 @@ def test_delete_dataset( remote_feedback_dataset: "RemoteFeedbackDataset", ) -> None: dataset_from_argilla_mock = mocker.patch( - "argilla.client.feedback.dataset.local.FeedbackDataset.from_argilla", + "argilla.client.feedback.dataset.local.dataset.FeedbackDataset.from_argilla", return_value=remote_feedback_dataset, ) remote_feedback_dataset_delete_mock = mocker.patch( @@ -55,7 +55,7 @@ def test_delete_dataset_runtime_error( remote_feedback_dataset: "RemoteFeedbackDataset", ) -> None: dataset_from_argilla_mock = mocker.patch( - "argilla.client.feedback.dataset.local.FeedbackDataset.from_argilla", + "argilla.client.feedback.dataset.local.dataset.FeedbackDataset.from_argilla", return_value=remote_feedback_dataset, ) remote_feedback_dataset_delete_mock = mocker.patch( diff --git a/tests/unit/cli/datasets/test_list.py b/tests/unit/cli/datasets/test_list.py index 9c0ba99d9d..c2b33680aa 100644 --- a/tests/unit/cli/datasets/test_list.py +++ b/tests/unit/cli/datasets/test_list.py @@ -38,7 +38,7 @@ def test_list_datasets( ) -> None: add_row_spy = mocker.spy(Table, "add_row") feedback_dataset_list_mock = mocker.patch( - "argilla.client.feedback.dataset.local.FeedbackDataset.list", return_value=[remote_feedback_dataset] + "argilla.client.feedback.dataset.local.dataset.FeedbackDataset.list", return_value=[remote_feedback_dataset] ) list_datasets_mock = mocker.patch("argilla.client.api.list_datasets", return_value=[dataset]) @@ -74,7 +74,7 @@ def test_list_datasets( def test_list_datasets_with_workspace(self, cli_runner: "CliRunner", cli: "Typer", mocker: "MockerFixture") -> None: workspace_from_name_mock = mocker.patch("argilla.client.workspaces.Workspace.from_name") - feedback_dataset_list_mock = mocker.patch("argilla.client.feedback.dataset.local.FeedbackDataset.list") + feedback_dataset_list_mock = mocker.patch("argilla.client.feedback.dataset.local.dataset.FeedbackDataset.list") list_datasets_mock = mocker.patch("argilla.client.api.list_datasets") result = cli_runner.invoke(cli, "datasets list --workspace unit-test") @@ -98,7 +98,7 @@ def test_list_datasets_with_non_existing_workspace( def test_list_datasets_using_type_feedback_filter( self, cli_runner: "CliRunner", cli: "Typer", mocker: "MockerFixture" ) -> None: - feedback_dataset_list_mock = mocker.patch("argilla.client.feedback.dataset.local.FeedbackDataset.list") + feedback_dataset_list_mock = mocker.patch("argilla.client.feedback.dataset.local.dataset.FeedbackDataset.list") list_datasets_mock = mocker.patch("argilla.client.api.list_datasets") result = cli_runner.invoke(cli, "datasets list --type feedback") @@ -110,7 +110,7 @@ def test_list_datasets_using_type_feedback_filter( def test_list_datasets_using_type_other_filter( self, cli_runner: "CliRunner", cli: "Typer", mocker: "MockerFixture" ) -> None: - feedback_dataset_list_mock = mocker.patch("argilla.client.feedback.dataset.local.FeedbackDataset.list") + feedback_dataset_list_mock = mocker.patch("argilla.client.feedback.dataset.local.dataset.FeedbackDataset.list") list_datasets_mock = mocker.patch("argilla.client.api.list_datasets") result = cli_runner.invoke(cli, "datasets list --type other") diff --git a/tests/unit/cli/datasets/test_push.py b/tests/unit/cli/datasets/test_push.py index 30179b5924..02fdfbe173 100644 --- a/tests/unit/cli/datasets/test_push.py +++ b/tests/unit/cli/datasets/test_push.py @@ -33,7 +33,7 @@ def test_push_to_huggingface( remote_feedback_dataset: "RemoteFeedbackDataset", ) -> None: dataset_from_argilla_mock = mocker.patch( - "argilla.client.feedback.dataset.local.FeedbackDataset.from_argilla", + "argilla.client.feedback.dataset.local.dataset.FeedbackDataset.from_argilla", return_value=remote_feedback_dataset, ) push_to_huggingface_mock = mocker.patch( @@ -58,7 +58,7 @@ def test_push_to_huggingface_missing_repo_id_arg( remote_feedback_dataset: "RemoteFeedbackDataset", ) -> None: mocker.patch( - "argilla.client.feedback.dataset.local.FeedbackDataset.from_argilla", + "argilla.client.feedback.dataset.local.dataset.FeedbackDataset.from_argilla", return_value=remote_feedback_dataset, ) diff --git a/tests/unit/client/feedback/dataset/test_base.py b/tests/unit/client/feedback/dataset/test_base.py index 97ad8abcdb..401bd848c1 100644 --- a/tests/unit/client/feedback/dataset/test_base.py +++ b/tests/unit/client/feedback/dataset/test_base.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, Any, List import pytest from argilla.client.feedback.dataset.base import FeedbackDatasetBase @@ -27,10 +27,31 @@ class FeedbackDataset(FeedbackDatasetBase): - @property - def records(self) -> None: + def add_records(self, *args, **kwargs) -> None: + pass + + def pull(self): + return self + + def filter_by(self, *args, **kwargs): + return self + + def delete(self): pass + def prepare_for_training(self, *args, **kwargs) -> Any: + return [] + + def push_to_argilla(self, *args, **kwargs) -> "FeedbackDatasetBase": + return self + + def unify_responses(self, *args, **kwargs): + return self + + @property + def records(self) -> List[Any]: + return [] + def test_init( feedback_dataset_guidelines: str,