Skip to content

Commit

Permalink
add kwargs to cate and gate irm
Browse files Browse the repository at this point in the history
  • Loading branch information
SvenKlaassen committed Oct 10, 2024
1 parent 88c8ec8 commit cb02bca
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 8 deletions.
15 changes: 11 additions & 4 deletions doubleml/irm/irm.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ def _nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune, n_

return res

def cate(self, basis, is_gate=False):
def cate(self, basis, is_gate=False, **kwargs):
"""
Calculate conditional average treatment effects (CATE) for a given basis.
Expand All @@ -440,10 +440,14 @@ def cate(self, basis, is_gate=False):
basis : :class:`pandas.DataFrame`
The basis for estimating the best linear predictor. Has to have the shape ``(n_obs, d)``,
where ``n_obs`` is the number of observations and ``d`` is the number of predictors.
is_gate : bool
Indicates whether the basis is constructed for GATEs (dummy-basis).
Default is ``False``.
**kwargs: dict
Additional keyword arguments to be passed to :meth:`statsmodels.regression.linear_model.OLS.fit` e.g. ``cov_type``.
Returns
-------
model : :class:`doubleML.DoubleMLBLP`
Expand All @@ -462,10 +466,10 @@ def cate(self, basis, is_gate=False):
orth_signal = self.psi_elements['psi_b'].reshape(-1)
# fit the best linear predictor
model = DoubleMLBLP(orth_signal, basis=basis, is_gate=is_gate)
model.fit()
model.fit(**kwargs)
return model

def gate(self, groups):
def gate(self, groups, **kwargs):
"""
Calculate group average treatment effects (GATE) for groups.
Expand All @@ -476,6 +480,9 @@ def gate(self, groups):
Has to be dummy coded with shape ``(n_obs, d)``, where ``n_obs`` is the number of observations
and ``d`` is the number of groups or ``(n_obs, 1)`` and contain the corresponding groups (as str).
**kwargs: dict
Additional keyword arguments to be passed to :meth:`statsmodels.regression.linear_model.OLS.fit` e.g. ``cov_type``.
Returns
-------
model : :class:`doubleML.DoubleMLBLP`
Expand All @@ -495,7 +502,7 @@ def gate(self, groups):
if any(groups.sum(0) <= 5):
warnings.warn('At least one group effect is estimated with less than 6 observations.')

model = self.cate(groups, is_gate=True)
model = self.cate(groups, is_gate=True, **kwargs)
return model

def policy_tree(self, features, depth=2, **tree_params):
Expand Down
14 changes: 10 additions & 4 deletions doubleml/irm/tests/test_irm.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,14 @@ def test_dml_irm_sensitivity_rho0(dml_irm_fixture):
rtol=1e-9, atol=1e-4)


@pytest.fixture(scope='module',
params=["nonrobust", "HC0", "HC1", "HC2", "HC3"])
def cov_type(request):
return request.param


@pytest.mark.ci
def test_dml_irm_cate_gate():
def test_dml_irm_cate_gate(cov_type):
n = 9
# collect data
np.random.seed(42)
Expand All @@ -207,7 +213,7 @@ def test_dml_irm_cate_gate():
dml_irm_obj.fit()
# create a random basis
random_basis = pd.DataFrame(np.random.normal(0, 1, size=(n, 5)))
cate = dml_irm_obj.cate(random_basis)
cate = dml_irm_obj.cate(random_basis, cov_type=cov_type)
assert isinstance(cate, dml.utils.blp.DoubleMLBLP)
assert isinstance(cate.confint(), pd.DataFrame)

Expand All @@ -216,7 +222,7 @@ def test_dml_irm_cate_gate():
columns=['Group 1', 'Group 2'])
msg = ('At least one group effect is estimated with less than 6 observations.')
with pytest.warns(UserWarning, match=msg):
gate_1 = dml_irm_obj.gate(groups_1)
gate_1 = dml_irm_obj.gate(groups_1, cov_type=cov_type)
assert isinstance(gate_1, dml.utils.blp.DoubleMLBLP)
assert isinstance(gate_1.confint(), pd.DataFrame)
assert all(gate_1.confint().index == groups_1.columns.to_list())
Expand All @@ -225,7 +231,7 @@ def test_dml_irm_cate_gate():
groups_2 = pd.DataFrame(np.random.choice(["1", "2"], n))
msg = ('At least one group effect is estimated with less than 6 observations.')
with pytest.warns(UserWarning, match=msg):
gate_2 = dml_irm_obj.gate(groups_2)
gate_2 = dml_irm_obj.gate(groups_2, cov_type=cov_type)
assert isinstance(gate_2, dml.utils.blp.DoubleMLBLP)
assert isinstance(gate_2.confint(), pd.DataFrame)
assert all(gate_2.confint().index == ["Group_1", "Group_2"])
Expand Down

0 comments on commit cb02bca

Please sign in to comment.