Skip to content

Commit

Permalink
refactor code, rename attributes categorical to _categorical and nume…
Browse files Browse the repository at this point in the history
…rical to _numerical, update n_missing
  • Loading branch information
yuenshingyan committed Dec 16, 2024
1 parent a785706 commit f034d13
Showing 1 changed file with 65 additions and 55 deletions.
120 changes: 65 additions & 55 deletions src/missforest/missforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
_validate_empty_feature,
_validate_imputable,
_validate_verbose,
_validate_column_consistency,
)
from .metrics import pfc, nrmse
from ._array import SafeArray
Expand Down Expand Up @@ -55,9 +56,9 @@ class MissForest:
Maximum iterations of imputing.
early_stopping : bool
Determines if early stopping will be executed.
categorical : list
_categorical : list
All categorical columns of given dataframe `x`.
numerical : list
_numerical : list
All numerical columns of given dataframe `x`.
column_order : pd.Index
Sorting order of features.
Expand Down Expand Up @@ -139,11 +140,11 @@ def __init__(self, clf: Union[Any, BaseEstimator] = lgbm_clf,

self.classifier = clf
self.regressor = rgr
self.categorical = [] if categorical is None else categorical
self._categorical = [] if categorical is None else categorical
self.initial_guess = initial_guess
self.max_iter = max_iter
self.early_stopping = early_stopping
self.numerical = None
self._numerical = None
self.column_order = None
self.initial_imputations = None
self._is_fitted = False
Expand Down Expand Up @@ -268,34 +269,34 @@ def _is_stopping_criterion_satisfied(self, pfc_score: SafeArray,
- False, if stopping criterion not satisfied.
"""
is_pfc_increased = False
if any(self.categorical) and len(pfc_score) >= 2:
if any(self._categorical) and len(pfc_score) >= 2:
is_pfc_increased = pfc_score[-1] > pfc_score[-2]

is_nrmse_increased = False
if any(self.numerical) and len(nrmse_score) >= 2:
if any(self._numerical) and len(nrmse_score) >= 2:
is_nrmse_increased = nrmse_score[-1] > nrmse_score[-2]

if (
any(self.categorical) and
any(self.numerical) and
any(self._categorical) and
any(self._numerical) and
is_pfc_increased * is_nrmse_increased
):
if self._verbose >= 2:
warnings.warn("Both PFC and NRMSE have increased.")

return True
elif (
any(self.categorical) and
not any(self.numerical) and
any(self._categorical) and
not any(self._numerical) and
is_pfc_increased
):
if self._verbose >= 2:
warnings.warn("PFC have increased.")

return True
elif (
not any(self.categorical) and
any(self.numerical) and
not any(self._categorical) and
any(self._numerical) and
is_nrmse_increased
):
if self._verbose >= 2:
Expand Down Expand Up @@ -355,25 +356,25 @@ def fit(self, x: pd.DataFrame):
_validate_empty_feature(x)
_validate_feature_dtype_consistency(x)
_validate_imputable(x)
_validate_cat_var_consistency(x.columns, self.categorical)
_validate_cat_var_consistency(x.columns, self._categorical)

if any(self.categorical):
_validate_infinite(x.drop(self.categorical, axis=1))
if any(self._categorical):
_validate_infinite(x.drop(self._categorical, axis=1))
else:
_validate_infinite(x)

self.numerical = [c for c in x.columns if c not in self.categorical]
self._numerical = [c for c in x.columns if c not in self._categorical]

# Sort column order according to the amount of missing values
# starting with the lowest amount.
pct_missing = x.isnull().sum() / len(x)
self.column_order = pct_missing.sort_values().index
x = x[self.column_order].copy()

n_missing = self._get_n_missing(x)
n_missing = self._get_n_missing(x[self._categorical])
missing_indices = self._get_missing_indices(x)
self.initial_imputations = self._compute_initial_imputations(
x, self.categorical
x, self._categorical
)
x_imp = self._initial_impute(x, self.initial_imputations)

Expand All @@ -390,7 +391,7 @@ def fit(self, x: pd.DataFrame):
fitted_estimators = OrderedDict()

for c in missing_indices:
if c in self.categorical:
if c in self._categorical:
estimator = deepcopy(self.classifier)
else:
estimator = deepcopy(self.regressor)
Expand All @@ -416,27 +417,30 @@ def fit(self, x: pd.DataFrame):

# Store imputed categorical and numerical features after
# each iteration.
x_imp_cat.append(x_imp[self.categorical])
x_imp_num.append(x_imp[self.numerical])

# Compute and store PFC.
if any(self.categorical) and len(x_imp_cat) >= 2:
pfc_score.append(
pfc(
x_true=x_imp_cat[-1],
x_imp=x_imp_cat[-2],
n_missing=n_missing,
if any(self._categorical):
x_imp_cat.append(x_imp[self._categorical])

if len(x_imp_cat) >= 2:
pfc_score.append(
pfc(
x_true=x_imp_cat[-1],
x_imp=x_imp_cat[-2],
n_missing=n_missing,
)
)
)

# Compute and store NRMSE.
if any(self.numerical) and len(x_imp_num) >= 2:
nrmse_score.append(
nrmse(
x_true=x_imp_num[-1],
x_imp=x_imp_num[-2],
if any(self._numerical):
x_imp_num.append(x_imp[self._numerical])

if len(x_imp_num) >= 2:
nrmse_score.append(
nrmse(
x_true=x_imp_num[-1],
x_imp=x_imp_num[-2],
)
)
)

if (
self.early_stopping and
Expand Down Expand Up @@ -501,10 +505,12 @@ def transform(self, x: pd.DataFrame) -> pd.DataFrame:
_validate_empty_feature(x)
_validate_feature_dtype_consistency(x)
_validate_imputable(x)
_validate_cat_var_consistency(x.columns, self._categorical)
_validate_column_consistency(set(x.columns), set(self.column_order))

x = x[self.column_order].copy()

n_missing = self._get_n_missing(x)
n_missing = self._get_n_missing(x[self._categorical])
missing_indices = self._get_missing_indices(x)
x_imp = self._initial_impute(x, self.initial_imputations)

Expand All @@ -531,28 +537,32 @@ def transform(self, x: pd.DataFrame) -> pd.DataFrame:

# Store imputed categorical and numerical features after
# each iteration.
x_imp_cat.append(x_imp[self.categorical])
x_imp_num.append(x_imp[self.numerical])
x_imps.append(x_imp)

# Compute and store PFC.
if any(self.categorical) and len(x_imp_cat) >= 2:
pfc_score.append(
pfc(
x_true=x_imp_cat[-1],
x_imp=x_imp_cat[-2],
n_missing=n_missing,
if any(self._categorical):
x_imp_cat.append(x_imp[self._categorical])

# Compute and store PFC.
if len(x_imp_cat) >= 2:
pfc_score.append(
pfc(
x_true=x_imp_cat[-1],
x_imp=x_imp_cat[-2],
n_missing=n_missing,
)
)
)

# Compute and store NRMSE.
if any(self.numerical) and len(x_imp_num) >= 2:
nrmse_score.append(
nrmse(
x_true=x_imp_num[-1],
x_imp=x_imp_num[-2],
if any(self._numerical):
x_imp_num.append(x_imp[self._numerical])

# Compute and store NRMSE.
if len(x_imp_num) >= 2:
nrmse_score.append(
nrmse(
x_true=x_imp_num[-1],
x_imp=x_imp_num[-2],
)
)
)

x_imps.append(x_imp)

if (
self.early_stopping and
Expand Down

0 comments on commit f034d13

Please sign in to comment.