-
Notifications
You must be signed in to change notification settings - Fork 70
/
prtOutlierRemovalNStd.m
94 lines (77 loc) · 3.19 KB
/
prtOutlierRemovalNStd.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
classdef prtOutlierRemovalNStd < prtOutlierRemoval
% prtOutlierRemovalNStd Removes outliers from a prtDataSet
%
% NSTDOUT = prtOutlierRemovalNStd creates a pre-processing
% object that flags as outliers data where any of the feature values is
% more then nStd standard deviations from the mean of that feature.
%
% prtOutlierRemovalNStd has the following properties:
%
% nStd - The number of standard deviations at which to flag an
% observation as an outlier an observation (default = 3)
%
% A prtOutlierRemovalNStd object also inherits all properties and
% functions from the prtOutlierRemoval class. For more information
% on how to control the behaviour of outlier removal objects, see the
% help for prtOutlierRemoval.
%
% Example:
%
% dataSet = prtDataGenUnimodal; % Load a data Set
% outlier = prtDataSetClass([-10 -10],1); % Create and insert
% dataSet = catObservations(dataSet,outlier); % an outlier
%
% % Create the prtOutlierRemoval object
% nStdRemove = prtOutlierRemovalNStd('runMode','removeObservation');
%
% nStdRemove = nStdRemove.train(dataSet); % Train and run
% dataSetNew = nStdRemove.run(dataSet);
%
% % Plot the results
% subplot(2,1,1); plot(dataSet);
% title('Original Data');
% subplot(2,1,2); plot(dataSetNew);
% title('NstdOutlierRemove Data');
%
% See Also: prtOutlierRemoval,
% prtOutlierRemovalNonFinite,prtOutlierRemovalMissingData
properties (SetAccess=private)
name = 'Standard Deviation Based Outlier Removal'; % Standard Deviation Based Outlier Removal
nameAbbreviation = 'nStd' % nStd
end
properties
nStd = 3; % The number of standard deviations beyond which to remove data
end
% General Classifier Properties
properties (SetAccess=private)
stdVector = []; % The standard deviation vector
meanVector = []; % The mean vector
end
methods
% Allow for string, value pairs
function Obj = prtOutlierRemovalNStd(varargin)
Obj.isCrossValidateValid = false; %can't cross validate because nStd changes the size of datasets
Obj = prtUtilAssignStringValuePairs(Obj,varargin{:});
end
end
methods
function Obj = set.nStd(Obj,value)
if ~prtUtilIsPositiveScalarInteger(value)
error('prt:prtOutlierRemovalNStd','value (%s) must be a positive scalar integer',mat2str(value));
end
Obj.nStd = value;
end
end
methods (Access = protected, Hidden = true)
function Obj = trainAction(Obj,DataSet)
Obj.meanVector = mean(DataSet.getObservations(),1);
Obj.stdVector = std(DataSet.getObservations(),1);
end
function outlierIndices = calculateOutlierIndices(Obj,DataSet)
x = DataSet.getObservations;
x = bsxfun(@minus,x,Obj.meanVector);
x = bsxfun(@rdivide,x,Obj.stdVector);
outlierIndices = abs(x) > Obj.nStd;
end
end
end