Skip to content

Commit

Permalink
fix: Do not use aspect ratio of 1 for residual plot (#1335)
Browse files Browse the repository at this point in the history
closes #1292 

We forced the aspect ratio of 1.0 for the residual plot. However the
residual is computing `actual vs predicted` and the scale of the
residual can be really different from the prediction if the error is
low, explaining the problem of the plot.

This PR only set the aspect ratio for the actual vs predicted plot only.
  • Loading branch information
glemaitre authored Feb 17, 2025
1 parent 035e4f6 commit ad1da20
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
1 change: 0 additions & 1 deletion skore/src/skore/sklearn/_plot/prediction_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,6 @@ def plot(
else:
self.line_ = ax.plot(x_range_perfect_pred, [0, 0], **perfect_line_kwargs)[0]
ax.set(
aspect="equal",
xlim=x_range_perfect_pred,
ylim=y_range_perfect_pred,
xticks=np.linspace(
Expand Down
6 changes: 6 additions & 0 deletions skore/tests/unit/sklearn/plot/test_prediction_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def test_prediction_error_display_regression(pyplot, regression_data, subsample)
assert display.ax_.get_xlabel() == "Predicted values"
assert display.ax_.get_ylabel() == "Residuals (actual - predicted)"

assert display.ax_.get_aspect() not in ("equal", 1.0)


def test_prediction_error_cross_validation_display_regression(
pyplot, regression_data_no_split
Expand Down Expand Up @@ -105,6 +107,8 @@ def test_prediction_error_cross_validation_display_regression(
assert display.ax_.get_xlabel() == "Predicted values"
assert display.ax_.get_ylabel() == "Residuals (actual - predicted)"

assert display.ax_.get_aspect() not in ("equal", 1.0)


def test_prediction_error_display_regression_kind(pyplot, regression_data):
"""Check the attributes when switching to the "actual_vs_predicted" kind."""
Expand Down Expand Up @@ -164,6 +168,8 @@ def test_prediction_error_cross_validation_display_regression_kind(
assert display.ax_.get_xlabel() == "Predicted values"
assert display.ax_.get_ylabel() == "Actual values"

assert display.ax_.get_aspect() in ("equal", 1.0)


def test_prediction_error_display_data_source(pyplot, regression_data):
"""Check that we can pass the `data_source` argument to the prediction error
Expand Down

0 comments on commit ad1da20

Please sign in to comment.