diff --git a/CHANGELOG.md b/CHANGELOG.md index 1479b11a08..a5ccc02363 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,8 @@ These are the section headers that we use: - Updated `Dockerfile` to use multi stage build ([#3221](https://github.com/argilla-io/argilla/pull/3221) and [#3793](https://github.com/argilla-io/argilla/pull/3793)). - Updated active learning for text classification notebooks to use the most recent small-text version ([#3831](https://github.com/argilla-io/argilla/pull/3831)). - Changed argilla dataset name in the active learning for text classification notebooks to be consistent with the default names in the huggingface spaces ([#3831](https://github.com/argilla-io/argilla/pull/3831)). +- FeedbackDataset API methods have been aligned to be accessible through the several implementations ([#3937](https://github.com/argilla-io/argilla/pull/3937)). +- The `unify_responses` support for remote datasets ([#3937](https://github.com/argilla-io/argilla/pull/3937)). ### Fixed diff --git a/src/argilla/cli/datasets/__main__.py b/src/argilla/cli/datasets/__main__.py index e6b4b60086..c6c022609f 100644 --- a/src/argilla/cli/datasets/__main__.py +++ b/src/argilla/cli/datasets/__main__.py @@ -32,7 +32,7 @@ def callback( init_callback() from argilla.cli.rich import echo_in_panel - from argilla.client.feedback.dataset.local import FeedbackDataset + from argilla.client.feedback.dataset.local.dataset import FeedbackDataset if ctx.invoked_subcommand not in _COMMANDS_REQUIRING_DATASET: return diff --git a/src/argilla/cli/datasets/list.py b/src/argilla/cli/datasets/list.py index 8b9b804138..88bd75c209 100644 --- a/src/argilla/cli/datasets/list.py +++ b/src/argilla/cli/datasets/list.py @@ -31,7 +31,7 @@ def list_datasets( from argilla.cli.rich import echo_in_panel, get_argilla_themed_table from argilla.client.api import list_datasets as list_datasets_api - from argilla.client.feedback.dataset.local import FeedbackDataset + from argilla.client.feedback.dataset.local.dataset import FeedbackDataset from argilla.client.workspaces import Workspace console = Console() diff --git a/src/argilla/cli/datasets/push.py b/src/argilla/cli/datasets/push.py index 93a64fd54e..5b1c386a25 100644 --- a/src/argilla/cli/datasets/push.py +++ b/src/argilla/cli/datasets/push.py @@ -29,7 +29,7 @@ def push_to_huggingface( from rich.spinner import Spinner from argilla.cli.rich import echo_in_panel - from argilla.client.feedback.dataset.local import FeedbackDataset + from argilla.client.feedback.dataset.local.dataset import FeedbackDataset dataset: "FeedbackDataset" = ctx.obj diff --git a/src/argilla/client/feedback/dataset/__init__.py b/src/argilla/client/feedback/dataset/__init__.py index 564e19262f..0f9646d49b 100644 --- a/src/argilla/client/feedback/dataset/__init__.py +++ b/src/argilla/client/feedback/dataset/__init__.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from argilla.client.feedback.dataset.local import FeedbackDataset +from argilla.client.feedback.dataset.local.dataset import FeedbackDataset __all__ = ["FeedbackDataset"] diff --git a/src/argilla/client/feedback/dataset/base.py b/src/argilla/client/feedback/dataset/base.py index 7a5edbc06a..cbe891d9ae 100644 --- a/src/argilla/client/feedback/dataset/base.py +++ b/src/argilla/client/feedback/dataset/base.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging -from abc import ABC, abstractproperty +from abc import ABC, ABCMeta, abstractmethod from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union from pydantic import ValidationError @@ -21,23 +20,10 @@ from argilla.client.feedback.integrations.huggingface import HuggingFaceDatasetMixin from argilla.client.feedback.schemas import ( FeedbackRecord, - FieldSchema, ) from argilla.client.feedback.schemas.types import AllowedFieldTypes, AllowedQuestionTypes -from argilla.client.feedback.training.schemas import ( - TrainingTaskForChatCompletion, - TrainingTaskForDPO, - TrainingTaskForPPO, - TrainingTaskForQuestionAnswering, - TrainingTaskForRM, - TrainingTaskForSentenceSimilarity, - TrainingTaskForSFT, - TrainingTaskForTextClassification, - TrainingTaskTypes, -) from argilla.client.feedback.utils import generate_pydantic_schema -from argilla.client.models import Framework -from argilla.utils.dependency import require_dependencies, requires_dependencies +from argilla.utils.dependency import requires_dependencies if TYPE_CHECKING: from datasets import Dataset @@ -48,10 +34,7 @@ ) -_LOGGER = logging.getLogger(__name__) - - -class FeedbackDatasetBase(ABC, HuggingFaceDatasetMixin): +class FeedbackDatasetBase(ABC, HuggingFaceDatasetMixin, metaclass=ABCMeta): """Base class with shared functionality for `FeedbackDataset` and `RemoteFeedbackDataset`.""" def __init__( @@ -135,7 +118,7 @@ def __init__( self._guidelines = guidelines @property - @abstractproperty + @abstractmethod def records(self) -> Any: """Returns the records of the dataset.""" pass @@ -189,6 +172,11 @@ def question_by_name(self, name: str) -> Union[AllowedQuestionTypes, "AllowedRem f" {', '.join(q.name for q in self._questions)}" ) + @abstractmethod + def add_records(self, *args, **kwargs) -> None: + """Adds the given records to the `FeedbackDataset`.""" + pass + def _parse_records( self, records: Union[FeedbackRecord, Dict[str, Any], List[Union[FeedbackRecord, Dict[str, Any]]]] ) -> List[FeedbackRecord]: @@ -275,110 +263,32 @@ def format_as(self, format: Literal["datasets"]) -> "Dataset": return self._huggingface_format(self) raise ValueError(f"Unsupported format '{format}'.") - # TODO(alvarobartt,davidberenstein1957): we should consider having something like - # `export(..., training=True)` to export the dataset records in any format, replacing - # both `format_as` and `prepare_for_training` - def prepare_for_training( - self, - framework: Union[Framework, str], - task: TrainingTaskTypes, - train_size: Optional[float] = 1, - test_size: Optional[float] = None, - seed: Optional[int] = None, - lang: Optional[str] = None, - ) -> Any: - """ - Prepares the dataset for training for a specific training framework and NLP task by splitting the dataset into train and test sets. + @abstractmethod + def pull(self): + """Pulls the dataset from Argilla and returns a local instance of it.""" + pass - Args: - framework: the framework to use for training. Currently supported frameworks are: `transformers`, `peft`, - `setfit`, `spacy`, `spacy-transformers`, `span_marker`, `spark-nlp`, `openai`, `trl`, `sentence-transformers`. - task: the NLP task to use for training. Currently supported tasks are: `TrainingTaskForTextClassification`, - `TrainingTaskForSFT`, `TrainingTaskForRM`, `TrainingTaskForPPO`, `TrainingTaskForDPO`, `TrainingTaskForSentenceSimilarity`. - train_size: the size of the train set. If `None`, the whole dataset will be used for training. - test_size: the size of the test set. If `None`, the whole dataset will be used for testing. - seed: the seed to use for splitting the dataset into train and test sets. - lang: the spaCy language to use for training. If `None`, the language of the dataset will be used. - """ - if isinstance(framework, str): - framework = Framework(framework) - - # validate train and test sizes - if train_size is None: - train_size = 1 - if test_size is None: - test_size = 1 - train_size - - # check if all numbers are larger than 0 - if not [abs(train_size), abs(test_size)] == [train_size, test_size]: - raise ValueError("`train_size` and `test_size` must be larger than 0.") - # check if train sizes sum up to 1 - if not (train_size + test_size) == 1: - raise ValueError("`train_size` and `test_size` must sum to 1.") - - if test_size == 0: - test_size = None - - if len(self.records) < 1: - raise ValueError( - "No records found in the dataset. Make sure you add records to the" - " dataset via the `FeedbackDataset.add_records` method first." - ) - - if isinstance(task, (TrainingTaskForTextClassification, TrainingTaskForSentenceSimilarity)): - if task.formatting_func is None: - # in sentence-transformer models we can train without labels - if task.label: - self.unify_responses(question=task.label.question, strategy=task.label.strategy) - elif isinstance(task, TrainingTaskForQuestionAnswering): - if task.formatting_func is None: - self.unify_responses(question=task.answer.name, strategy="disagreement") - elif not isinstance( - task, - ( - TrainingTaskForSFT, - TrainingTaskForRM, - TrainingTaskForPPO, - TrainingTaskForDPO, - TrainingTaskForChatCompletion, - ), - ): - raise ValueError(f"Training data {type(task)} is not supported yet") - - data = task._format_data(self) - if framework in [ - Framework.TRANSFORMERS, - Framework.SETFIT, - Framework.SPAN_MARKER, - Framework.PEFT, - ]: - return task._prepare_for_training_with_transformers( - data=data, train_size=train_size, seed=seed, framework=framework - ) - elif framework in [Framework.SPACY, Framework.SPACY_TRANSFORMERS]: - require_dependencies("spacy") - import spacy - - if lang is None: - _LOGGER.warning("spaCy `lang` is not provided. Using `en`(English) as default language.") - lang = spacy.blank("en") - elif lang.isinstance(str): - if len(lang) == 2: - lang = spacy.blank(lang) - else: - lang = spacy.load(lang) - return task._prepare_for_training_with_spacy(data=data, train_size=train_size, seed=seed, lang=lang) - elif framework is Framework.SPARK_NLP: - return task._prepare_for_training_with_spark_nlp(data=data, train_size=train_size, seed=seed) - elif framework is Framework.OPENAI: - return task._prepare_for_training_with_openai(data=data, train_size=train_size, seed=seed) - elif framework is Framework.TRL: - return task._prepare_for_training_with_trl(data=data, train_size=train_size, seed=seed) - elif framework is Framework.TRLX: - return task._prepare_for_training_with_trlx(data=data, train_size=train_size, seed=seed) - elif framework is Framework.SENTENCE_TRANSFORMERS: - return task._prepare_for_training_with_sentence_transformers(data=data, train_size=train_size, seed=seed) - else: - raise NotImplementedError( - f"Framework {framework} is not supported. Choose from: {[e.value for e in Framework]}" - ) + @abstractmethod + def filter_by(self, *args, **kwargs): + """Filters the current `FeedbackDataset`.""" + pass + + @abstractmethod + def delete(self): + """Deletes the `FeedbackDataset` from Argilla.""" + pass + + @abstractmethod + def prepare_for_training(self, *args, **kwargs) -> Any: + """Prepares the `FeedbackDataset` for training by creating the training.""" + pass + + @abstractmethod + def push_to_argilla(self, *args, **kwargs) -> "FeedbackDatasetBase": + """Pushes the `FeedbackDataset` to Argilla.""" + pass + + @abstractmethod + def unify_responses(self, *args, **kwargs): + """Unifies the responses for a given question.""" + pass diff --git a/src/argilla/client/feedback/dataset/local.py b/src/argilla/client/feedback/dataset/local.py deleted file mode 100644 index ce3480ae72..0000000000 --- a/src/argilla/client/feedback/dataset/local.py +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union - -from argilla.client.feedback.constants import FETCHING_BATCH_SIZE -from argilla.client.feedback.dataset.base import FeedbackDatasetBase -from argilla.client.feedback.dataset.mixins import ArgillaMixin, UnificationMixin -from argilla.client.feedback.schemas.fields import TextField -from argilla.client.feedback.schemas.types import AllowedQuestionTypes - -if TYPE_CHECKING: - from argilla.client.feedback.schemas.records import FeedbackRecord - from argilla.client.feedback.schemas.types import AllowedFieldTypes - - -class FeedbackDataset(FeedbackDatasetBase, ArgillaMixin, UnificationMixin): - def __init__( - self, - *, - fields: List["AllowedFieldTypes"], - questions: List[AllowedQuestionTypes], - guidelines: Optional[str] = None, - ) -> None: - """Initializes a `FeedbackDataset` instance locally. - - Args: - fields: contains the fields that will define the schema of the records in the dataset. - questions: contains the questions that will be used to annotate the dataset. - guidelines: contains the guidelines for annotating the dataset. Defaults to `None`. - - Raises: - TypeError: if `fields` is not a list of `FieldSchema`. - ValueError: if `fields` does not contain at least one required field. - TypeError: if `questions` is not a list of `TextQuestion`, `RatingQuestion`, - `LabelQuestion`, and/or `MultiLabelQuestion`. - ValueError: if `questions` does not contain at least one required question. - TypeError: if `guidelines` is not None and not a string. - ValueError: if `guidelines` is an empty string. - - Examples: - >>> import argilla as rg - >>> rg.init(api_url="...", api_key="...") - >>> dataset = rg.FeedbackDataset( - ... fields=[ - ... rg.TextField(name="text", required=True), - ... rg.TextField(name="label", required=True), - ... ], - ... questions=[ - ... rg.TextQuestion( - ... name="question-1", - ... description="This is the first question", - ... required=True, - ... ), - ... rg.RatingQuestion( - ... name="question-2", - ... description="This is the second question", - ... required=True, - ... values=[1, 2, 3, 4, 5], - ... ), - ... rg.LabelQuestion( - ... name="question-3", - ... description="This is the third question", - ... required=True, - ... labels=["positive", "negative"], - ... ), - ... rg.MultiLabelQuestion( - ... name="question-4", - ... description="This is the fourth question", - ... required=True, - ... labels=["category-1", "category-2", "category-3"], - ... ), - ... ], - ... guidelines="These are the annotation guidelines.", - ... ) - """ - super().__init__(fields=fields, questions=questions, guidelines=guidelines) - - self._records = [] - - @property - def records(self) -> List["FeedbackRecord"]: - """Returns the records in the dataset.""" - return self._records - - def __repr__(self) -> str: - """Returns a string representation of the dataset.""" - return f"" - - def __len__(self) -> int: - """Returns the number of records in the dataset.""" - return len(self._records) - - def __getitem__(self, key: Union[slice, int]) -> Union["FeedbackRecord", List["FeedbackRecord"]]: - """Returns the record(s) at the given index(es). - - Args: - key: the index(es) of the record(s) to return. Can either be a single index or a slice. - - Returns: - Either the record of the given index, or a list with the records at the given indexes. - """ - if len(self._records) < 1: - raise RuntimeError( - "In order to get items from `FeedbackDataset` you need to add them first" " with `add_records`." - ) - if isinstance(key, int) and len(self._records) < key: - raise IndexError(f"This dataset contains {len(self)} records, so index {key} is out of range.") - return self._records[key] - - def iter(self, batch_size: Optional[int] = FETCHING_BATCH_SIZE) -> Iterator[List["FeedbackRecord"]]: - """Returns an iterator over the records in the dataset. - - Args: - batch_size: the size of the batches to return. Defaults to 100. - """ - for i in range(0, len(self._records), batch_size): - yield self._records[i : i + batch_size] - - def add_records( - self, - records: Union["FeedbackRecord", Dict[str, Any], List[Union["FeedbackRecord", Dict[str, Any]]]], - ) -> None: - """Adds the given records to the dataset, and stores them locally. If you are - planning to push those to Argilla, you will need to call `push_to_argilla` afterwards, - to both create the dataset in Argilla and push the records to it. Then, from a - `FeedbackDataset` pushed to Argilla, you'll just need to call `add_records` and - those will be automatically uploaded to Argilla. - - Args: - records: can be a single `FeedbackRecord`, a list of `FeedbackRecord`, - a single dictionary, or a list of dictionaries. If a dictionary is provided, - it will be converted to a `FeedbackRecord` internally. - - Raises: - ValueError: if the given records are an empty list. - ValueError: if the given records are neither: `FeedbackRecord`, list of `FeedbackRecord`, - list of dictionaries as a record or dictionary as a record. - ValueError: if the given records do not match the expected schema. - """ - records = self._parse_records(records) - self._validate_records(records) - - if len(self._records) > 0: - self._records += records - else: - self._records = records diff --git a/src/argilla/client/feedback/dataset/local/dataset.py b/src/argilla/client/feedback/dataset/local/dataset.py new file mode 100644 index 0000000000..10fa9ad299 --- /dev/null +++ b/src/argilla/client/feedback/dataset/local/dataset.py @@ -0,0 +1,370 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import textwrap +import warnings +from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union + +from argilla.client.feedback.constants import FETCHING_BATCH_SIZE +from argilla.client.feedback.dataset.base import FeedbackDatasetBase +from argilla.client.feedback.dataset.local.mixins import ArgillaMixin +from argilla.client.feedback.schemas.questions import ( + LabelQuestion, + MultiLabelQuestion, + RankingQuestion, + RatingQuestion, + TextQuestion, +) +from argilla.client.feedback.schemas.types import AllowedQuestionTypes +from argilla.client.feedback.training.schemas import ( + TrainingTaskForChatCompletion, + TrainingTaskForDPO, + TrainingTaskForPPO, + TrainingTaskForQuestionAnswering, + TrainingTaskForRM, + TrainingTaskForSentenceSimilarity, + TrainingTaskForSFT, + TrainingTaskForTextClassification, + TrainingTaskTypes, +) +from argilla.client.feedback.unification import ( + LabelQuestionStrategy, + MultiLabelQuestionStrategy, + RankingQuestionStrategy, + RatingQuestionStrategy, + TextQuestionStrategy, +) +from argilla.client.models import Framework +from argilla.utils.dependency import require_dependencies + +if TYPE_CHECKING: + from argilla.client.feedback.schemas.records import FeedbackRecord + from argilla.client.feedback.schemas.types import AllowedFieldTypes + + +_LOGGER = logging.getLogger(__name__) + + +class FeedbackDataset(ArgillaMixin, FeedbackDatasetBase): + def __init__( + self, + *, + fields: List["AllowedFieldTypes"], + questions: List[AllowedQuestionTypes], + guidelines: Optional[str] = None, + ) -> None: + """Initializes a `FeedbackDataset` instance locally. + + Args: + fields: contains the fields that will define the schema of the records in the dataset. + questions: contains the questions that will be used to annotate the dataset. + guidelines: contains the guidelines for annotating the dataset. Defaults to `None`. + + Raises: + TypeError: if `fields` is not a list of `FieldSchema`. + ValueError: if `fields` does not contain at least one required field. + TypeError: if `questions` is not a list of `TextQuestion`, `RatingQuestion`, + `LabelQuestion`, and/or `MultiLabelQuestion`. + ValueError: if `questions` does not contain at least one required question. + TypeError: if `guidelines` is not None and not a string. + ValueError: if `guidelines` is an empty string. + + Examples: + >>> import argilla as rg + >>> rg.init(api_url="...", api_key="...") + >>> dataset = rg.FeedbackDataset( + ... fields=[ + ... rg.TextField(name="text", required=True), + ... rg.TextField(name="label", required=True), + ... ], + ... questions=[ + ... rg.TextQuestion( + ... name="question-1", + ... description="This is the first question", + ... required=True, + ... ), + ... rg.RatingQuestion( + ... name="question-2", + ... description="This is the second question", + ... required=True, + ... values=[1, 2, 3, 4, 5], + ... ), + ... rg.LabelQuestion( + ... name="question-3", + ... description="This is the third question", + ... required=True, + ... labels=["positive", "negative"], + ... ), + ... rg.MultiLabelQuestion( + ... name="question-4", + ... description="This is the fourth question", + ... required=True, + ... labels=["category-1", "category-2", "category-3"], + ... ), + ... ], + ... guidelines="These are the annotation guidelines.", + ... ) + """ + super().__init__(fields=fields, questions=questions, guidelines=guidelines) + + self._records = [] + + @property + def records(self) -> List["FeedbackRecord"]: + """Returns the records in the dataset.""" + return self._records + + def __repr__(self) -> str: + """Returns a string representation of the dataset.""" + return ( + "FeedbackDataset(" + + textwrap.indent( + f"\nfields={self.fields}\nquestions={self.questions}\nguidelines={self.guidelines})", " " + ) + + "\n)" + ) + + def __len__(self) -> int: + """Returns the number of records in the dataset.""" + return len(self._records) + + def __getitem__(self, key: Union[slice, int]) -> Union["FeedbackRecord", List["FeedbackRecord"]]: + """Returns the record(s) at the given index(es). + + Args: + key: the index(es) of the record(s) to return. Can either be a single index or a slice. + + Returns: + Either the record of the given index, or a list with the records at the given indexes. + """ + if len(self._records) < 1: + raise RuntimeError( + "In order to get items from `FeedbackDataset` you need to add them first" " with `add_records`." + ) + if isinstance(key, int) and len(self._records) < key: + raise IndexError(f"This dataset contains {len(self)} records, so index {key} is out of range.") + return self._records[key] + + def iter(self, batch_size: Optional[int] = FETCHING_BATCH_SIZE) -> Iterator[List["FeedbackRecord"]]: + """Returns an iterator over the records in the dataset. + + Args: + batch_size: the size of the batches to return. Defaults to 100. + """ + for i in range(0, len(self._records), batch_size): + yield self._records[i : i + batch_size] + + def add_records( + self, records: Union["FeedbackRecord", Dict[str, Any], List[Union["FeedbackRecord", Dict[str, Any]]]] + ) -> None: + """Adds the given records to the dataset, and stores them locally. If you are + planning to push those to Argilla, you will need to call `push_to_argilla` afterwards, + to both create the dataset in Argilla and push the records to it. Then, from a + `FeedbackDataset` pushed to Argilla, you'll just need to call `add_records` and + those will be automatically uploaded to Argilla. + + Args: + records: can be a single `FeedbackRecord`, a list of `FeedbackRecord`, + a single dictionary, or a list of dictionaries. If a dictionary is provided, + it will be converted to a `FeedbackRecord` internally. + + Raises: + ValueError: if the given records are an empty list. + ValueError: if the given records are neither: `FeedbackRecord`, list of `FeedbackRecord`, + list of dictionaries as a record or dictionary as a record. + ValueError: if the given records do not match the expected schema. + """ + records = self._parse_records(records) + self._validate_records(records) + + if len(self._records) > 0: + self._records += records + else: + self._records = records + + def pull(self) -> "FeedbackDataset": + warnings.warn( + "`pull` method is not supported for local datasets and won't take any effect." + "First, you need to push the dataset to Argilla with `FeedbackDataset.push_to_argilla()`." + "After, use `FeedbackDataset.from_argilla(...).pull()`.", + UserWarning, + ) + return self + + def filter_by(self, *args, **kwargs) -> "FeedbackDataset": + warnings.warn( + "`filter_by` method is not supported for local datasets and won't take any effect. " + "First, you need to push the dataset to Argilla with `FeedbackDataset.push_to_argilla()`." + "After, use `FeedbackDataset.from_argilla(...).filter_by()`.", + UserWarning, + ) + return self + + def delete(self): + warnings.warn( + "`delete` method is not supported for local datasets and won't take any effect. " + "First, you need to push the dataset to Argilla with `FeedbackDataset.push_to_argilla`." + "After, use `FeedbackDataset.from_argilla(...).delete()`", + UserWarning, + ) + return self + + def unify_responses( + self: "FeedbackDatasetBase", + question: Union[str, LabelQuestion, MultiLabelQuestion, RatingQuestion], + strategy: Union[ + str, LabelQuestionStrategy, MultiLabelQuestionStrategy, RatingQuestionStrategy, RankingQuestionStrategy + ], + ) -> "FeedbackDataset": + """ + The `unify_responses` function takes a question and a strategy as input and applies the strategy + to unify the responses for that question. + + Args: + question The `question` parameter can be either a string representing the name of the + question, or an instance of one of the question classes (`LabelQuestion`, `MultiLabelQuestion`, + `RatingQuestion`, `RankingQuestion`). + strategy The `strategy` parameter is used to specify the strategy to be used for unifying + responses for a given question. It can be either a string or an instance of a strategy class. + """ + if isinstance(question, str): + question = self.question_by_name(question) + + if isinstance(strategy, str): + if isinstance(question, LabelQuestion): + strategy = LabelQuestionStrategy(strategy) + elif isinstance(question, MultiLabelQuestion): + strategy = MultiLabelQuestionStrategy(strategy) + elif isinstance(question, RatingQuestion): + strategy = RatingQuestionStrategy(strategy) + elif isinstance(question, RankingQuestion): + strategy = RankingQuestionStrategy(strategy) + elif isinstance(question, TextQuestion): + strategy = TextQuestionStrategy(strategy) + else: + raise ValueError(f"Question {question} is not supported yet") + + strategy.unify_responses(self.records, question) + return self + + # TODO(alvarobartt,davidberenstein1957): we should consider having something like + # `export(..., training=True)` to export the dataset records in any format, replacing + # both `format_as` and `prepare_for_training` + def prepare_for_training( + self, + framework: Union[Framework, str], + task: TrainingTaskTypes, + train_size: Optional[float] = 1, + test_size: Optional[float] = None, + seed: Optional[int] = None, + lang: Optional[str] = None, + ) -> Any: + """ + Prepares the dataset for training for a specific training framework and NLP task by splitting the dataset into train and test sets. + + Args: + framework: the framework to use for training. Currently supported frameworks are: `transformers`, `peft`, + `setfit`, `spacy`, `spacy-transformers`, `span_marker`, `spark-nlp`, `openai`, `trl`, `sentence-transformers`. + task: the NLP task to use for training. Currently supported tasks are: `TrainingTaskForTextClassification`, + `TrainingTaskForSFT`, `TrainingTaskForRM`, `TrainingTaskForPPO`, `TrainingTaskForDPO`, `TrainingTaskForSentenceSimilarity`. + train_size: the size of the train set. If `None`, the whole dataset will be used for training. + test_size: the size of the test set. If `None`, the whole dataset will be used for testing. + seed: the seed to use for splitting the dataset into train and test sets. + lang: the spaCy language to use for training. If `None`, the language of the dataset will be used. + """ + if isinstance(framework, str): + framework = Framework(framework) + + # validate train and test sizes + if train_size is None: + train_size = 1 + if test_size is None: + test_size = 1 - train_size + + # check if all numbers are larger than 0 + if not [abs(train_size), abs(test_size)] == [train_size, test_size]: + raise ValueError("`train_size` and `test_size` must be larger than 0.") + # check if train sizes sum up to 1 + if not (train_size + test_size) == 1: + raise ValueError("`train_size` and `test_size` must sum to 1.") + + if test_size == 0: + test_size = None + + if len(self.records) < 1: + raise ValueError( + "No records found in the dataset. Make sure you add records to the" + " dataset via the `FeedbackDataset.add_records()` method first." + ) + + local_dataset = self.pull() + if isinstance(task, (TrainingTaskForTextClassification, TrainingTaskForSentenceSimilarity)): + if task.formatting_func is None: + # in sentence-transformer models we can train without labels + if task.label: + local_dataset = local_dataset.unify_responses( + question=task.label.question, strategy=task.label.strategy + ) + elif isinstance(task, TrainingTaskForQuestionAnswering): + if task.formatting_func is None: + local_dataset = self.unify_responses(question=task.answer.name, strategy="disagreement") + elif not isinstance( + task, + ( + TrainingTaskForSFT, + TrainingTaskForRM, + TrainingTaskForPPO, + TrainingTaskForDPO, + TrainingTaskForChatCompletion, + ), + ): + raise ValueError(f"Training data {type(task)} is not supported yet") + + data = task._format_data(local_dataset) + if framework in [ + Framework.TRANSFORMERS, + Framework.SETFIT, + Framework.SPAN_MARKER, + Framework.PEFT, + ]: + return task._prepare_for_training_with_transformers( + data=data, train_size=train_size, seed=seed, framework=framework + ) + elif framework in [Framework.SPACY, Framework.SPACY_TRANSFORMERS]: + require_dependencies("spacy") + import spacy + + if lang is None: + _LOGGER.warning("spaCy `lang` is not provided. Using `en`(English) as default language.") + lang = spacy.blank("en") + elif lang.isinstance(str): + if len(lang) == 2: + lang = spacy.blank(lang) + else: + lang = spacy.load(lang) + return task._prepare_for_training_with_spacy(data=data, train_size=train_size, seed=seed, lang=lang) + elif framework is Framework.SPARK_NLP: + return task._prepare_for_training_with_spark_nlp(data=data, train_size=train_size, seed=seed) + elif framework is Framework.OPENAI: + return task._prepare_for_training_with_openai(data=data, train_size=train_size, seed=seed) + elif framework is Framework.TRL: + return task._prepare_for_training_with_trl(data=data, train_size=train_size, seed=seed) + elif framework is Framework.TRLX: + return task._prepare_for_training_with_trlx(data=data, train_size=train_size, seed=seed) + elif framework is Framework.SENTENCE_TRANSFORMERS: + return task._prepare_for_training_with_sentence_transformers(data=data, train_size=train_size, seed=seed) + else: + raise NotImplementedError( + f"Framework {framework} is not supported. Choose from: {[e.value for e in Framework]}" + ) diff --git a/src/argilla/client/feedback/dataset/mixins.py b/src/argilla/client/feedback/dataset/local/mixins.py similarity index 85% rename from src/argilla/client/feedback/dataset/mixins.py rename to src/argilla/client/feedback/dataset/local/mixins.py index 4994795086..6788e57785 100644 --- a/src/argilla/client/feedback/dataset/mixins.py +++ b/src/argilla/client/feedback/dataset/local/mixins.py @@ -15,19 +15,10 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union from uuid import UUID -from tqdm import trange - from argilla.client.api import ArgillaSingleton from argilla.client.feedback.constants import PUSHING_BATCH_SIZE from argilla.client.feedback.dataset.remote.dataset import RemoteFeedbackDataset from argilla.client.feedback.schemas.enums import FieldTypes, QuestionTypes -from argilla.client.feedback.schemas.questions import ( - LabelQuestion, - MultiLabelQuestion, - RankingQuestion, - RatingQuestion, - TextQuestion, -) from argilla.client.feedback.schemas.remote.fields import RemoteTextField from argilla.client.feedback.schemas.remote.questions import ( RemoteLabelQuestion, @@ -36,28 +27,22 @@ RemoteRatingQuestion, RemoteTextQuestion, ) -from argilla.client.feedback.unification import ( - LabelQuestionStrategy, - MultiLabelQuestionStrategy, - RankingQuestionStrategy, - RatingQuestionStrategy, - TextQuestionStrategy, -) from argilla.client.feedback.utils import feedback_dataset_in_argilla from argilla.client.sdk.v1.datasets import api as datasets_api_v1 from argilla.client.workspaces import Workspace +from tqdm import trange if TYPE_CHECKING: import httpx - from argilla.client.client import Argilla as ArgillaClient - from argilla.client.feedback.dataset.local import FeedbackDataset + from argilla.client.feedback.dataset.local.dataset import FeedbackDataset from argilla.client.feedback.schemas.types import AllowedRemoteFieldTypes, AllowedRemoteQuestionTypes from argilla.client.sdk.v1.datasets.models import FeedbackDatasetModel class ArgillaMixin: - def __delete_dataset(self: "FeedbackDataset", client: "httpx.Client", id: UUID) -> None: + @staticmethod + def __delete_dataset(client: "httpx.Client", id: Union[str, UUID]) -> None: try: datasets_api_v1.delete_dataset(client=client, id=id) except Exception as e: @@ -342,42 +327,3 @@ def list(cls: Type["FeedbackDataset"], workspace: Optional[str] = None) -> List[ ) for dataset in datasets ] - - -class UnificationMixin: - def unify_responses( - self, - question: Union[str, LabelQuestion, MultiLabelQuestion, RatingQuestion], - strategy: Union[ - str, LabelQuestionStrategy, MultiLabelQuestionStrategy, RatingQuestionStrategy, RankingQuestionStrategy - ], - ) -> None: - """ - The `unify_responses` function takes a question and a strategy as input and applies the strategy - to unify the responses for that question. - - Args: - question The `question` parameter can be either a string representing the name of the - question, or an instance of one of the question classes (`LabelQuestion`, `MultiLabelQuestion`, - `RatingQuestion`, `RankingQuestion`). - strategy The `strategy` parameter is used to specify the strategy to be used for unifying - responses for a given question. It can be either a string or an instance of a strategy class. - """ - if isinstance(question, str): - question = self.question_by_name(question) - - if isinstance(strategy, str): - if isinstance(question, LabelQuestion): - strategy = LabelQuestionStrategy(strategy) - elif isinstance(question, MultiLabelQuestion): - strategy = MultiLabelQuestionStrategy(strategy) - elif isinstance(question, RatingQuestion): - strategy = RatingQuestionStrategy(strategy) - elif isinstance(question, RankingQuestion): - strategy = RankingQuestionStrategy(strategy) - elif isinstance(question, TextQuestion): - strategy = TextQuestionStrategy(strategy) - else: - raise ValueError(f"Question {question} is not supported yet") - - strategy.unify_responses(self.records, question) diff --git a/src/argilla/client/feedback/dataset/remote/base.py b/src/argilla/client/feedback/dataset/remote/base.py index 4098b4c7cf..9830c41814 100644 --- a/src/argilla/client/feedback/dataset/remote/base.py +++ b/src/argilla/client/feedback/dataset/remote/base.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import textwrap import warnings from abc import ABC, abstractmethod from datetime import datetime from typing import TYPE_CHECKING, Any, Dict, Generic, Iterator, List, Optional, Type, TypeVar, Union +from argilla import Workspace from argilla.client.feedback.dataset.base import FeedbackDatasetBase from argilla.client.feedback.dataset.remote.mixins import ArgillaRecordsMixin from argilla.client.feedback.schemas.remote.records import RemoteFeedbackRecord @@ -28,7 +30,7 @@ import httpx - from argilla.client.feedback.dataset.local import FeedbackDataset + from argilla.client.feedback.dataset.local.dataset import FeedbackDataset from argilla.client.feedback.schemas.records import FeedbackRecord from argilla.client.feedback.schemas.types import AllowedRemoteFieldTypes, AllowedRemoteQuestionTypes from argilla.client.sdk.v1.datasets.models import FeedbackRecordsModel @@ -91,7 +93,7 @@ def delete(self) -> None: pass -class RemoteFeedbackDatasetBase(Generic[T], FeedbackDatasetBase): +class RemoteFeedbackDatasetBase(FeedbackDatasetBase, Generic[T]): records_cls: Type[T] def __init__( @@ -186,10 +188,17 @@ def updated_at(self) -> datetime: def __repr__(self) -> str: """Returns a string representation of the dataset.""" + indent = " " return ( - f"" + "RemoteFeedbackDataset(" + + textwrap.indent(f"\nid={self.id}", indent) + + textwrap.indent(f"\nname={self.name}", indent) + + textwrap.indent(f"\nworkspace={self.workspace}", indent) + + textwrap.indent(f"\nurl={self.url}", indent) + + textwrap.indent(f"\nfields={self.fields}", indent) + + textwrap.indent(f"\nquestions={self.questions}", indent) + + textwrap.indent(f"\nguidelines={self.guidelines}", indent) + + ")" ) def __len__(self) -> int: @@ -252,7 +261,7 @@ def pull(self) -> "FeedbackDataset": A local instance of the dataset which is a `FeedbackDataset` object. """ # Importing here to avoid circular imports - from argilla.client.feedback.dataset.local import FeedbackDataset + from argilla.client.feedback.dataset.local.dataset import FeedbackDataset instance = FeedbackDataset( fields=self.fields, @@ -263,3 +272,11 @@ def pull(self) -> "FeedbackDataset": records=[record.to_local() for record in self._records], ) return instance + + def push_to_argilla( + self, name: str, workspace: Optional[Union[str, "Workspace"]] = None, show_progress: bool = False + ) -> "RemoteFeedbackDatasetBase": + warnings.warn( + "Already pushed datasets cannot be pushed to Argilla again because they are synced automatically." + ) + return self diff --git a/src/argilla/client/feedback/dataset/remote/dataset.py b/src/argilla/client/feedback/dataset/remote/dataset.py index d6b2eac040..4d1a7d69b3 100644 --- a/src/argilla/client/feedback/dataset/remote/dataset.py +++ b/src/argilla/client/feedback/dataset/remote/dataset.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from datetime import datetime from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union @@ -20,8 +21,23 @@ from argilla.client.feedback.constants import DELETE_DATASET_RECORDS_MAX_NUMBER, PUSHING_BATCH_SIZE from argilla.client.feedback.dataset.remote.base import RemoteFeedbackDatasetBase, RemoteFeedbackRecordsBase from argilla.client.feedback.dataset.remote.filtered import FilteredRemoteFeedbackDataset +from argilla.client.feedback.schemas.questions import ( + LabelQuestion, + MultiLabelQuestion, + RatingQuestion, +) from argilla.client.feedback.schemas.records import FeedbackRecord from argilla.client.feedback.schemas.remote.records import RemoteFeedbackRecord +from argilla.client.feedback.training.schemas import ( + TrainingTaskTypes, +) +from argilla.client.feedback.unification import ( + LabelQuestionStrategy, + MultiLabelQuestionStrategy, + RankingQuestionStrategy, + RatingQuestionStrategy, +) +from argilla.client.models import Framework from argilla.client.sdk.users.models import UserRole from argilla.client.sdk.v1.datasets import api as datasets_api_v1 from argilla.client.utils import allowed_for_roles @@ -31,6 +47,7 @@ import httpx + from argilla.client.feedback.dataset.local import FeedbackDataset from argilla.client.feedback.schemas.enums import ResponseStatusFilter from argilla.client.feedback.schemas.types import AllowedRemoteFieldTypes, AllowedRemoteQuestionTypes from argilla.client.sdk.v1.datasets.models import FeedbackRecordsModel @@ -188,3 +205,70 @@ def delete(self) -> None: datasets_api_v1.delete_dataset(client=self._client, id=self.id) except Exception as e: raise RuntimeError(f"Failed while deleting the `FeedbackDataset` from Argilla with exception: {e}") from e + + def unify_responses( + self, + question: Union[str, LabelQuestion, MultiLabelQuestion, RatingQuestion], + strategy: Union[ + str, LabelQuestionStrategy, MultiLabelQuestionStrategy, RatingQuestionStrategy, RankingQuestionStrategy + ], + ) -> "FeedbackDataset": + """ + The `unify_responses` function takes a question and a strategy as input and applies the strategy + to unify the responses for that question. + + Args: + question The `question` parameter can be either a string representing the name of the + question, or an instance of one of the question classes (`LabelQuestion`, `MultiLabelQuestion`, + `RatingQuestion`, `RankingQuestion`). + strategy The `strategy` parameter is used to specify the strategy to be used for unifying + responses for a given question. It can be either a string or an instance of a strategy class. + """ + warnings.warn( + "A local `FeedbackDataset` returned because " + "`unify_responses` is not supported for `RemoteFeedbackDataset`. " + "`RemoteFeedbackDataset`.pull().unify_responses(*args, **kwargs)` is applied.", + UserWarning, + ) + local = self.pull() + return local.unify_responses(question=question, strategy=strategy) + + def prepare_for_training( + self, + framework: Union[Framework, str], + task: TrainingTaskTypes, + train_size: Optional[float] = 1, + test_size: Optional[float] = None, + seed: Optional[int] = None, + lang: Optional[str] = None, + ) -> Any: + """ + Prepares the dataset for training for a specific training framework and NLP task by splitting the dataset into train and test sets. + + Args: + framework: the framework to use for training. Currently supported frameworks are: `transformers`, `peft`, + `setfit`, `spacy`, `spacy-transformers`, `span_marker`, `spark-nlp`, `openai`, `trl`, `sentence-transformers`. + task: the NLP task to use for training. Currently supported tasks are: `TrainingTaskForTextClassification`, + `TrainingTaskForSFT`, `TrainingTaskForRM`, `TrainingTaskForPPO`, `TrainingTaskForDPO`, `TrainingTaskForSentenceSimilarity`. + train_size: the size of the train set. If `None`, the whole dataset will be used for training. + test_size: the size of the test set. If `None`, the whole dataset will be used for testing. + seed: the seed to use for splitting the dataset into train and test sets. + lang: the spaCy language to use for training. If `None`, the language of the dataset will be used. + """ + warnings.warn( + ( + "A local `FeedbackDataset` returned because " + "`prepare_for_training` is not supported for `RemoteFeedbackDataset`. " + "`RemoteFeedbackDataset`.pull().prepare_for_training(*args, **kwargs)` is applied." + ), + UserWarning, + ) + local = self.pull() + return local.prepare_for_training( + framework=framework, + task=task, + train_size=train_size, + test_size=test_size, + seed=seed, + lang=lang, + ) diff --git a/src/argilla/client/feedback/dataset/remote/filtered.py b/src/argilla/client/feedback/dataset/remote/filtered.py index 0a5175918b..109417eada 100644 --- a/src/argilla/client/feedback/dataset/remote/filtered.py +++ b/src/argilla/client/feedback/dataset/remote/filtered.py @@ -24,6 +24,7 @@ import httpx from argilla.client.feedback.dataset.remote.dataset import RemoteFeedbackDataset + from argilla.client.feedback.schemas.enums import ResponseStatusFilter from argilla.client.feedback.schemas.records import FeedbackRecord from argilla.client.feedback.schemas.remote.records import RemoteFeedbackRecord from argilla.client.feedback.schemas.types import AllowedRemoteFieldTypes, AllowedRemoteQuestionTypes @@ -38,7 +39,7 @@ def __init__(self, dataset: "RemoteFeedbackDataset", filters: Dict[str, Any]) -> self._filters = filters def __len__(self) -> None: - raise NotImplementedError("`__len__` does not work for filtered datasets.") + raise NotImplementedError("`__len__` does not work for `FilteredRemoteFeedbackDataset`.") def _fetch_records(self, offset: int, limit: int) -> "FeedbackRecordsModel": """Fetches a batch of records from Argilla.""" @@ -55,10 +56,10 @@ def add( records: Union["FeedbackRecord", Dict[str, Any], List[Union["FeedbackRecord", Dict[str, Any]]]], show_progress: bool = True, ) -> None: - raise NotImplementedError("`records.add` does not work for filtered datasets.") + raise NotImplementedError("`records.add` does not work for `FilteredRemoteFeedbackDataset`.") def delete(self, records: List["RemoteFeedbackRecord"]) -> None: - raise NotImplementedError("`records.delete` does not work for filtered datasets.") + raise NotImplementedError("`records.delete` does not work for `FilteredRemoteFeedbackDataset`.") class FilteredRemoteFeedbackDataset(RemoteFeedbackDatasetBase[FilteredRemoteFeedbackRecords]): @@ -93,4 +94,31 @@ def __init__( ) def delete(self) -> None: - raise NotImplementedError("`delete` does not work for filtered datasets.") + raise NotImplementedError("`delete` does not work for `FilteredRemoteFeedbackDataset`.") + + def prepare_for_training(self, *args, **kwargs) -> Any: + raise NotImplementedError("`prepare_for_training` does not work for `FilteredRemoteFeedbackDataset`.") + + def unify_responses(self, *args, **kwargs): + raise NotImplementedError("`unify_responses` does not work for `FilteredRemoteFeedbackDataset`.") + + def filter_by( + self, response_status: Union["ResponseStatusFilter", List["ResponseStatusFilter"]] + ) -> "FilteredRemoteFeedbackDataset": + if not isinstance(response_status, list): + response_status = [response_status] + + return self.__class__( + client=self._client, + id=self.id, + name=self.name, + workspace=self.workspace, + created_at=self.created_at, + updated_at=self.updated_at, + fields=self.fields, + questions=self.questions, + guidelines=self.guidelines, + filters={ + "response_status": [status.value if hasattr(status, "value") else status for status in response_status] + }, + ) diff --git a/src/argilla/client/feedback/schemas/remote/shared.py b/src/argilla/client/feedback/schemas/remote/shared.py index ce3ecb952a..cbb0933f91 100644 --- a/src/argilla/client/feedback/schemas/remote/shared.py +++ b/src/argilla/client/feedback/schemas/remote/shared.py @@ -42,7 +42,7 @@ def to_local(self) -> BaseModel: @classmethod @abstractmethod - def from_api(cls) -> Type["RemoteSchema"]: + def from_api(cls, payload: "BaseModel") -> Type["RemoteSchema"]: """Abstract method to be implemented by subclasses to convert the API payload into a remote schema.""" raise NotImplementedError diff --git a/tests/integration/client/feedback/dataset/test_dataset.py b/tests/integration/client/feedback/dataset/local/test_dataset.py similarity index 94% rename from tests/integration/client/feedback/dataset/test_dataset.py rename to tests/integration/client/feedback/dataset/local/test_dataset.py index 980c2a6dbe..6348b3c145 100644 --- a/tests/integration/client/feedback/dataset/test_dataset.py +++ b/tests/integration/client/feedback/dataset/local/test_dataset.py @@ -17,6 +17,7 @@ import datasets import pytest +from argilla import Workspace from argilla.client import api from argilla.client.feedback.config import DatasetConfig from argilla.client.feedback.dataset import FeedbackDataset @@ -654,6 +655,7 @@ def test_push_to_huggingface_and_from_huggingface( "feedback_dataset_records", ) def test_prepare_for_training_text_classification( + owner: "ServerUser", framework: Union[Framework, str], question: str, feedback_dataset_guidelines: str, @@ -667,7 +669,48 @@ def test_prepare_for_training_text_classification( questions=feedback_dataset_questions, ) dataset.add_records(feedback_dataset_records) - label = dataset.question_by_name(question) + + api.init(api_key=owner.api_key) + ws = Workspace.create(name="test-workspace") + + remote = dataset.push_to_argilla(name="test-dataset", workspace=ws) + + label = remote.question_by_name(question) task = TrainingTask.for_text_classification(text=dataset.fields[0], label=label) - dataset.prepare_for_training(framework=framework, task=task) + data = remote.prepare_for_training(framework=framework, task=task) + assert data is not None + + +@pytest.mark.usefixtures( + "feedback_dataset_guidelines", + "feedback_dataset_fields", + "feedback_dataset_questions", + "feedback_dataset_records", +) +def test_warning_remote_dataset_methods( + feedback_dataset_guidelines: str, + feedback_dataset_fields: List["AllowedFieldTypes"], + feedback_dataset_questions: List["AllowedQuestionTypes"], + feedback_dataset_records: List[FeedbackRecord], +): + dataset = FeedbackDataset( + guidelines=feedback_dataset_guidelines, + fields=feedback_dataset_fields, + questions=feedback_dataset_questions, + ) + + with pytest.warns( + UserWarning, match="`pull` method is not supported for local datasets and won't take any effect." + ): + dataset.pull() + + with pytest.warns( + UserWarning, match="`filter_by` method is not supported for local datasets and won't take any effect." + ): + dataset.filter_by() + + with pytest.warns( + UserWarning, match="`delete` method is not supported for local datasets and won't take any effect." + ): + dataset.delete() diff --git a/tests/integration/client/feedback/dataset/remote/test_dataset.py b/tests/integration/client/feedback/dataset/remote/test_dataset.py index 8b53b8207c..181d92ffee 100644 --- a/tests/integration/client/feedback/dataset/remote/test_dataset.py +++ b/tests/integration/client/feedback/dataset/remote/test_dataset.py @@ -147,3 +147,28 @@ async def test_attributes(self, role: UserRole) -> None: assert isinstance(remote_dataset.url, str) assert isinstance(remote_dataset.created_at, datetime) assert isinstance(remote_dataset.updated_at, datetime) + + @pytest.mark.parametrize("role", [UserRole.owner, UserRole.admin]) + async def test_warning_local_methods(self, role: UserRole) -> None: + dataset = await DatasetFactory.create() + await TextFieldFactory.create(dataset=dataset, required=True) + await TextQuestionFactory.create(dataset=dataset, required=True) + await RecordFactory.create_batch(dataset=dataset, size=10) + user = await UserFactory.create(role=role, workspaces=[dataset.workspace]) + + api.init(api_key=user.api_key) + ds = FeedbackDataset.from_argilla(id=dataset.id) + + with pytest.raises(ValueError, match="`FeedbackRecord.fields` does not match the expected schema"): + with pytest.warns( + UserWarning, + match="A local `FeedbackDataset` returned because `unify_responses` is not supported for `RemoteFeedbackDataset`. ", + ): + ds.unify_responses(question=None, strategy=None) + + with pytest.raises(ValueError, match="`FeedbackRecord.fields` does not match the expected schema"): + with pytest.warns( + UserWarning, + match="A local `FeedbackDataset` returned because `prepare_for_training` is not supported for `RemoteFeedbackDataset`. ", + ): + ds.prepare_for_training(framework=None, task=None) diff --git a/tests/integration/client/feedback/dataset/remote/test_filtered.py b/tests/integration/client/feedback/dataset/remote/test_filtered.py index b109a97acb..4e798e8dcf 100644 --- a/tests/integration/client/feedback/dataset/remote/test_filtered.py +++ b/tests/integration/client/feedback/dataset/remote/test_filtered.py @@ -18,7 +18,7 @@ import pytest from argilla.client import api -from argilla.client.feedback.dataset.local import FeedbackDataset +from argilla.client.feedback.dataset.local.dataset import FeedbackDataset from argilla.client.feedback.dataset.remote.filtered import FilteredRemoteFeedbackDataset, FilteredRemoteFeedbackRecords from argilla.client.feedback.schemas.enums import ResponseStatusFilter from argilla.client.feedback.schemas.records import FeedbackRecord @@ -90,19 +90,27 @@ async def test_not_implemented_methods( filtered_dataset = remote_dataset.filter_by(response_status=status) assert isinstance(filtered_dataset, FilteredRemoteFeedbackDataset) - with pytest.raises(NotImplementedError, match="`records.delete` does not work for filtered datasets."): + with pytest.raises( + NotImplementedError, match="`records.delete` does not work for `FilteredRemoteFeedbackDataset`." + ): filtered_dataset.delete_records(remote_dataset.records[0]) - with pytest.raises(NotImplementedError, match="`records.delete` does not work for filtered datasets."): + with pytest.raises( + NotImplementedError, match="`records.delete` does not work for `FilteredRemoteFeedbackDataset`." + ): filtered_dataset.records.delete(remote_dataset.records[0]) - with pytest.raises(NotImplementedError, match="`records.add` does not work for filtered datasets."): + with pytest.raises( + NotImplementedError, match="`records.add` does not work for `FilteredRemoteFeedbackDataset`." + ): filtered_dataset.add_records(FeedbackRecord(fields={text_field.name: "test"})) - with pytest.raises(NotImplementedError, match="`records.add` does not work for filtered datasets."): + with pytest.raises( + NotImplementedError, match="`records.add` does not work for `FilteredRemoteFeedbackDataset`." + ): filtered_dataset.records.add(FeedbackRecord(fields={text_field.name: "test"})) - with pytest.raises(NotImplementedError, match="`delete` does not work for filtered datasets."): + with pytest.raises(NotImplementedError, match="`delete` does not work for `FilteredRemoteFeedbackDataset`."): filtered_dataset.delete() @pytest.mark.parametrize("role", [UserRole.owner, UserRole.admin]) diff --git a/tests/unit/cli/datasets/test_delete.py b/tests/unit/cli/datasets/test_delete.py index 0b9b819f87..25b58c367e 100644 --- a/tests/unit/cli/datasets/test_delete.py +++ b/tests/unit/cli/datasets/test_delete.py @@ -33,7 +33,7 @@ def test_delete_dataset( remote_feedback_dataset: "RemoteFeedbackDataset", ) -> None: dataset_from_argilla_mock = mocker.patch( - "argilla.client.feedback.dataset.local.FeedbackDataset.from_argilla", + "argilla.client.feedback.dataset.local.dataset.FeedbackDataset.from_argilla", return_value=remote_feedback_dataset, ) remote_feedback_dataset_delete_mock = mocker.patch( @@ -55,7 +55,7 @@ def test_delete_dataset_runtime_error( remote_feedback_dataset: "RemoteFeedbackDataset", ) -> None: dataset_from_argilla_mock = mocker.patch( - "argilla.client.feedback.dataset.local.FeedbackDataset.from_argilla", + "argilla.client.feedback.dataset.local.dataset.FeedbackDataset.from_argilla", return_value=remote_feedback_dataset, ) remote_feedback_dataset_delete_mock = mocker.patch( diff --git a/tests/unit/cli/datasets/test_list.py b/tests/unit/cli/datasets/test_list.py index 9c0ba99d9d..c2b33680aa 100644 --- a/tests/unit/cli/datasets/test_list.py +++ b/tests/unit/cli/datasets/test_list.py @@ -38,7 +38,7 @@ def test_list_datasets( ) -> None: add_row_spy = mocker.spy(Table, "add_row") feedback_dataset_list_mock = mocker.patch( - "argilla.client.feedback.dataset.local.FeedbackDataset.list", return_value=[remote_feedback_dataset] + "argilla.client.feedback.dataset.local.dataset.FeedbackDataset.list", return_value=[remote_feedback_dataset] ) list_datasets_mock = mocker.patch("argilla.client.api.list_datasets", return_value=[dataset]) @@ -74,7 +74,7 @@ def test_list_datasets( def test_list_datasets_with_workspace(self, cli_runner: "CliRunner", cli: "Typer", mocker: "MockerFixture") -> None: workspace_from_name_mock = mocker.patch("argilla.client.workspaces.Workspace.from_name") - feedback_dataset_list_mock = mocker.patch("argilla.client.feedback.dataset.local.FeedbackDataset.list") + feedback_dataset_list_mock = mocker.patch("argilla.client.feedback.dataset.local.dataset.FeedbackDataset.list") list_datasets_mock = mocker.patch("argilla.client.api.list_datasets") result = cli_runner.invoke(cli, "datasets list --workspace unit-test") @@ -98,7 +98,7 @@ def test_list_datasets_with_non_existing_workspace( def test_list_datasets_using_type_feedback_filter( self, cli_runner: "CliRunner", cli: "Typer", mocker: "MockerFixture" ) -> None: - feedback_dataset_list_mock = mocker.patch("argilla.client.feedback.dataset.local.FeedbackDataset.list") + feedback_dataset_list_mock = mocker.patch("argilla.client.feedback.dataset.local.dataset.FeedbackDataset.list") list_datasets_mock = mocker.patch("argilla.client.api.list_datasets") result = cli_runner.invoke(cli, "datasets list --type feedback") @@ -110,7 +110,7 @@ def test_list_datasets_using_type_feedback_filter( def test_list_datasets_using_type_other_filter( self, cli_runner: "CliRunner", cli: "Typer", mocker: "MockerFixture" ) -> None: - feedback_dataset_list_mock = mocker.patch("argilla.client.feedback.dataset.local.FeedbackDataset.list") + feedback_dataset_list_mock = mocker.patch("argilla.client.feedback.dataset.local.dataset.FeedbackDataset.list") list_datasets_mock = mocker.patch("argilla.client.api.list_datasets") result = cli_runner.invoke(cli, "datasets list --type other") diff --git a/tests/unit/cli/datasets/test_push.py b/tests/unit/cli/datasets/test_push.py index 30179b5924..02fdfbe173 100644 --- a/tests/unit/cli/datasets/test_push.py +++ b/tests/unit/cli/datasets/test_push.py @@ -33,7 +33,7 @@ def test_push_to_huggingface( remote_feedback_dataset: "RemoteFeedbackDataset", ) -> None: dataset_from_argilla_mock = mocker.patch( - "argilla.client.feedback.dataset.local.FeedbackDataset.from_argilla", + "argilla.client.feedback.dataset.local.dataset.FeedbackDataset.from_argilla", return_value=remote_feedback_dataset, ) push_to_huggingface_mock = mocker.patch( @@ -58,7 +58,7 @@ def test_push_to_huggingface_missing_repo_id_arg( remote_feedback_dataset: "RemoteFeedbackDataset", ) -> None: mocker.patch( - "argilla.client.feedback.dataset.local.FeedbackDataset.from_argilla", + "argilla.client.feedback.dataset.local.dataset.FeedbackDataset.from_argilla", return_value=remote_feedback_dataset, ) diff --git a/tests/unit/client/feedback/dataset/test_base.py b/tests/unit/client/feedback/dataset/test_base.py index 97ad8abcdb..401bd848c1 100644 --- a/tests/unit/client/feedback/dataset/test_base.py +++ b/tests/unit/client/feedback/dataset/test_base.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, Any, List import pytest from argilla.client.feedback.dataset.base import FeedbackDatasetBase @@ -27,10 +27,31 @@ class FeedbackDataset(FeedbackDatasetBase): - @property - def records(self) -> None: + def add_records(self, *args, **kwargs) -> None: + pass + + def pull(self): + return self + + def filter_by(self, *args, **kwargs): + return self + + def delete(self): pass + def prepare_for_training(self, *args, **kwargs) -> Any: + return [] + + def push_to_argilla(self, *args, **kwargs) -> "FeedbackDatasetBase": + return self + + def unify_responses(self, *args, **kwargs): + return self + + @property + def records(self) -> List[Any]: + return [] + def test_init( feedback_dataset_guidelines: str,