Skip to content

Commit

Permalink
Support paths as strings in public fetching API (skrub-data#453)
Browse files Browse the repository at this point in the history
* Add target directory as optional argument

* Add changelog entry

* Support str in public fetching API

* Update changelog entry

* Make paths absolute

* Remove unnecessary path conversion in test
  • Loading branch information
LilianBoulard authored Feb 17, 2023
1 parent 4d9f555 commit 72be1af
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 16 deletions.
4 changes: 4 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ Minor changes
which can be used to specify where to save and load from datasets.
:pr:`432` by :user:`Lilian Boulard <LilianBoulard>`

* Fetching functions now have an additional argument ``directory``,
which can be used to specify where to save and load from datasets.
:pr:`432` and :pr:`453` by :user:`Lilian Boulard <LilianBoulard>`

* The :class:`TableVectorizer`'s default `OneHotEncoder` for low cardinality categorical variables now defaults
to `handle_unknown="ignore"` instead of `handle_unknown="error"` (for sklearn >= 1.0.0).
This means that categories seen only at test time will be encoded by a vector of zeroes instead of raising an error. :pr:`473` by :user:`Leo Grinsztajn <LeoGrin>`
Expand Down
24 changes: 14 additions & 10 deletions dirty_cat/datasets/_fetching.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from dataclasses import dataclass
from itertools import chain
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, TextIO, Union
from urllib.error import URLError
from zipfile import BadZipFile, ZipFile

Expand Down Expand Up @@ -530,6 +530,7 @@ def _export_gz_data_to_csv(
atdata_found = False
with destination_file.open(mode="w", encoding="utf8") as csv:
with gzip.open(compressed_dir_path, mode="rt", encoding="utf8") as gz:
gz: TextIO # Clarify for IDEs
csv.write(_features_to_csv_format(features))
csv.write("\n")
# We will look at each line of the file until we find
Expand All @@ -552,7 +553,7 @@ def _fetch_dataset_as_dataclass(
dataset_id: Union[int, str],
target: Optional[str],
load_dataframe: bool,
data_directory: Optional[Path] = None,
data_directory: Optional[Union[Path, str]] = None,
read_csv_kwargs: Optional[dict] = None,
) -> Union[DatasetAll, DatasetInfoOnly]:
"""
Expand Down Expand Up @@ -582,6 +583,9 @@ def _fetch_dataset_as_dataclass(
If `load_dataframe=False`
"""
if isinstance(data_directory, str):
data_directory = Path(data_directory)

if source == "openml":
info = _fetch_openml_dataset(dataset_id, data_directory)
elif source == "world_bank":
Expand Down Expand Up @@ -635,7 +639,7 @@ def fetch_employee_salaries(
load_dataframe: bool = True,
drop_linked: bool = True,
drop_irrelevant: bool = True,
directory: Optional[Path] = None,
directory: Optional[Union[Path, str]] = None,
) -> Union[DatasetAll, DatasetInfoOnly]:
"""Fetches the employee_salaries dataset (regression), available at
https://openml.org/d/42125
Expand Down Expand Up @@ -689,7 +693,7 @@ def fetch_employee_salaries(

def fetch_road_safety(
load_dataframe: bool = True,
directory: Optional[Path] = None,
directory: Optional[Union[Path, str]] = None,
) -> Union[DatasetAll, DatasetInfoOnly]:
"""Fetches the road safety dataset (classification), available at
https://openml.org/d/42803
Expand Down Expand Up @@ -723,7 +727,7 @@ def fetch_road_safety(

def fetch_medical_charge(
load_dataframe: bool = True,
directory: Optional[Path] = None,
directory: Optional[Union[Path, str]] = None,
) -> Union[DatasetAll, DatasetInfoOnly]:
"""Fetches the medical charge dataset (regression), available at
https://openml.org/d/42720
Expand Down Expand Up @@ -762,7 +766,7 @@ def fetch_medical_charge(

def fetch_midwest_survey(
load_dataframe: bool = True,
directory: Optional[Path] = None,
directory: Optional[Union[Path, str]] = None,
) -> Union[DatasetAll, DatasetInfoOnly]:
"""Fetches the midwest survey dataset (classification), available at
https://openml.org/d/42805
Expand Down Expand Up @@ -794,7 +798,7 @@ def fetch_midwest_survey(

def fetch_open_payments(
load_dataframe: bool = True,
directory: Optional[Path] = None,
directory: Optional[Union[Path, str]] = None,
) -> Union[DatasetAll, DatasetInfoOnly]:
"""Fetches the open payments dataset (classification), available at
https://openml.org/d/42738
Expand Down Expand Up @@ -828,7 +832,7 @@ def fetch_open_payments(

def fetch_traffic_violations(
load_dataframe: bool = True,
directory: Optional[Path] = None,
directory: Optional[Union[Path, str]] = None,
) -> Union[DatasetAll, DatasetInfoOnly]:
"""Fetches the traffic violations dataset (classification), available at
https://openml.org/d/42132
Expand Down Expand Up @@ -864,7 +868,7 @@ def fetch_traffic_violations(

def fetch_drug_directory(
load_dataframe: bool = True,
directory: Optional[Path] = None,
directory: Optional[Union[Path, str]] = None,
) -> Union[DatasetAll, DatasetInfoOnly]:
"""Fetches the drug directory dataset (classification), available at
https://openml.org/d/43044
Expand Down Expand Up @@ -898,7 +902,7 @@ def fetch_drug_directory(
def fetch_world_bank_indicator(
indicator_id: str,
load_dataframe: bool = True,
directory: Optional[Path] = None,
directory: Optional[Union[Path, str]] = None,
) -> Union[DatasetAll, DatasetInfoOnly]:
"""Fetches a dataset of an indicator from the World Bank
open data platform.
Expand Down
6 changes: 0 additions & 6 deletions dirty_cat/datasets/tests/test_fetching.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Callable, Tuple
from unittest import mock
Expand Down Expand Up @@ -41,10 +40,6 @@ def test_openml_fetching(
Test a function that loads data from OpenML.
"""
with TemporaryDirectory() as temp_dir_1, TemporaryDirectory() as temp_dir_2:
# Convert to path objects
temp_dir_1 = Path(temp_dir_1).absolute()
temp_dir_2 = Path(temp_dir_2).absolute()

# Fetch without loading into memory
try:
dataset_wo_load: _fetching.DatasetInfoOnly = fetching_function(
Expand Down Expand Up @@ -126,7 +121,6 @@ def test_fetch_world_bank_indicator():
}

with TemporaryDirectory() as temp_dir:
temp_dir = Path(temp_dir).absolute()
try:
# First, we want to purposefully test FileNotFoundError exceptions.
with pytest.raises(FileNotFoundError):
Expand Down

0 comments on commit 72be1af

Please sign in to comment.