Skip to content

Commit

Permalink
Cython implementation of GRF and CausalForestDML (py-why#341)
Browse files Browse the repository at this point in the history
* added backend option in orf, adding verbosity, restructuring static functions

* added cython grf module that implements generalized random forests

* added cuthon version of causal forest and causal forest dml

* deprecating older CausalForest

* updates to CF and ORF notebook

* restructured dml into folder. Deprecated ForestDML by CausalForestDML. 

* Removed two legacy files in our main folder.

* deprecating ensemble.SubsampledHonestForest

* made drlearner use the non dprecated regression forest. 

* Enable setuptools build process

* fixed flaky random_state test

* fixed tests and api consistency

* updated tables and library flow chart

* enforce sklearn 0.24.

* fixed _cross_val_predict

* added option for max background samples to shap make computation more reasonable

* fixed error_score param in gcvlist due to sklearn upgrade

* added shap cells in DML notebook

* added shap values to GRF notebook

* fixed bug in the way input_feature_names where used in summary. enabled shap to use input featurenames

* updated readme. removed autoreload from noteoboks

* added shap specific notebook

* updated dowhy notebook
  • Loading branch information
vsyrgkanis authored Jan 9, 2021
1 parent 3df959d commit bb042d5
Show file tree
Hide file tree
Showing 85 changed files with 14,660 additions and 3,772 deletions.
37 changes: 37 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,40 @@
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE


Parts of this software, in particular code contained in the modules econml.tree and
econml.grf contain files that are forks from the scikit-learn git repository, or code
snippets from that repository:
https://github.com/scikit-learn/scikit-learn
published under the following License.

BSD 3-Clause License

Copyright (c) 2007-2020 The scikit-learn developers.
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.

* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
45 changes: 23 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,20 +118,7 @@ To install from source, see [For Developers](#for-developers) section below.
treatment_effects = est.effect(X_test)
lb, ub = est.effect_interval(X_test, alpha=0.05) # Confidence intervals via debiased lasso
```

* Forest last stage

```Python
from econml.dml import ForestDML
from sklearn.ensemble import GradientBoostingRegressor

est = ForestDML(model_y=GradientBoostingRegressor(), model_t=GradientBoostingRegressor())
est.fit(Y, T, X=X, W=W)
treatment_effects = est.effect(X_test)
# Confidence intervals via Bootstrap-of-Little-Bags for forests
lb, ub = est.effect_interval(X_test, alpha=0.05)
```

* Generic Machine Learning last stage

```Python
Expand All @@ -152,16 +139,16 @@ To install from source, see [For Developers](#for-developers) section below.
<summary>Causal Forests (click to expand)</summary>

```Python
from econml.causal_forest import CausalForest
from econml.dml import CausalForestDML
from sklearn.linear_model import LassoCV
# Use defaults
est = CausalForest()
est = CausalForestDML()
# Or specify hyperparameters
est = CausalForest(n_trees=500, min_leaf_size=10,
max_depth=10, subsample_ratio=0.7,
lambda_reg=0.01,
discrete_treatment=False,
model_T=LassoCV(), model_Y=LassoCV())
est = CausalForestDML(criterion='het', n_estimators=500,
min_samples_leaf=10,
max_depth=10, max_samples=0.5,
discrete_treatment=False,
model_t=LassoCV(), model_y=LassoCV())
est.fit(Y, T, X=X, W=W)
treatment_effects = est.effect(X_test)
# Confidence intervals via Bootstrap-of-Little-Bags for forests
Expand Down Expand Up @@ -354,7 +341,7 @@ treatment_effects = est.effect(X_test)

<details>
<summary>Policy Interpreter of the CATE model (click to expand)</summary>

```Python
from econml.cate_interpreter import SingleTreePolicyInterpreter
# We find a tree-based treatment policy based on the CATE model
Expand All @@ -366,7 +353,21 @@ treatment_effects = est.effect(X_test)
plt.show()
```
![image](notebooks/images/dr_policy_tree.png)


</details>

<details>
<summary>SHAP values for the CATE model (click to expand)</summary>

```Python
import shap
from econml.dml import CausalForestDML
est = CausalForestDML()
est.fit(Y, T, X=X, W=W)
shap_values = est.shap_values(X)
shap.summary_plot(shap_values['Y0']['T0'])
```

</details>

### Inference
Expand Down
4 changes: 2 additions & 2 deletions azure-pipelines-steps.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

parameters:
body: []
package: '.'
package: '-e .'

steps:
- task: UsePythonVersion@0
Expand All @@ -24,7 +24,7 @@ steps:
condition: and(succeeded(), eq(variables['Agent.OS'], 'Linux'))

# Install the package
- script: 'python -m pip install --upgrade pip && pip install --upgrade setuptools wheel && pip install ${{ parameters.package }}'
- script: 'python -m pip install --upgrade pip && pip install --upgrade setuptools wheel Cython && pip install ${{ parameters.package }}'
displayName: 'Install dependencies'

- ${{ parameters.body }}
5 changes: 1 addition & 4 deletions azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,6 @@ jobs:
- script: 'pip install --force-reinstall --no-cache-dir shap'
displayName: 'Install public shap'

- script: 'pip install --force-reinstall scikit-learn==0.23.2'
displayName: 'Install public old sklearn'

- script: 'python setup.py build_sphinx -W'
displayName: 'Build documentation'

Expand All @@ -81,7 +78,7 @@ jobs:

- script: 'python setup.py build_sphinx -b doctest'
displayName: 'Run doctests'
package: '.[automl]'
package: '-e .[automl]'

- job: 'Notebooks'
dependsOn: 'EvalChanges'
Expand Down
2 changes: 1 addition & 1 deletion doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@
# Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = {'python': ('https://docs.python.org/3', None),
'numpy': ('https://docs.scipy.org/doc/numpy/', None),
'sklearn': ('https://scikit-learn.org/0.23/', None),
'sklearn': ('https://scikit-learn.org/stable/', None),
'matplotlib': ('https://matplotlib.org/', None)}

# -- Options for todo extension ----------------------------------------------
Expand Down
28 changes: 14 additions & 14 deletions doc/map.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
13 changes: 7 additions & 6 deletions doc/reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,11 @@ Public Module Reference
:toctree: _autosummary

econml.bootstrap
econml.cate_estimator
econml.cate_interpreter
econml.causal_forest
econml.causal_tree
econml.deepiv
econml.dgp
econml.dml
econml.drlearner
econml.grf
econml.inference
econml.metalearners
econml.ortho_forest
Expand All @@ -27,7 +24,12 @@ Private Module Reference
:toctree: _autosummary

econml._ortho_learner
econml._rlearner
econml._cate_estimator
econml._causal_tree
econml.dml._rlearner
econml.grf._base_grf
econml.grf._base_grftree
econml.grf._criterion

Scikit-Learn Extensions
=======================
Expand All @@ -37,4 +39,3 @@ Scikit-Learn Extensions

econml.sklearn_extensions.linear_model
econml.sklearn_extensions.model_selection
econml.sklearn_extensions.ensemble
17 changes: 13 additions & 4 deletions doc/spec/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,6 @@ The latter translates to estimating a local gradient around a treatment vector c
\partial\tau(\vec{t}, \vec{x}) = \E\left[\nabla_{\vec{t}} Y(\vec{t}) | X=\vec{x}\right] \tag{marginal CATE}
We will refer to the latter as the *heterogeneous marginal effect*. [1]_
Finally, we might not only be interested in the effect but also in the actual *counterfactual prediction*, i.e. estimating the quatity:

.. math ::
\mu(\vec{t}, \vec{x}) = \E\left[Y(\vec{t}) | X=\vec{x}\right] \tag{counterfactual prediction}

We assume we have data that are generated from some collection policy. In particular, we assume that we have data of the form:
:math:`\{Y_i(T_i), T_i, X_i, W_i, Z_i\}`, where :math:`Y_i(T_i)` is the observed outcome for the chosen treatment,
Expand All @@ -43,6 +39,19 @@ The variables :math:`X_i` can also be thought of as *control* variables, but the
they are a subset of the controls with respect to which we want to measure treatment effect heterogeneity.
We will refer to them as *features*.

Finally, some times we might not only be interested in the effect but also in the actual *counterfactual prediction*, i.e. estimating the quatity:

.. math ::
\mu(\vec{t}, \vec{x}) = \E\left[Y(\vec{t}) | X=\vec{x}\right] \tag{counterfactual prediction}
Our package does not offer support for counterfactual prediction. However, for most of our estimators (the ones
assuming a linear-in-treatment model), counterfactual prediction can be easily constructed by combining any baseline predictive model
with our causal effect model, i.e. train any machine learning model :math:`b(\vec{t}, \vec{x})` to solve the regression/classification
problem :math:`\E[Y | T=\vec{t}, X=\vec{x}]`, and then set :math:`\mu(vec{t}, \vec{x}) = \tau(\vec{t}, T, \vec{x}) + b(T, \vec{x})`,
where :math:`T` is either the observed treatment for that sample under the observational policy or the treatment
that the observational policy would have assigned to that sample. These auxiliary ML models can be trained
with any machine learning package outside of EconML.

.. rubric::
Structural Equation Formulation

Expand Down
6 changes: 3 additions & 3 deletions doc/spec/comparison.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ Detailed estimator comparison
+---------------------------------------------+--------------+--------------+------------------+-------------+-----------------+------------+--------------+--------------------+
| :class:`.LinearDRLearner` | Categorical | | Yes | | Projected | | Yes | |
+---------------------------------------------+--------------+--------------+------------------+-------------+-----------------+------------+--------------+--------------------+
| :class:`.ForestDML` | 1-d/Binary | | Yes | Yes | | Yes | | Yes |
| :class:`.CausalForestDML` | Any | | Yes | Yes | | Yes | Yes | Yes |
+---------------------------------------------+--------------+--------------+------------------+-------------+-----------------+------------+--------------+--------------------+
| :class:`.ForestDRLearner` | Categorical | | Yes | | | | Yes | Yes |
+---------------------------------------------+--------------+--------------+------------------+-------------+-----------------+------------+--------------+--------------------+
| :class:`.ContinuousTreatmentOrthoForest` | Continuous | | Yes | Yes | | | Yes | Yes |
| :class:`.DMLOrthoForest` | Any | | Yes | Yes | | | Yes | Yes |
+---------------------------------------------+--------------+--------------+------------------+-------------+-----------------+------------+--------------+--------------------+
| :class:`.DiscreteTreatmentOrthoForest` | Categorical | | Yes | | | | Yes | Yes |
| :class:`.DROrthoForest` | Categorical | | Yes | | | | Yes | Yes |
+---------------------------------------------+--------------+--------------+------------------+-------------+-----------------+------------+--------------+--------------------+
| :mod:`~econml.metalearners` | Categorical | | | | | Yes | Yes | Yes |
+---------------------------------------------+--------------+--------------+------------------+-------------+-----------------+------------+--------------+--------------------+
Expand Down
Loading

0 comments on commit bb042d5

Please sign in to comment.