Skip to content

Commit

Permalink
feat: Some feedback dataset improvements (argilla-io#3937)
Browse files Browse the repository at this point in the history
<!-- Thanks for your contribution! As part of our Community Growers
initiative 🌱, we're donating Justdiggit bunds in your name to reforest
sub-Saharan Africa. To claim your Community Growers certificate, please
contact David Berenstein in our Slack community or fill in this form
https://tally.so/r/n9XrxK once your PR has been merged. -->

# Description

The main improvements that this PR brings are:

1. **Feedback dataset class method unification**: All expected methods
are defined in the base class, so they must be available for each
dataset implementation. For those cases where the method has less sense
or is not implemented yet, the user will be notified with a warning
(@davidberenstein1957 review and improve, please)
2. **More general workflow with response unification**: The unification
workflow support dataset connected to Argilla. This means that the
`prepare_for_training` can be applied with remote datasets. The
`unify_responses` returns a dataset where responses are unified. As a
common practice, returning data is preferable to modifying values
internally. We can avoid weird side effects. So, the unification
workflow should be as:
```python
from argilla import MultiLabelQuestionStrategy, FeedbackDataset
​
dataset = FeedbackDataset.from_argilla(name="my-dataset")
strategy = MultiLabelQuestionStrategy("majority") # "disagreement", "majority_weighted (WIP)"

unified_dataset = dataset.unify_responses(
    question=dataset.question_by_name("tags"),
    strategy=strategy,
)

unified_dataset...
```

**Type of change**

(Please delete options that are not relevant. Remember to title the PR
according to the type of change)

- [X] Refactor (change restructuring the codebase without changing
functionality)
- [X] Improvement (change adding some improvement to an existing
functionality)

**How Has This Been Tested**

(Please describe the tests that you ran to verify your changes. And
ideally, reference `tests`)

Some of those flows have been tests locally

**Checklist**

- [ ] I added relevant documentation
- [x] I followed the style guidelines of this project
- [x] I did a self-review of my code
- [ ] I made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK)
(see text above)
- [x] I have added relevant notes to the `CHANGELOG.md` file (See
https://keepachangelog.com/)

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: davidberenstein1957 <[email protected]>
  • Loading branch information
3 people authored Oct 16, 2023
1 parent 37b7074 commit cc4cfdf
Show file tree
Hide file tree
Showing 20 changed files with 673 additions and 377 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ These are the section headers that we use:
- Updated `Dockerfile` to use multi stage build ([#3221](https://github.com/argilla-io/argilla/pull/3221) and [#3793](https://github.com/argilla-io/argilla/pull/3793)).
- Updated active learning for text classification notebooks to use the most recent small-text version ([#3831](https://github.com/argilla-io/argilla/pull/3831)).
- Changed argilla dataset name in the active learning for text classification notebooks to be consistent with the default names in the huggingface spaces ([#3831](https://github.com/argilla-io/argilla/pull/3831)).
- FeedbackDataset API methods have been aligned to be accessible through the several implementations ([#3937](https://github.com/argilla-io/argilla/pull/3937)).
- The `unify_responses` support for remote datasets ([#3937](https://github.com/argilla-io/argilla/pull/3937)).

### Fixed

Expand Down
2 changes: 1 addition & 1 deletion src/argilla/cli/datasets/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def callback(
init_callback()

from argilla.cli.rich import echo_in_panel
from argilla.client.feedback.dataset.local import FeedbackDataset
from argilla.client.feedback.dataset.local.dataset import FeedbackDataset

if ctx.invoked_subcommand not in _COMMANDS_REQUIRING_DATASET:
return
Expand Down
2 changes: 1 addition & 1 deletion src/argilla/cli/datasets/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def list_datasets(

from argilla.cli.rich import echo_in_panel, get_argilla_themed_table
from argilla.client.api import list_datasets as list_datasets_api
from argilla.client.feedback.dataset.local import FeedbackDataset
from argilla.client.feedback.dataset.local.dataset import FeedbackDataset
from argilla.client.workspaces import Workspace

console = Console()
Expand Down
2 changes: 1 addition & 1 deletion src/argilla/cli/datasets/push.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def push_to_huggingface(
from rich.spinner import Spinner

from argilla.cli.rich import echo_in_panel
from argilla.client.feedback.dataset.local import FeedbackDataset
from argilla.client.feedback.dataset.local.dataset import FeedbackDataset

dataset: "FeedbackDataset" = ctx.obj

Expand Down
2 changes: 1 addition & 1 deletion src/argilla/client/feedback/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from argilla.client.feedback.dataset.local import FeedbackDataset
from argilla.client.feedback.dataset.local.dataset import FeedbackDataset

__all__ = ["FeedbackDataset"]
164 changes: 37 additions & 127 deletions src/argilla/client/feedback/dataset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from abc import ABC, abstractproperty
from abc import ABC, ABCMeta, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union

from pydantic import ValidationError

from argilla.client.feedback.integrations.huggingface import HuggingFaceDatasetMixin
from argilla.client.feedback.schemas import (
FeedbackRecord,
FieldSchema,
)
from argilla.client.feedback.schemas.types import AllowedFieldTypes, AllowedQuestionTypes
from argilla.client.feedback.training.schemas import (
TrainingTaskForChatCompletion,
TrainingTaskForDPO,
TrainingTaskForPPO,
TrainingTaskForQuestionAnswering,
TrainingTaskForRM,
TrainingTaskForSentenceSimilarity,
TrainingTaskForSFT,
TrainingTaskForTextClassification,
TrainingTaskTypes,
)
from argilla.client.feedback.utils import generate_pydantic_schema
from argilla.client.models import Framework
from argilla.utils.dependency import require_dependencies, requires_dependencies
from argilla.utils.dependency import requires_dependencies

if TYPE_CHECKING:
from datasets import Dataset
Expand All @@ -48,10 +34,7 @@
)


_LOGGER = logging.getLogger(__name__)


class FeedbackDatasetBase(ABC, HuggingFaceDatasetMixin):
class FeedbackDatasetBase(ABC, HuggingFaceDatasetMixin, metaclass=ABCMeta):
"""Base class with shared functionality for `FeedbackDataset` and `RemoteFeedbackDataset`."""

def __init__(
Expand Down Expand Up @@ -135,7 +118,7 @@ def __init__(
self._guidelines = guidelines

@property
@abstractproperty
@abstractmethod
def records(self) -> Any:
"""Returns the records of the dataset."""
pass
Expand Down Expand Up @@ -189,6 +172,11 @@ def question_by_name(self, name: str) -> Union[AllowedQuestionTypes, "AllowedRem
f" {', '.join(q.name for q in self._questions)}"
)

@abstractmethod
def add_records(self, *args, **kwargs) -> None:
"""Adds the given records to the `FeedbackDataset`."""
pass

def _parse_records(
self, records: Union[FeedbackRecord, Dict[str, Any], List[Union[FeedbackRecord, Dict[str, Any]]]]
) -> List[FeedbackRecord]:
Expand Down Expand Up @@ -275,110 +263,32 @@ def format_as(self, format: Literal["datasets"]) -> "Dataset":
return self._huggingface_format(self)
raise ValueError(f"Unsupported format '{format}'.")

# TODO(alvarobartt,davidberenstein1957): we should consider having something like
# `export(..., training=True)` to export the dataset records in any format, replacing
# both `format_as` and `prepare_for_training`
def prepare_for_training(
self,
framework: Union[Framework, str],
task: TrainingTaskTypes,
train_size: Optional[float] = 1,
test_size: Optional[float] = None,
seed: Optional[int] = None,
lang: Optional[str] = None,
) -> Any:
"""
Prepares the dataset for training for a specific training framework and NLP task by splitting the dataset into train and test sets.
@abstractmethod
def pull(self):
"""Pulls the dataset from Argilla and returns a local instance of it."""
pass

Args:
framework: the framework to use for training. Currently supported frameworks are: `transformers`, `peft`,
`setfit`, `spacy`, `spacy-transformers`, `span_marker`, `spark-nlp`, `openai`, `trl`, `sentence-transformers`.
task: the NLP task to use for training. Currently supported tasks are: `TrainingTaskForTextClassification`,
`TrainingTaskForSFT`, `TrainingTaskForRM`, `TrainingTaskForPPO`, `TrainingTaskForDPO`, `TrainingTaskForSentenceSimilarity`.
train_size: the size of the train set. If `None`, the whole dataset will be used for training.
test_size: the size of the test set. If `None`, the whole dataset will be used for testing.
seed: the seed to use for splitting the dataset into train and test sets.
lang: the spaCy language to use for training. If `None`, the language of the dataset will be used.
"""
if isinstance(framework, str):
framework = Framework(framework)

# validate train and test sizes
if train_size is None:
train_size = 1
if test_size is None:
test_size = 1 - train_size

# check if all numbers are larger than 0
if not [abs(train_size), abs(test_size)] == [train_size, test_size]:
raise ValueError("`train_size` and `test_size` must be larger than 0.")
# check if train sizes sum up to 1
if not (train_size + test_size) == 1:
raise ValueError("`train_size` and `test_size` must sum to 1.")

if test_size == 0:
test_size = None

if len(self.records) < 1:
raise ValueError(
"No records found in the dataset. Make sure you add records to the"
" dataset via the `FeedbackDataset.add_records` method first."
)

if isinstance(task, (TrainingTaskForTextClassification, TrainingTaskForSentenceSimilarity)):
if task.formatting_func is None:
# in sentence-transformer models we can train without labels
if task.label:
self.unify_responses(question=task.label.question, strategy=task.label.strategy)
elif isinstance(task, TrainingTaskForQuestionAnswering):
if task.formatting_func is None:
self.unify_responses(question=task.answer.name, strategy="disagreement")
elif not isinstance(
task,
(
TrainingTaskForSFT,
TrainingTaskForRM,
TrainingTaskForPPO,
TrainingTaskForDPO,
TrainingTaskForChatCompletion,
),
):
raise ValueError(f"Training data {type(task)} is not supported yet")

data = task._format_data(self)
if framework in [
Framework.TRANSFORMERS,
Framework.SETFIT,
Framework.SPAN_MARKER,
Framework.PEFT,
]:
return task._prepare_for_training_with_transformers(
data=data, train_size=train_size, seed=seed, framework=framework
)
elif framework in [Framework.SPACY, Framework.SPACY_TRANSFORMERS]:
require_dependencies("spacy")
import spacy

if lang is None:
_LOGGER.warning("spaCy `lang` is not provided. Using `en`(English) as default language.")
lang = spacy.blank("en")
elif lang.isinstance(str):
if len(lang) == 2:
lang = spacy.blank(lang)
else:
lang = spacy.load(lang)
return task._prepare_for_training_with_spacy(data=data, train_size=train_size, seed=seed, lang=lang)
elif framework is Framework.SPARK_NLP:
return task._prepare_for_training_with_spark_nlp(data=data, train_size=train_size, seed=seed)
elif framework is Framework.OPENAI:
return task._prepare_for_training_with_openai(data=data, train_size=train_size, seed=seed)
elif framework is Framework.TRL:
return task._prepare_for_training_with_trl(data=data, train_size=train_size, seed=seed)
elif framework is Framework.TRLX:
return task._prepare_for_training_with_trlx(data=data, train_size=train_size, seed=seed)
elif framework is Framework.SENTENCE_TRANSFORMERS:
return task._prepare_for_training_with_sentence_transformers(data=data, train_size=train_size, seed=seed)
else:
raise NotImplementedError(
f"Framework {framework} is not supported. Choose from: {[e.value for e in Framework]}"
)
@abstractmethod
def filter_by(self, *args, **kwargs):
"""Filters the current `FeedbackDataset`."""
pass

@abstractmethod
def delete(self):
"""Deletes the `FeedbackDataset` from Argilla."""
pass

@abstractmethod
def prepare_for_training(self, *args, **kwargs) -> Any:
"""Prepares the `FeedbackDataset` for training by creating the training."""
pass

@abstractmethod
def push_to_argilla(self, *args, **kwargs) -> "FeedbackDatasetBase":
"""Pushes the `FeedbackDataset` to Argilla."""
pass

@abstractmethod
def unify_responses(self, *args, **kwargs):
"""Unifies the responses for a given question."""
pass
Loading

0 comments on commit cc4cfdf

Please sign in to comment.