Skip to content

Commit c011a1a

Browse files
author
linyiqun
committed
装袋提升算法,通过组合分类器进行分类
装袋提升算法,通过组合分类器进行分类
1 parent 17567e3 commit c011a1a

File tree

4 files changed

+402
-0
lines changed

4 files changed

+402
-0
lines changed
Lines changed: 313 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
package DataMining_AdaBoost;
2+
3+
import java.io.BufferedReader;
4+
import java.io.File;
5+
import java.io.FileReader;
6+
import java.io.IOException;
7+
import java.text.MessageFormat;
8+
import java.util.ArrayList;
9+
import java.util.HashMap;
10+
import java.util.Map;
11+
12+
/**
13+
* AdaBoost提升算法工具类
14+
*
15+
* @author lyq
16+
*
17+
*/
18+
public class AdaBoostTool {
19+
// 分类的类别,程序默认为正类1和负类-1
20+
public static final int CLASS_POSITIVE = 1;
21+
public static final int CLASS_NEGTIVE = -1;
22+
23+
// 事先假设的3个分类器(理论上应该重新对数据集进行训练得到)
24+
public static final String CLASSIFICATION1 = "X=2.5";
25+
public static final String CLASSIFICATION2 = "X=7.5";
26+
public static final String CLASSIFICATION3 = "Y=5.5";
27+
28+
// 分类器组
29+
public static final String[] ClASSIFICATION = new String[] {
30+
CLASSIFICATION1, CLASSIFICATION2, CLASSIFICATION3 };
31+
// 分类权重组
32+
private double[] CLASSIFICATION_WEIGHT;
33+
34+
// 测试数据文件地址
35+
private String filePath;
36+
// 误差率阈值
37+
private double errorValue;
38+
// 所有的数据点
39+
private ArrayList<Point> totalPoint;
40+
41+
public AdaBoostTool(String filePath, double errorValue) {
42+
this.filePath = filePath;
43+
this.errorValue = errorValue;
44+
readDataFile();
45+
}
46+
47+
/**
48+
* 从文件中读取数据
49+
*/
50+
private void readDataFile() {
51+
File file = new File(filePath);
52+
ArrayList<String[]> dataArray = new ArrayList<String[]>();
53+
54+
try {
55+
BufferedReader in = new BufferedReader(new FileReader(file));
56+
String str;
57+
String[] tempArray;
58+
while ((str = in.readLine()) != null) {
59+
tempArray = str.split(" ");
60+
dataArray.add(tempArray);
61+
}
62+
in.close();
63+
} catch (IOException e) {
64+
e.getStackTrace();
65+
}
66+
67+
Point temp;
68+
totalPoint = new ArrayList<>();
69+
for (String[] array : dataArray) {
70+
temp = new Point(array[0], array[1], array[2]);
71+
temp.setProbably(1.0 / dataArray.size());
72+
totalPoint.add(temp);
73+
}
74+
}
75+
76+
/**
77+
* 根据当前的误差值算出所得的权重
78+
*
79+
* @param errorValue
80+
* 当前划分的坐标点误差率
81+
* @return
82+
*/
83+
private double calculateWeight(double errorValue) {
84+
double alpha = 0;
85+
double temp = 0;
86+
87+
temp = (1 - errorValue) / errorValue;
88+
alpha = 0.5 * Math.log(temp);
89+
90+
return alpha;
91+
}
92+
93+
/**
94+
* 计算当前划分的误差率
95+
*
96+
* @param pointMap
97+
* 划分之后的点集
98+
* @param weight
99+
* 本次划分得到的分类器权重
100+
* @return
101+
*/
102+
private double calculateErrorValue(
103+
HashMap<Integer, ArrayList<Point>> pointMap) {
104+
double resultValue = 0;
105+
double temp = 0;
106+
double weight = 0;
107+
int tempClassType;
108+
ArrayList<Point> pList;
109+
for (Map.Entry entry : pointMap.entrySet()) {
110+
tempClassType = (int) entry.getKey();
111+
112+
pList = (ArrayList<Point>) entry.getValue();
113+
for (Point p : pList) {
114+
temp = p.getProbably();
115+
// 如果划分类型不相等,代表划错了
116+
if (tempClassType != p.getClassType()) {
117+
resultValue += temp;
118+
}
119+
}
120+
}
121+
122+
weight = calculateWeight(resultValue);
123+
for (Map.Entry entry : pointMap.entrySet()) {
124+
tempClassType = (int) entry.getKey();
125+
126+
pList = (ArrayList<Point>) entry.getValue();
127+
for (Point p : pList) {
128+
temp = p.getProbably();
129+
// 如果划分类型不相等,代表划错了
130+
if (tempClassType != p.getClassType()) {
131+
// 划错的点的权重比例变大
132+
temp *= Math.exp(weight);
133+
p.setProbably(temp);
134+
} else {
135+
// 划对的点的权重比减小
136+
temp *= Math.exp(-weight);
137+
p.setProbably(temp);
138+
}
139+
}
140+
}
141+
142+
// 如果误差率没有小于阈值,继续处理
143+
dataNormalized();
144+
145+
return resultValue;
146+
}
147+
148+
/**
149+
* 概率做归一化处理
150+
*/
151+
private void dataNormalized() {
152+
double sumProbably = 0;
153+
double temp = 0;
154+
155+
for (Point p : totalPoint) {
156+
sumProbably += p.getProbably();
157+
}
158+
159+
// 归一化处理
160+
for (Point p : totalPoint) {
161+
temp = p.getProbably();
162+
p.setProbably(temp / sumProbably);
163+
}
164+
}
165+
166+
/**
167+
* 用AdaBoost算法得到的组合分类器对数据进行分类
168+
*
169+
*/
170+
public void adaBoostClassify() {
171+
double value = 0;
172+
Point p;
173+
174+
calculateWeightArray();
175+
for (int i = 0; i < ClASSIFICATION.length; i++) {
176+
System.out.println(MessageFormat.format("分类器{0}权重为:{1}", (i+1), CLASSIFICATION_WEIGHT[i]));
177+
}
178+
179+
for (int j = 0; j < totalPoint.size(); j++) {
180+
p = totalPoint.get(j);
181+
value = 0;
182+
183+
for (int i = 0; i < ClASSIFICATION.length; i++) {
184+
value += 1.0 * classifyData(ClASSIFICATION[i], p)
185+
* CLASSIFICATION_WEIGHT[i];
186+
}
187+
188+
//进行符号判断
189+
if (value > 0) {
190+
System.out
191+
.println(MessageFormat.format(
192+
"点({0}, {1})的组合分类结果为:1,该点的实际分类为{2}", p.getX(), p.getY(),
193+
p.getClassType()));
194+
} else {
195+
System.out.println(MessageFormat.format(
196+
"点({0}, {1})的组合分类结果为:-1,该点的实际分类为{2}", p.getX(), p.getY(),
197+
p.getClassType()));
198+
}
199+
}
200+
}
201+
202+
/**
203+
* 计算分类器权重数组
204+
*/
205+
private void calculateWeightArray() {
206+
int tempClassType = 0;
207+
double errorValue = 0;
208+
ArrayList<Point> posPointList;
209+
ArrayList<Point> negPointList;
210+
HashMap<Integer, ArrayList<Point>> mapList;
211+
CLASSIFICATION_WEIGHT = new double[ClASSIFICATION.length];
212+
213+
for (int i = 0; i < CLASSIFICATION_WEIGHT.length; i++) {
214+
mapList = new HashMap<>();
215+
posPointList = new ArrayList<>();
216+
negPointList = new ArrayList<>();
217+
218+
for (Point p : totalPoint) {
219+
tempClassType = classifyData(ClASSIFICATION[i], p);
220+
221+
if (tempClassType == CLASS_POSITIVE) {
222+
posPointList.add(p);
223+
} else {
224+
negPointList.add(p);
225+
}
226+
}
227+
228+
mapList.put(CLASS_POSITIVE, posPointList);
229+
mapList.put(CLASS_NEGTIVE, negPointList);
230+
231+
if (i == 0) {
232+
// 最开始的各个点的权重一样,所以传入0,使得e的0次方等于1
233+
errorValue = calculateErrorValue(mapList);
234+
} else {
235+
// 每次把上次计算所得的权重代入,进行概率的扩大或缩小
236+
errorValue = calculateErrorValue(mapList);
237+
}
238+
239+
// 计算当前分类器的所得权重
240+
CLASSIFICATION_WEIGHT[i] = calculateWeight(errorValue);
241+
}
242+
}
243+
244+
/**
245+
* 用各个子分类器进行分类
246+
*
247+
* @param classification
248+
* 分类器名称
249+
* @param p
250+
* 待划分坐标点
251+
* @return
252+
*/
253+
private int classifyData(String classification, Point p) {
254+
// 分割线所属坐标轴
255+
String position;
256+
// 分割线的值
257+
double value = 0;
258+
double posProbably = 0;
259+
double negProbably = 0;
260+
// 划分是否是大于一边的划分
261+
boolean isLarger = false;
262+
String[] array;
263+
ArrayList<Point> pList = new ArrayList<>();
264+
265+
array = classification.split("=");
266+
position = array[0];
267+
value = Double.parseDouble(array[1]);
268+
269+
if (position.equals("X")) {
270+
if (p.getX() > value) {
271+
isLarger = true;
272+
}
273+
274+
// 将训练数据中所有属于这边的点加入
275+
for (Point point : totalPoint) {
276+
if (isLarger && point.getX() > value) {
277+
pList.add(point);
278+
} else if (!isLarger && point.getX() < value) {
279+
pList.add(point);
280+
}
281+
}
282+
} else if (position.equals("Y")) {
283+
if (p.getY() > value) {
284+
isLarger = true;
285+
}
286+
287+
// 将训练数据中所有属于这边的点加入
288+
for (Point point : totalPoint) {
289+
if (isLarger && point.getY() > value) {
290+
pList.add(point);
291+
} else if (!isLarger && point.getY() < value) {
292+
pList.add(point);
293+
}
294+
}
295+
}
296+
297+
for (Point p2 : pList) {
298+
if (p2.getClassType() == CLASS_POSITIVE) {
299+
posProbably++;
300+
} else {
301+
negProbably++;
302+
}
303+
}
304+
305+
//分类按正负类数量进行划分
306+
if (posProbably > negProbably) {
307+
return CLASS_POSITIVE;
308+
} else {
309+
return CLASS_NEGTIVE;
310+
}
311+
}
312+
313+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
package DataMining_AdaBoost;
2+
3+
/**
4+
* AdaBoost提升算法调用类
5+
* @author lyq
6+
*
7+
*/
8+
public class Client {
9+
public static void main(String[] agrs){
10+
String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt";
11+
//误差率阈值
12+
double errorValue = 0.2;
13+
14+
AdaBoostTool tool = new AdaBoostTool(filePath, errorValue);
15+
tool.adaBoostClassify();
16+
}
17+
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
package DataMining_AdaBoost;
2+
3+
/**
4+
* 坐标点类
5+
*
6+
* @author lyq
7+
*
8+
*/
9+
public class Point {
10+
// 坐标点x坐标
11+
private int x;
12+
// 坐标点y坐标
13+
private int y;
14+
// 坐标点的分类类别
15+
private int classType;
16+
//如果此节点被划错,他的误差率,不能用个数除以总数,因为不同坐标点的权重不一定相等
17+
private double probably;
18+
19+
public Point(int x, int y, int classType){
20+
this.x = x;
21+
this.y = y;
22+
this.classType = classType;
23+
}
24+
25+
public Point(String x, String y, String classType){
26+
this.x = Integer.parseInt(x);
27+
this.y = Integer.parseInt(y);
28+
this.classType = Integer.parseInt(classType);
29+
}
30+
31+
public int getX() {
32+
return x;
33+
}
34+
35+
public void setX(int x) {
36+
this.x = x;
37+
}
38+
39+
public int getY() {
40+
return y;
41+
}
42+
43+
public void setY(int y) {
44+
this.y = y;
45+
}
46+
47+
public int getClassType() {
48+
return classType;
49+
}
50+
51+
public void setClassType(int classType) {
52+
this.classType = classType;
53+
}
54+
55+
public double getProbably() {
56+
return probably;
57+
}
58+
59+
public void setProbably(double probably) {
60+
this.probably = probably;
61+
}
62+
}

0 commit comments

Comments
 (0)