Skip to content

Commit

Permalink
Minor fixes to template estimators
Browse files Browse the repository at this point in the history
  • Loading branch information
TomDLT authored Sep 9, 2016
1 parent 8a8797c commit 7cfc247
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions skltemplate/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class TemplateEstimator(BaseEstimator):
A parameter used for demonstation of how to pass and store paramters.
"""
def __init__(self, demo_param='demo_param'):
self.demo_param = 'demo_param'
self.demo_param = demo_param

def fit(self, X, y):
"""A reference implementation of a fitting function
Expand All @@ -29,12 +29,14 @@ def fit(self, X, y):
y : array-like, shape = [n_samples] or [n_samples, n_outputs]
The target values (class labels in classification, real numbers in
regression).
Returns
-------
self : object
Returns self.
"""
X, y = check_X_y(X, y)
# Return the estimator
return self

def predict(self, X):
Expand Down Expand Up @@ -81,6 +83,7 @@ def fit(self, X, y):
The training input samples.
y : array-like, shape = [n_samples]
The target values. An array of int.
Returns
-------
self : object
Expand Down Expand Up @@ -143,8 +146,10 @@ def fit(self, X, y=None):
----------
X : array-like or sparse matrix of shape = [n_samples, n_features]
The training input samples.
y : array-like, shape = [n_samples]
The target values. An array of int.
y : None
There is no need of a target in a transformer, yet the pipeline API
requires this parameter.
Returns
-------
self : object
Expand All @@ -154,7 +159,7 @@ def fit(self, X, y=None):

self.input_shape_ = X.shape

# Return the classifier
# Return the transformer
return self

def transform(self, X):
Expand Down

0 comments on commit 7cfc247

Please sign in to comment.