Skip to content

Commit

Permalink
better readme
Browse files Browse the repository at this point in the history
  • Loading branch information
louisabraham committed Sep 23, 2024
1 parent 1baa417 commit f195d64
Showing 1 changed file with 4 additions and 10 deletions.
14 changes: 4 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pip install lassonet

We have designed the code to follow scikit-learn's standards to the extent possible (e.g. [linear_model.Lasso](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Lasso.html)).

```
```python
from lassonet import LassoNetClassifierCV
model = LassoNetClassifierCV() # LassoNetRegressorCV
path = model.fit(X_train, y_train)
Expand Down Expand Up @@ -60,9 +60,8 @@ Here are some examples of what you can do with LassoNet. Note that you can switc

#### Using the base interface

The base interface implements a `.fit()` method that is not very useful as it computes a path but does not store any intermediate result.
The original paper describes how to train LassoNet along a regularization path. This requires the user to manually select a model from the path and made the `.fit()` method useless since the resulting model is always empty. This feature is still available with the `.path(return_state_dicts=True)` method for any base model and returns a list of checkpoints that can be loaded with `.load()`.

Usually, you want to store the intermediate results (with `return_state_dicts=True`) and then load one of the models from the path into the model to inspect it.

```python
from lassonet import LassoNetRegressor, plot_path
Expand All @@ -88,7 +87,8 @@ You get a `model.feature_importances_` attribute that is the value of the L1 reg

#### Using the cross-validation interface

The cross-validation interface computes validation scores on multiple folds before running a final path on the whole training dataset with the best regularization parameter.

We integrated support for cross-validation (5-fold by default) in the estimators whose name ends with `CV`. For each fold, a path is trained. The best regularization value is then chosen to maximize the average score over all folds. The model is then retrained on the whole training dataset to reach that regularization.

```python
model = LassoNetRegressorCV()
Expand Down Expand Up @@ -152,12 +152,6 @@ Note that cross-validation, group feature selection and automatic `lambda_start`

We are currently working (among others) on adding support for convolution layers, auto-encoders and online logging of experiments.

## Cross-validation

The original paper describes how to train LassoNet along a regularization path. This requires the user to manually select a model from the path and made the `.fit()` method useless since the resulting model is always empty. This feature is still available with the `.path()` method for any model or the `lassonet_path` function and returns a list of checkpoints that can be loaded with `.load()`.

Since then, we integrated support for cross-validation (5-fold by default) in the estimators whose name ends with `CV`. For each fold, a path is trained. The best regularization value is then chosen to maximize the average performance over all folds. The model is then retrained on the whole training dataset to reach that regularization.

## Website

LassoNet's website is [https:lasso-net.github.io/](https://lasso-net.github.io/). It contains many useful references including the paper, live talks and additional documentation.
Expand Down

0 comments on commit f195d64

Please sign in to comment.