Skip to content

Commit

Permalink
Update column selector
Browse files Browse the repository at this point in the history
  • Loading branch information
lixfz committed Feb 15, 2024
1 parent 6638a04 commit 903da91
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
15 changes: 8 additions & 7 deletions hypernets/tabular/column_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def __init__(self, pattern=None, *, dtype_include=None, dtype_exclude=None,
assert isinstance(word_count_threshold, int) and word_count_threshold >= 1

if dtype_include is None:
dtype_include = ['object']
dtype_include = ['object', 'string']

super(TextColumnSelector, self).__init__(pattern,
dtype_include=dtype_include,
Expand Down Expand Up @@ -241,19 +241,20 @@ def __call__(self, df):


column_all = ColumnSelector()
column_object_category_bool = ColumnSelector(dtype_include=['object', 'category', 'bool'])
column_object_category_bool_with_auto = AutoCategoryColumnSelector(dtype_include=['object', 'category', 'bool'],
cat_exponent=0.5)
column_text = TextColumnSelector(dtype_include=['object'])
column_object_category_bool = ColumnSelector(dtype_include=['object', 'string', 'category', 'bool'])
column_object_category_bool_with_auto = AutoCategoryColumnSelector(
dtype_include=['object', 'string', 'category', 'bool'],
cat_exponent=0.5)
column_text = TextColumnSelector(dtype_include=['object', 'string'])
column_latlong = LatLongColumnSelector()

column_object = ColumnSelector(dtype_include=['object'])
column_object = ColumnSelector(dtype_include=['object', 'string'])
column_category = ColumnSelector(dtype_include=['category'])
column_bool = ColumnSelector(dtype_include=['bool'])
column_number = ColumnSelector(dtype_include='number')
column_number_exclude_timedelta = ColumnSelector(dtype_include='number', dtype_exclude='timedelta')
column_object_category_bool_int = ColumnSelector(
dtype_include=['object', 'category', 'bool',
dtype_include=['object', 'string', 'category', 'bool',
'int', 'int8', 'int16', 'int32', 'int64',
'uint', 'uint8', 'uint16', 'uint32', 'uint64'])

Expand Down
4 changes: 3 additions & 1 deletion hypernets/tests/tabular/tb_dask/dask_transofromer_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import pandas as pd
import pytest

from hypernets.tabular.datasets import dsutils
from hypernets.utils import const
Expand Down Expand Up @@ -130,10 +131,11 @@ def test_varlen_encoder_with_customized_data(self):
print(d_result_df)
assert all(d_result_df.values == result.values)

@pytest.mark.xfail # see: dask_ml ColumnTransformer
def test_dataframe_wrapper(self):
X = self.bank_data.copy()

cats = X.select_dtypes(['object', ]).columns.to_list()
cats = X.select_dtypes(['object', 'string']).columns.to_list()
continous = X.select_dtypes(['float', 'float64', 'int', 'int64']).columns.to_list()
transformers = [('cats',
dex.SimpleImputer(missing_values=np.nan, strategy='constant', fill_value=''),
Expand Down

0 comments on commit 903da91

Please sign in to comment.