|
12 | 12 |
|
13 | 13 | from numpy.testing import assert_almost_equal, assert_array_equal
|
14 | 14 |
|
15 |
| -from sklearn.datasets import load_digits, load_boston |
| 15 | +from sklearn.datasets import load_digits, load_boston, load_iris |
16 | 16 | from sklearn.datasets import make_regression, make_multilabel_classification
|
17 | 17 | from sklearn.exceptions import ConvergenceWarning
|
18 | 18 | from sklearn.externals.six.moves import cStringIO as StringIO
|
|
24 | 24 | from scipy.sparse import csr_matrix
|
25 | 25 | from sklearn.utils.testing import (assert_raises, assert_greater, assert_equal,
|
26 | 26 | assert_false, ignore_warnings)
|
| 27 | +from sklearn.utils.testing import assert_raise_message |
27 | 28 |
|
28 | 29 |
|
29 | 30 | np.seterr(all='warn')
|
|
49 | 50 | Xboston = StandardScaler().fit_transform(boston.data)[: 200]
|
50 | 51 | yboston = boston.target[:200]
|
51 | 52 |
|
| 53 | +iris = load_iris() |
| 54 | + |
| 55 | +X_iris = iris.data |
| 56 | +y_iris = iris.target |
| 57 | + |
52 | 58 |
|
53 | 59 | def test_alpha():
|
54 | 60 | # Test that larger alpha yields weights closer to zero
|
@@ -556,3 +562,29 @@ def test_adaptive_learning_rate():
|
556 | 562 | clf.fit(X, y)
|
557 | 563 | assert_greater(clf.max_iter, clf.n_iter_)
|
558 | 564 | assert_greater(1e-6, clf._optimizer.learning_rate)
|
| 565 | + |
| 566 | + |
| 567 | +@ignore_warnings(RuntimeError) |
| 568 | +def test_warm_start(): |
| 569 | + X = X_iris |
| 570 | + y = y_iris |
| 571 | + |
| 572 | + y_2classes = np.array([0] * 75 + [1] * 75) |
| 573 | + y_3classes = np.array([0] * 40 + [1] * 40 + [2] * 70) |
| 574 | + y_3classes_alt = np.array([0] * 50 + [1] * 50 + [3] * 50) |
| 575 | + y_4classes = np.array([0] * 37 + [1] * 37 + [2] * 38 + [3] * 38) |
| 576 | + y_5classes = np.array([0] * 30 + [1] * 30 + [2] * 30 + [3] * 30 + [4] * 30) |
| 577 | + |
| 578 | + # No error raised |
| 579 | + clf = MLPClassifier(hidden_layer_sizes=2, solver='lbfgs', |
| 580 | + warm_start=True).fit(X, y) |
| 581 | + clf.fit(X, y) |
| 582 | + clf.fit(X, y_3classes) |
| 583 | + |
| 584 | + for y_i in (y_2classes, y_3classes_alt, y_4classes, y_5classes): |
| 585 | + clf = MLPClassifier(hidden_layer_sizes=2, solver='lbfgs', |
| 586 | + warm_start=True).fit(X, y) |
| 587 | + message = ('warm_start can only be used where `y` has the same ' |
| 588 | + 'classes as in the previous call to fit.' |
| 589 | + ' Previously got [0 1 2], `y` has %s' % np.unique(y_i)) |
| 590 | + assert_raise_message(ValueError, message, clf.fit, X, y_i) |
0 commit comments