Skip to content

Commit

Permalink
Merge pull request #287 from DoubleML/s-fix-rdrobust-test
Browse files Browse the repository at this point in the history
Update Importerror for conda testing
  • Loading branch information
SvenKlaassen authored Jan 9, 2025
2 parents e5308be + 570e5ca commit 92b057d
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 13 deletions.
5 changes: 1 addition & 4 deletions doubleml/rdd/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,4 @@ def _is_rdrobust_available():
rdrobust = importlib.import_module("rdrobust")
return rdrobust
except ImportError:
msg = (
"rdrobust is not installed. "
"Please install it using 'pip install DoubleML[rdd]'")
raise ImportError(msg)
return None
5 changes: 5 additions & 0 deletions doubleml/rdd/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ def __init__(self,
fs_kernel="triangular",
**kwargs):

if rdrobust is None:
msg = ("rdrobust is not installed. "
"Please install it using 'pip install DoubleML[rdd]'")
raise ImportError(msg)

self._check_data(obj_dml_data, cutoff)
self._dml_data = obj_dml_data

Expand Down
5 changes: 5 additions & 0 deletions doubleml/rdd/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ def _predict_dummy(data: DoubleMLData, cutoff, alpha, n_rep, p, fs_specification
dml_rdflex.fit(n_iterations=1)
ci_manual = dml_rdflex.confint(level=1-alpha)

if rdrobust is None:
msg = ("rdrobust is not installed. "
"Please install it using 'pip install DoubleML[rdd]'")
raise ImportError(msg)

rdrobust_model = rdrobust.rdrobust(
y=data.y,
x=data.s,
Expand Down
3 changes: 1 addition & 2 deletions doubleml/rdd/tests/test_rdd_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@
)
dml_data = dml.DoubleMLData(df, y_col='y', d_cols='d', s_col='score')

dml_rdflex = RDFlex(dml_data, ml_g=LogisticRegression(), ml_m=LogisticRegression(), fuzzy=True)


@pytest.mark.ci_rdd
def test_rdd_classifier():
dml_rdflex = RDFlex(dml_data, ml_g=LogisticRegression(), ml_m=LogisticRegression(), fuzzy=True)
dml_rdflex.fit()
3 changes: 1 addition & 2 deletions doubleml/rdd/tests/test_rdd_default_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
)
dml_data = dml.DoubleMLData(df, y_col='y', d_cols='d', s_col='score')

dml_rdflex = RDFlex(dml_data, ml_g=Lasso(), ml_m=LogisticRegression())


def _assert_resampling_default_settings(dml_obj):
assert dml_obj.n_folds == 5
Expand All @@ -32,4 +30,5 @@ def _assert_resampling_default_settings(dml_obj):

@pytest.mark.ci_rdd
def test_rdd_defaults():
dml_rdflex = RDFlex(dml_data, ml_g=Lasso(), ml_m=LogisticRegression())
_assert_resampling_default_settings(dml_rdflex)
7 changes: 4 additions & 3 deletions doubleml/rdd/tests/test_rdd_not_installed.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import pytest
from unittest.mock import patch

from doubleml.rdd._utils import _is_rdrobust_available
import doubleml as dml


@pytest.mark.ci
def test_rdrobust_import_error():
with patch('importlib.import_module', side_effect=ImportError):
with patch('doubleml.rdd.rdd.rdrobust', None):
msg = r"rdrobust is not installed. Please install it using 'pip install DoubleML\[rdd\]'"
with pytest.raises(ImportError, match=msg):
_is_rdrobust_available()
dml.rdd.RDFlex(None, None)
4 changes: 2 additions & 2 deletions doubleml/rdd/tests/test_rdd_return_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
)
dml_data = dml.DoubleMLData(df, y_col='y', d_cols='d', s_col='score')

dml_rdflex = RDFlex(dml_data, ml_g=Lasso(), ml_m=LogisticRegression())


def _assert_return_types(dml_obj):
assert isinstance(dml_obj.n_folds, int)
Expand Down Expand Up @@ -52,5 +50,7 @@ def _assert_return_types_after_fit(dml_obj):

@pytest.mark.ci_rdd
def test_rdd_returntypes():
dml_rdflex = RDFlex(dml_data, ml_g=Lasso(), ml_m=LogisticRegression())

_assert_return_types(dml_rdflex)
_assert_return_types_after_fit(dml_rdflex)

0 comments on commit 92b057d

Please sign in to comment.