diff --git a/CHANGES.txt b/CHANGES.txt index 27d4ea735..f9beede5e 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -24,4 +24,5 @@ v<0.4.9>, <06/09/2018> -- Add new utility functions and improve documentations. v<0.5.0>, <06/10/2018> -- Refactor models and improve documentation. v<0.5.1>, <06/12/2018> -- Add MCD detector and more Jupyter notebooks. v<0.5.2>, <06/13/2018> -- Incremental changes. -v<0.5.3>, <06/14/2018> -- Incremental changes. \ No newline at end of file +v<0.5.3>, <06/14/2018> -- Incremental changes. +v<0.5.4>, <06/18/2018> -- Add CBLOF model and incremental improvements. \ No newline at end of file diff --git a/README.md b/README.md index 297ab616b..93ab635cf 100644 --- a/README.md +++ b/README.md @@ -64,7 +64,7 @@ detection utility functions. 2. Proximity-Based Outlier Detection Models: 1. **LOF: Local Outlier Factor** [1] - 2. **CBLOF: Clustering-Based Local Outlier Factor** [15] (work in progress) + 2. **CBLOF: Clustering-Based Local Outlier Factor** [15] 3. **HBOS: Histogram-based Outlier Score** [5] 4. **kNN: k Nearest Neighbors** (use the distance to the kth nearest neighbor as the outlier score) [13] diff --git a/docs/index.rst b/docs/index.rst index 7e627c548..0a534631f 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -63,13 +63,13 @@ detection utility functions. i. **LOF: Local Outlier Factor** :cite:`a-breunig2000lof`: :class:`pyod.models.lof.LOF` ii. **CBLOF: Clustering-Based Local Outlier Factor** :cite:`a-he2003discovering`: :class:`pyod.models.cblof.CBLOF` - ii. **kNN: k Nearest Neighbors** (use the distance to the kth nearest - neighbor as the outlier score) :cite:`a-ramaswamy2000efficient,a-angiulli2002fast`: :class:`pyod.models.knn.KNN` - iii. **Average kNN** (use the average distance to k nearest neighbors as - the outlier score): :class:`pyod.models.knn.KNN` - iv. **Median kNN** (use the median distance to k nearest neighbors - as the outlier score): :class:`pyod.models.knn.KNN` - v. **HBOS: Histogram-based Outlier Score** :cite:`a-goldstein2012histogram`: :class:`pyod.models.hbos.HBOS` + iii. **kNN: k Nearest Neighbors** (use the distance to the kth nearest + neighbor as the outlier score) :cite:`a-ramaswamy2000efficient,a-angiulli2002fast`: :class:`pyod.models.knn.KNN` + iv. **Average kNN** (use the average distance to k nearest neighbors as + the outlier score): :class:`pyod.models.knn.KNN` + v. **Median kNN** (use the median distance to k nearest neighbors + as the outlier score): :class:`pyod.models.knn.KNN` + vi. **HBOS: Histogram-based Outlier Score** :cite:`a-goldstein2012histogram`: :class:`pyod.models.hbos.HBOS` 3. Probabilistic Models for Outlier Detection: diff --git a/docs/pyod.models.rst b/docs/pyod.models.rst index 2ac21a153..2e590f7b3 100644 --- a/docs/pyod.models.rst +++ b/docs/pyod.models.rst @@ -23,7 +23,7 @@ pyod.models.base module :inherited-members: pyod.models.cblof module ------------------------ +------------------------ .. automodule:: pyod.models.cblof :members: diff --git a/examples/cblof_example.py b/examples/cblof_example.py new file mode 100644 index 000000000..245d04130 --- /dev/null +++ b/examples/cblof_example.py @@ -0,0 +1,145 @@ +# -*- coding: utf-8 -*- +"""Example of using Cluster-based Local Outlier Factor (CBLOF) for outlier +detection +""" +# Author: Yue Zhao +# License: BSD 2 clause + +from __future__ import division +from __future__ import print_function + +import os +import sys + +# temporary solution for relative imports in case pyod is not installed +# if pyod is installed, no need to use the following line +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from sklearn.utils import check_X_y +import matplotlib.pyplot as plt +from matplotlib.lines import Line2D + +from pyod.models.cblof import CBLOF +from pyod.utils.data import generate_data +from pyod.utils.data import get_color_codes +from pyod.utils.data import evaluate_print + + +def visualize(clf_name, X_train, y_train, X_test, y_test, y_train_pred, + y_test_pred, show_figure=True, + save_figure=False): # pragma: no cover + """ + Utility function for visualizing the results in examples + Internal use only + + :param clf_name: The name of the detector + :type clf_name: str + + :param X_train: The training samples + :param X_train: numpy array of shape (n_samples, n_features) + + :param y_train: The ground truth of training samples + :type y_train: list or array of shape (n_samples,) + + :param X_test: The test samples + :type X_test: numpy array of shape (n_samples, n_features) + + :param y_test: The ground truth of test samples + :type y_test: list or array of shape (n_samples,) + + :param y_train_pred: The predicted outlier scores on the training samples + :type y_train_pred: numpy array of shape (n_samples, n_features) + + :param y_test_pred: The predicted outlier scores on the test samples + :type y_test_pred: numpy array of shape (n_samples, n_features) + + :param show_figure: If set to True, show the figure + :type show_figure: bool, optional (default=True) + + :param save_figure: If set to True, save the figure to the local + :type save_figure: bool, optional (default=False) + """ + + if X_train.shape[1] != 2 or X_test.shape[1] != 2: + raise ValueError("Input data has to be 2-d for visualization. The " + "input data has {shape}.".format(shape=X_train.shape)) + + X_train, y_train = check_X_y(X_train, y_train) + X_test, y_test = check_X_y(X_test, y_test) + c_train = get_color_codes(y_train) + c_test = get_color_codes(y_test) + + fig = plt.figure(figsize=(12, 10)) + plt.suptitle("Demo of {clf_name}".format(clf_name=clf_name)) + + fig.add_subplot(221) + plt.scatter(X_train[:, 0], X_train[:, 1], c=c_train) + plt.title('Train ground truth') + legend_elements = [Line2D([0], [0], marker='o', color='w', label='normal', + markerfacecolor='b', markersize=8), + Line2D([0], [0], marker='o', color='w', label='outlier', + markerfacecolor='r', markersize=8)] + + plt.legend(handles=legend_elements, loc=4) + + fig.add_subplot(222) + plt.scatter(X_test[:, 0], X_test[:, 1], c=c_test) + plt.title('Test ground truth') + plt.legend(handles=legend_elements, loc=4) + + fig.add_subplot(223) + plt.scatter(X_train[:, 0], X_train[:, 1], c=y_train_pred) + plt.title('Train prediction by {clf_name}'.format(clf_name=clf_name)) + legend_elements = [Line2D([0], [0], marker='o', color='w', label='normal', + markerfacecolor='0', markersize=8), + Line2D([0], [0], marker='o', color='w', label='outlier', + markerfacecolor='yellow', markersize=8)] + plt.legend(handles=legend_elements, loc=4) + + fig.add_subplot(224) + plt.scatter(X_test[:, 0], X_test[:, 1], c=y_test_pred) + plt.title('Test prediction by {clf_name}'.format(clf_name=clf_name)) + plt.legend(handles=legend_elements, loc=4) + + if save_figure: + plt.savefig('{clf_name}.png'.format(clf_name=clf_name), dpi=300) + if show_figure: + plt.show() + return + + +if __name__ == "__main__": + contamination = 0.1 # percentage of outliers + n_train = 200 # number of training points + n_test = 100 # number of testing points + + # Generate sample data + X_train, y_train, X_test, y_test = \ + generate_data(n_train=n_train, + n_test=n_test, + n_features=2, + contamination=contamination, + random_state=42) + + # train CBLOF detector + clf_name = 'CBLOF' + clf = CBLOF() + clf.fit(X_train) + + # get the prediction labels and outlier scores of the training data + y_train_pred = clf.labels_ # binary labels (0: inliers, 1: outliers) + y_train_scores = clf.decision_scores_ # raw outlier scores + + # get the prediction on the test data + y_test_pred = clf.predict(X_test) # outlier labels (0 or 1) + y_test_scores = clf.decision_function(X_test) # outlier scores + + # evaluate and print the results + print("\nOn Training Data:") + evaluate_print(clf_name, y_train, y_train_scores) + print("\nOn Test Data:") + evaluate_print(clf_name, y_test, y_test_scores) + + # visualize the results + visualize(clf_name, X_train, y_train, X_test, y_test, y_train_pred, + y_test_pred, show_figure=True, save_figure=False) diff --git a/pyod/__init__.py b/pyod/__init__.py index fa893c86d..f2ff175dc 100644 --- a/pyod/__init__.py +++ b/pyod/__init__.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -__version__ = '0.5.3' +__version__ = '0.5.4' from . import models from . import utils diff --git a/pyod/models/__init__.py b/pyod/models/__init__.py index 24a09577d..e63b814e7 100644 --- a/pyod/models/__init__.py +++ b/pyod/models/__init__.py @@ -13,6 +13,7 @@ from .pca import PCA __all__ = ['ABOD', + 'CBLOF', 'clone', 'aom', 'moa', 'average', 'maximization', 'FeatureBagging', diff --git a/pyod/models/cblof.py b/pyod/models/cblof.py index de483964d..86f901c9a 100644 --- a/pyod/models/cblof.py +++ b/pyod/models/cblof.py @@ -7,6 +7,7 @@ from __future__ import division from __future__ import print_function +import warnings import numpy as np from scipy.spatial.distance import cdist from sklearn.cluster import MiniBatchKMeans @@ -33,9 +34,12 @@ class CBLOF(BaseDetector): Use weighting for outlier factor based on the sizes of the clusters as proposed in the original publication. Since this might lead to unexpected - behavior (outliers close to small clusters are not found), it can be - disabled and outliers scores are solely computed based on their distance to - the cluster center. + behavior (outliers close to small clusters are not found), it is disabled + by default.Outliers scores are solely computed based on their distance to + the closest large cluster center. + + By default, MiniBatchKMeans is used for clustering algorithm instead of + Squeezer algorithm mentioned in the original paper for multiple reasons. See :cite:`he2003discovering` for details. @@ -48,14 +52,28 @@ class CBLOF(BaseDetector): define the threshold on the decision function. :type contamination: float in (0., 0.5), optional (default=0.1) - :param alpha: Coefficient for deciding small and large clusters. + :param clustering_estimator: The base clustering algorithm for performing + data clustering. A valid clustering algorithm should be passed in. + The estimator should have standard sklearn APIs, fit() and predict(). + The estimator should have attributes labels\_ and cluster_centers\_. + If cluster_centers\_ is not in the attributes once the model is fit, it + is calculated as the mean of the samples in a cluster. + + If not set, CBLOF uses MiniBatchKMeans for scalability. See + http://scikit-learn.org/stable/modules/generated/sklearn.cluster.MiniBatchKMeans.html + :type clustering_estimator: Estimator, optional (default=None) + + :param alpha: Coefficient for deciding small and large clusters. The ratio + of the number of samples in large clusters to the number of samples in + small clusters. :type alpha: float in (0.5, 1), optional (default=0.9) - :param beta: Coefficient for deciding small and large clusters. + :param beta: Coefficient for deciding small and large clusters. For a list + sorted clusters by size `|C1|, \|C2|, ..., |Cn|, beta = |Ck|/|Ck-1|` :type beta: int or float in (1,), optional (default=5). :param use_weights: If set to True, the size of clusters are used as - weights. + weights in outlier score calculation. :type use_weights: bool, optional (default=False) :param random_state: If int, random_state is the seed used by the random @@ -65,10 +83,30 @@ class CBLOF(BaseDetector): :type random_state: int, RandomState instance or None, optional (default=None) + :var clustering_estimator\_: Base estimator for clustering. + :vartype clustering_estimator\_: Estimator + + :var cluster_labels\_: Cluster assignment for the training samples + :vartype cluster_labels\_: list of shape (n_samples,) + + :var cluster_sizes\_: The size of each cluster once fitted with the + training data + :vartype cluster_sizes\_: list of shape (n_clusters,) + + :var cluster_centers\_: The center of each cluster. + :vartype cluster_centers\_: numpy array of shape (n_clusters, n_features) + + :var small_cluster_labels\_: The cluster assignments belonging to small + clusters + :vartype small_cluster_labels\_: list of clusters numbers + + :var large_cluster_labels\_: The cluster assignments belonging to large + clusters + :vartype large_cluster_labels\_: list of clusters numbers + :var decision_scores\_: The outlier scores of the training data. The higher, the more abnormal. Outliers tend to have higher - scores. This value is available once the detector is - fitted. + scores. This value is available once the detector is fitted. :vartype decision_scores\_: numpy array of shape (n_samples,) :var threshold\_: The threshold is based on ``contamination``. It is the @@ -109,25 +147,72 @@ def fit(self, X, y=None): # Validate inputs X and y (optional) X = check_array(X) self._set_n_classes(y) + n_samples, n_features = X.shape + # check parameters # number of clusters are default to 8 self._validate_estimator( default=MiniBatchKMeans(n_clusters=self.n_clusters, random_state=self.random_state)) + self.clustering_estimator_.fit(X=X, y=y) + # Get the labels of the clustering results + # labels_ is consistent across sklearn clustering algorithms + self.cluster_labels_ = self.clustering_estimator_.labels_ + self.cluster_sizes_ = np.bincount(self.cluster_labels_) + self._set_cluster_centers(X, n_features) + self._set_small_large_clusters(n_samples) + + self.decision_scores_ = self._decision_function(X, + self.cluster_labels_) + + self._process_decision_scores() + return self + + def decision_function(self, X): + check_is_fitted(self, ['decision_scores_', 'threshold_', 'labels_']) + X = check_array(X) + labels = self.clustering_estimator_.predict(X) + return self._decision_function(X, labels) + + def _validate_estimator(self, default=None): + """Check the value of alpha and beta and clustering algorithm. + """ + check_parameter(self.alpha, low=0, high=1, param_name='alpha', include_left=False, include_right=False) check_parameter(self.beta, low=1, param_name='beta', include_left=False) - self.clustering_estimator_.fit(X=X, y=y) - # Get the lables of the clustering results - # labels_ is consistent across sklearn clustering algorithms - self.cluster_labels_ = self.clustering_estimator_.labels_ - self.cluster_centers_ = self.clustering_estimator_.cluster_centers_ + if self.clustering_estimator is not None: + self.clustering_estimator_ = self.clustering_estimator + else: + self.clustering_estimator_ = default - # Sort the index of clusters by the number of elements + # make sure the base clustering algorithm is valid + if self.clustering_estimator_ is None: + raise ValueError("clustering algorithm cannot be None") + + check_estimator(self.clustering_estimator_) + + def _set_cluster_centers(self, X, n_features): + # Noted not all clustering algorithms have cluster_centers_ + if hasattr(self.clustering_estimator_, 'cluster_centers_'): + self.cluster_centers_ = self.clustering_estimator_.cluster_centers_ + else: + # Set the cluster center as the mean of all the samples within + # the cluster + warnings.warn("The chosen clustering for CBLOF does not have" + "the center of clusters. Calculate the center" + "as the mean of the clusters.") + self.cluster_centers_ = np.zeros([self.n_clusters, n_features]) + for i in range(self.n_clusters): + self.cluster_centers_[i, :] = np.mean( + X[np.where(self.cluster_labels_ == i)], axis=0) + + def _set_small_large_clusters(self, n_samples): + # Sort the index of clusters by the number of samples belonging to it size_clusters = np.bincount(self.cluster_labels_) sorted_cluster_indices = np.argsort(size_clusters) @@ -136,102 +221,63 @@ def fit(self, X, y=None): alpha_list = [] beta_list = [] for i in range(1, self.n_clusters): - print(i, size_clusters) - print(np.sum(size_clusters[sorted_cluster_indices[-1 * i:]])) temp_sum = np.sum(size_clusters[sorted_cluster_indices[-1 * i:]]) - if temp_sum >= X.shape[0] * self.alpha: - # print('stop', i, 'alpha') + if temp_sum >= n_samples * self.alpha: alpha_list.append(i) if size_clusters[sorted_cluster_indices[i]] / size_clusters[ sorted_cluster_indices[i - 1]] >= self.beta: - # print('stop', i, 'beta') beta_list.append(i) # Find the separation index fulfills both alpha and beta intersection = np.intersect1d(alpha_list, beta_list) if len(intersection) > 0: - self.clustering_threshold_ = intersection[0] + self._clustering_threshold = intersection[0] elif len(alpha_list) > 0: - self.clustering_threshold_ = alpha_list[0] + self._clustering_threshold = alpha_list[0] elif len(beta_list) > 0: - self.clustering_threshold_ = beta_list[0] + self._clustering_threshold = beta_list[0] else: raise ValueError("Could not form valid cluster separation. Try " "reset n_clusters or change clustering method") - # print(self.clustering_threshold_) - - # Weights are calculated as the number of elements in the cluster - self.weights_ = size_clusters[self.cluster_labels_] - self.small_cluster_labels_ = sorted_cluster_indices[ - 0:self.clustering_threshold_] + 0:self._clustering_threshold] self.large_cluster_labels_ = sorted_cluster_indices[ - self.clustering_threshold_:] + self._clustering_threshold:] - self.small_cluster_centers_ = self.cluster_centers_[ - self.small_cluster_labels_] + # No need to calculate samll cluster center + # self.small_cluster_centers_ = self.cluster_centers_[ + # self.small_cluster_labels_] - self.large_cluster_centers_ = self.cluster_centers_[ + self._large_cluster_centers = self.cluster_centers_[ self.large_cluster_labels_] - print(self.small_cluster_labels_) - print(self.large_cluster_labels_) - - self.decision_scores_ = np.zeros([X.shape[0], ]) + def _decision_function(self, X, labels): + # Initialize the score array + scores = np.zeros([X.shape[0], ]) small_indices = np.where( - np.isin(self.cluster_labels_, self.small_cluster_labels_))[0] + np.isin(labels, self.small_cluster_labels_))[0] large_indices = np.where( - np.isin(self.cluster_labels_, self.large_cluster_labels_))[0] - - print(len(small_indices), len(large_indices)) + np.isin(labels, self.large_cluster_labels_))[0] # Calculate the outlier factor for the samples in small clusters dist_to_large_center = cdist(X[small_indices, :], - self.large_cluster_centers_) + self._large_cluster_centers) - self.decision_scores_[small_indices] = np.min(dist_to_large_center, - axis=1) + scores[small_indices] = np.min(dist_to_large_center, axis=1) # Calculate the outlier factor for the samples in large clusters - large_centers = self.cluster_centers_[ - self.cluster_labels_[large_indices]] + large_centers = self.cluster_centers_[labels[large_indices]] - self.decision_scores_[large_indices] = pairwise_distances_no_broadcast( + scores[large_indices] = pairwise_distances_no_broadcast( X[large_indices, :], large_centers) if self.use_weights: - self.decision_scores_ = self.decision_scores_ * self.weights_ - - self._process_decision_scores() - return self - - def decision_function(self, X): - check_is_fitted(self, ['decision_scores_', 'threshold_', 'labels_']) - X = check_array(X) - pass - - def _validate_estimator(self, default=None): - """Check the value of alpha and beta and clustering algorithm. - """ - - check_parameter(self.alpha, 0, 1, param_name='alpha', - include_left=False, include_right=False) + # Weights are calculated as the number of elements in the cluster + scores = scores * self.cluster_sizes_[labels] - check_parameter(self.beta, 0, param_name='alpha', - include_left=False, include_right=False) - - if self.clustering_estimator is not None: - self.clustering_estimator_ = self.clustering_estimator - else: - self.clustering_estimator_ = default - - # make sure the base clustering algorithm is valid - if self.clustering_estimator_ is None: - raise ValueError("clustering algorithm cannot be None") - - check_estimator(self.clustering_estimator_) + return scores.ravel() diff --git a/pyod/test/test_cblof.py b/pyod/test/test_cblof.py index 7e76d9721..655e8ff38 100644 --- a/pyod/test/test_cblof.py +++ b/pyod/test/test_cblof.py @@ -38,95 +38,106 @@ def setUp(self): n_train=self.n_train, n_test=self.n_test, contamination=self.contamination, random_state=42) - self.clf = CBLOF(contamination=self.contamination) + self.clf = CBLOF(contamination=self.contamination, random_state=42) self.clf.fit(self.X_train) - def test_fit(self): - self.clf.fit(self.X_train) + def test_sklearn_estimator(self): + # TODO: sklearn examples are too small to form valid + # check_estimator(self.clf) + pass + + def test_parameters(self): + assert_true(hasattr(self.clf, 'decision_scores_') and + self.clf.decision_scores_ is not None) + assert_true(hasattr(self.clf, 'labels_') and + self.clf.labels_ is not None) + assert_true(hasattr(self.clf, 'threshold_') and + self.clf.threshold_ is not None) + assert_true(hasattr(self.clf, '_mu') and + self.clf._mu is not None) + assert_true(hasattr(self.clf, '_sigma') and + self.clf._sigma is not None) + assert_true(hasattr(self.clf, 'clustering_estimator_') and + self.clf.clustering_estimator_ is not None) + assert_true(hasattr(self.clf, 'cluster_labels_') and + self.clf.cluster_labels_ is not None) + assert_true(hasattr(self.clf, 'cluster_sizes_') and + self.clf.cluster_sizes_ is not None) + assert_true(hasattr(self.clf, 'cluster_centers_') and + self.clf.cluster_centers_ is not None) + assert_true(hasattr(self.clf, '_clustering_threshold') and + self.clf._clustering_threshold is not None) + assert_true(hasattr(self.clf, 'small_cluster_labels_') and + self.clf.small_cluster_labels_ is not None) + assert_true(hasattr(self.clf, 'large_cluster_labels_') and + self.clf.large_cluster_labels_ is not None) + assert_true(hasattr(self.clf, '_large_cluster_centers') and + self.clf._large_cluster_centers is not None) + + def test_train_scores(self): + assert_equal(len(self.clf.decision_scores_), self.X_train.shape[0]) + + def test_prediction_scores(self): + pred_scores = self.clf.decision_function(self.X_test) + # check score shapes + assert_equal(pred_scores.shape[0], self.X_test.shape[0]) + # check performance + assert_greater(roc_auc_score(self.y_test, pred_scores), self.roc_floor) + + def test_prediction_labels(self): + pred_labels = self.clf.predict(self.X_test) + assert_equal(pred_labels.shape, self.y_test.shape) + + def test_prediction_proba(self): + pred_proba = self.clf.predict_proba(self.X_test) + assert_greater_equal(pred_proba.min(), 0) + assert_less_equal(pred_proba.max(), 1) + + def test_prediction_proba_linear(self): + pred_proba = self.clf.predict_proba(self.X_test, method='linear') + assert_greater_equal(pred_proba.min(), 0) + assert_less_equal(pred_proba.max(), 1) + + def test_prediction_proba_unify(self): + pred_proba = self.clf.predict_proba(self.X_test, method='unify') + assert_greater_equal(pred_proba.min(), 0) + assert_less_equal(pred_proba.max(), 1) + + def test_prediction_proba_parameter(self): + with assert_raises(ValueError): + self.clf.predict_proba(self.X_test, method='something') + + def test_fit_predict(self): + pred_labels = self.clf.fit_predict(self.X_train) + assert_equal(pred_labels.shape, self.y_train.shape) + + def test_fit_predict_score(self): + self.clf.fit_predict_score(self.X_test, self.y_test) + self.clf.fit_predict_score(self.X_test, self.y_test, + scoring='roc_auc_score') + self.clf.fit_predict_score(self.X_test, self.y_test, + scoring='prc_n_score') + with assert_raises(NotImplementedError): + self.clf.fit_predict_score(self.X_test, self.y_test, + scoring='something') + + def test_predict_rank(self): + pred_socres = self.clf.decision_function(self.X_test) + pred_ranks = self.clf._predict_rank(self.X_test) + + # assert the order is reserved + assert_allclose(rankdata(pred_ranks), rankdata(pred_socres), atol=2) + assert_array_less(pred_ranks, self.X_train.shape[0] + 1) + assert_array_less(-0.1, pred_ranks) + + def test_predict_rank_normalized(self): + pred_socres = self.clf.decision_function(self.X_test) + pred_ranks = self.clf._predict_rank(self.X_test, normalized=True) - # def test_sklearn_estimator(self): - # check_estimator(self.clf) - # - # def test_parameters(self): - # assert_true(hasattr(self.clf, 'decision_scores_') and - # self.clf.decision_scores_ is not None) - # assert_true(hasattr(self.clf, 'labels_') and - # self.clf.labels_ is not None) - # assert_true(hasattr(self.clf, 'threshold_') and - # self.clf.threshold_ is not None) - # assert_true(hasattr(self.clf, '_mu') and - # self.clf._mu is not None) - # assert_true(hasattr(self.clf, '_sigma') and - # self.clf._sigma is not None) - # assert_true(hasattr(self.clf, 'n_neighbors_') and - # self.clf.n_neighbors_ is not None) - # - # def test_train_scores(self): - # assert_equal(len(self.clf.decision_scores_), self.X_train.shape[0]) - # - # def test_prediction_scores(self): - # pred_scores = self.clf.decision_function(self.X_test) - # - # # check score shapes - # assert_equal(pred_scores.shape[0], self.X_test.shape[0]) - # - # # check performance - # assert_greater(roc_auc_score(self.y_test, pred_scores), self.roc_floor) - # - # def test_prediction_labels(self): - # pred_labels = self.clf.predict(self.X_test) - # assert_equal(pred_labels.shape, self.y_test.shape) - # - # def test_prediction_proba(self): - # pred_proba = self.clf.predict_proba(self.X_test) - # assert_greater_equal(pred_proba.min(), 0) - # assert_less_equal(pred_proba.max(), 1) - # - # def test_prediction_proba_linear(self): - # pred_proba = self.clf.predict_proba(self.X_test, method='linear') - # assert_greater_equal(pred_proba.min(), 0) - # assert_less_equal(pred_proba.max(), 1) - # - # def test_prediction_proba_unify(self): - # pred_proba = self.clf.predict_proba(self.X_test, method='unify') - # assert_greater_equal(pred_proba.min(), 0) - # assert_less_equal(pred_proba.max(), 1) - # - # def test_prediction_proba_parameter(self): - # with assert_raises(ValueError): - # self.clf.predict_proba(self.X_test, method='something') - # - # def test_fit_predict(self): - # pred_labels = self.clf.fit_predict(self.X_train) - # assert_equal(pred_labels.shape, self.y_train.shape) - # - # def test_fit_predict_score(self): - # self.clf.fit_predict_score(self.X_test, self.y_test) - # self.clf.fit_predict_score(self.X_test, self.y_test, - # scoring='roc_auc_score') - # self.clf.fit_predict_score(self.X_test, self.y_test, - # scoring='prc_n_score') - # with assert_raises(NotImplementedError): - # self.clf.fit_predict_score(self.X_test, self.y_test, - # scoring='something') - # - # def test_predict_rank(self): - # pred_socres = self.clf.decision_function(self.X_test) - # pred_ranks = self.clf._predict_rank(self.X_test) - # - # # assert the order is reserved - # assert_allclose(rankdata(pred_ranks), rankdata(pred_socres), atol=2) - # assert_array_less(pred_ranks, self.X_train.shape[0] + 1) - # assert_array_less(-0.1, pred_ranks) - # - # def test_predict_rank_normalized(self): - # pred_socres = self.clf.decision_function(self.X_test) - # pred_ranks = self.clf._predict_rank(self.X_test, normalized=True) - # - # # assert the order is reserved - # assert_allclose(rankdata(pred_ranks), rankdata(pred_socres), atol=2) - # assert_array_less(pred_ranks, 1.01) - # assert_array_less(-0.1, pred_ranks) + # assert the order is reserved + assert_allclose(rankdata(pred_ranks), rankdata(pred_socres), atol=2) + assert_array_less(pred_ranks, 1.01) + assert_array_less(-0.1, pred_ranks) def tearDown(self): pass