Skip to content

Commit

Permalink
fix: configure_dataset_settings to use workspace arg (argilla-io#…
Browse files Browse the repository at this point in the history
…3887)

# Description

This PR solves a bug that was ignoring the `workspace` arg when provided
in `configure_dataset` and/or `configure_dataset_settings` methods, and
was just being taken into consideration when calling `set_workspace` or
as an arg in the `init` method.

So on, now one can just provide the `workspace` arg no matter what the
default value is.

Closes argilla-io#3505

**Type of change**

- [X] Bug fix (non-breaking change which fixes an issue)

**How Has This Been Tested**

- [x] Add integration tests for `configure_dataset_settings`
- [x] Add missing `DeprecationWarning` test for `configure_dataset` in
favour of `configure_dataset_settings`

**Checklist**

- [x] I followed the style guidelines of this project
- [x] I did a self-review of my code
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK)
(see text above)
- [x] I have added relevant notes to the `CHANGELOG.md` file (See
https://keepachangelog.com/)

---------

Co-authored-by: Paco Aranda <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Oct 17, 2023
1 parent cc4cfdf commit b4ee857
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 19 deletions.
4 changes: 1 addition & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,8 @@ These are the section headers that we use:

- Updated active learning for text classification notebooks to pass ids of type int to `TextClassificationRecord` ([#3831](https://github.com/argilla-io/argilla/pull/3831)).
- Fixed record fields validation that was preventing from logging records with optional fields (i.e. `required=True`) when the field value was `None` ([#3846](https://github.com/argilla-io/argilla/pull/3846)).
- Fixed `configure_dataset_settings` when providing the workspace via the arg `workspace` ([#3887](https://github.com/argilla-io/argilla/pull/3887)).
- The `inserted_at` and `updated_at` attributes are create using the `utcnow` factory to avoid unexpected race conditions on timestamp creation ([#3945](https://github.com/argilla-io/argilla/pull/3945))

### Fixed

- Fixed saving of models trained with `ArgillaTrainer` with a `peft_config` parameter ([#3795](https://github.com/argilla-io/argilla/pull/3795)).
- Fixed backwards compatibility on `from_huggingface` when loading a `FeedbackDataset` from the Hugging Face Hub that was previously dumped using another version of Argilla, starting at 1.8.0, when it was first introduced ([#3829](https://github.com/argilla-io/argilla/pull/3829)).

Expand Down
18 changes: 11 additions & 7 deletions src/argilla/client/apis/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ class _DatasetApiModel(BaseModel):
class _SettingsApiModel(BaseModel):
label_schema: Dict[str, Any]

def find_by_name(self, name: str) -> _DatasetApiModel:
dataset = get_dataset(self.http_client, name=name).parsed
def find_by_name(self, name: str, workspace: Optional[str] = None) -> _DatasetApiModel:
dataset = get_dataset(self.http_client, name=name, workspace=workspace).parsed
return self._DatasetApiModel.parse_obj(dataset)

def create(self, name: str, task: TaskType, workspace: str) -> _DatasetApiModel:
Expand Down Expand Up @@ -163,7 +163,7 @@ def configure(self, name: str, workspace: str, settings: Settings):
)
ds = self.create(name=name, task=task, workspace=workspace)
except AlreadyExistsApiError:
ds = self.find_by_name(name)
ds = self.find_by_name(name, workspace=workspace)
self._save_settings(dataset=ds, settings=settings)

def scan(
Expand Down Expand Up @@ -322,7 +322,7 @@ def _save_settings(self, dataset: _DatasetApiModel, settings: Settings):
try:
with api_compatibility(self, min_version="1.4"):
self.http_client.patch(
f"{self._API_PREFIX}/{dataset.name}/{dataset.task.value}/settings",
f"{self._API_PREFIX}/{dataset.name}/{dataset.task.value}/settings?workspace={dataset.workspace}",
json=settings_.dict(),
)
except ApiCompatibilityError:
Expand All @@ -332,20 +332,24 @@ def _save_settings(self, dataset: _DatasetApiModel, settings: Settings):
json=settings_.dict(),
)

def load_settings(self, name: str) -> Optional[Settings]:
def load_settings(self, name: str, workspace: Optional[str] = None) -> Optional[Settings]:
"""
Load the dataset settings
Args:
name: The dataset name
workspace: The workspace name where the dataset belongs to
Returns:
Settings defined for the dataset
"""
dataset = self.find_by_name(name)
dataset = self.find_by_name(name, workspace=workspace)
try:
with api_compatibility(self, min_version="1.0"):
response = self.http_client.get(f"{self._API_PREFIX}/{dataset.name}/{dataset.task.value}/settings")
params = {"workspace": dataset.workspace} if dataset.workspace else {}
response = self.http_client.get(
f"{self._API_PREFIX}/{dataset.name}/{dataset.task.value}/settings", params=params
)
return __TASK_TO_SETTINGS__.get(dataset.task).from_dict(response)
except NotFoundApiError:
return None
Expand Down
6 changes: 5 additions & 1 deletion src/argilla/client/sdk/datasets/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,14 @@


@lru_cache(maxsize=None)
def get_dataset(client: AuthenticatedClient, name: str) -> Response[Dataset]:
def get_dataset(client: AuthenticatedClient, name: str, workspace: Optional[str] = None) -> Response[Dataset]:
url = f"{client.base_url}/api/datasets/{name}"

params = {"workspace": workspace} if workspace else None

response = httpx.get(
url=url,
params=params,
headers=client.get_headers(),
cookies=client.get_cookies(),
timeout=client.get_timeout(),
Expand All @@ -40,6 +43,7 @@ def get_dataset(client: AuthenticatedClient, name: str) -> Response[Dataset]:
response_obj = Response.from_httpx_response(response)
response_obj.parsed = Dataset(**response.json())
return response_obj

handle_response_error(response)


Expand Down
7 changes: 3 additions & 4 deletions src/argilla/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
_LOGGER = logging.getLogger(__name__)


def load_dataset_settings(name: str, workspace: Optional[str] = None) -> Settings:
def load_dataset_settings(name: str, workspace: Optional[str] = None) -> Optional[Settings]:
"""
Loads the settings of a dataset
Expand All @@ -34,10 +34,9 @@ def load_dataset_settings(name: str, workspace: Optional[str] = None) -> Setting
The dataset settings
"""
active_api = api.active_api()
if workspace is not None:
active_api.set_workspace(workspace)
datasets = active_api.datasets
settings = datasets.load_settings(name)

settings = datasets.load_settings(name, workspace=workspace)
if settings is None:
return None
else:
Expand Down
69 changes: 65 additions & 4 deletions tests/integration/test_datasets_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional, Union
from uuid import uuid4

import argilla as rg
import pytest
from argilla import Workspace
from argilla.client import api
from argilla.client.api import delete, get_workspace, init
from argilla.client.client import Argilla
from argilla.client.sdk.commons.errors import ForbiddenApiError
from argilla.datasets import TextClassificationSettings, TokenClassificationSettings
from argilla.datasets.__init__ import configure_dataset
from argilla.datasets import (
TextClassificationSettings,
TokenClassificationSettings,
configure_dataset,
configure_dataset_settings,
load_dataset_settings,
)
from argilla.server.contexts import accounts
from argilla.server.security.model import WorkspaceUserCreate

Expand All @@ -30,6 +36,8 @@
from argilla.server.models import User
from sqlalchemy.ext.asyncio import AsyncSession

from .helpers import SecuredClient


@pytest.mark.parametrize(
("settings_", "wrong_settings"),
Expand Down Expand Up @@ -75,6 +83,59 @@ def test_settings_workflow(
configure_dataset(dataset, wrong_settings, workspace=workspace)


@pytest.mark.parametrize(
"settings, workspace",
[
(TextClassificationSettings(label_schema={"A", "B"}), None),
(TextClassificationSettings(label_schema={"D", "E"}), "admin"),
(TokenClassificationSettings(label_schema={"PER", "ORG"}), None),
(TokenClassificationSettings(label_schema={"CAT", "DOG"}), "admin"),
],
)
def test_configure_dataset_settings_twice(
owner: "User",
argilla_user: "User",
settings: Union[TextClassificationSettings, TokenClassificationSettings],
workspace: Optional[str],
) -> None:
if not workspace:
workspace_name = argilla_user.username
else:
init(api_key=owner.api_key)
workspace = Workspace.create(name=workspace)
workspace.add_user(argilla_user.id)
workspace_name = workspace.name

init(api_key=argilla_user.api_key, workspace=argilla_user.username)
dataset_name = f"test-dataset-{uuid4()}"
# This will create the dataset
configure_dataset_settings(dataset_name, settings=settings, workspace=workspace_name)
# This will update the dataset and what describes the issue https://github.com/argilla-io/argilla/issues/3505
configure_dataset_settings(dataset_name, settings=settings, workspace=workspace_name)

found_settings = load_dataset_settings(dataset_name, workspace_name)
assert {label for label in found_settings.label_schema} == {str(label) for label in settings.label_schema}


@pytest.mark.parametrize(
"settings",
[
TextClassificationSettings(label_schema={"A", "B"}),
TokenClassificationSettings(label_schema={"PER", "ORG"}),
],
)
def test_configure_dataset_deprecation_warning(
argilla_user: "User", settings: Union[TextClassificationSettings, TokenClassificationSettings]
) -> None:
init(api_key=argilla_user.api_key, workspace=argilla_user.username)

dataset_name = f"test-dataset-{uuid4()}"
workspace_name = get_workspace()

with pytest.warns(DeprecationWarning, match="This method is deprecated. Use configure_dataset_settings instead."):
configure_dataset(dataset_name, settings=settings, workspace=workspace_name)


def test_list_dataset(mocked_client: "SecuredClient"):
from argilla.client.api import active_client

Expand Down

0 comments on commit b4ee857

Please sign in to comment.