|
| 1 | +package DataMining_EM; |
| 2 | + |
| 3 | +import java.io.BufferedReader; |
| 4 | +import java.io.File; |
| 5 | +import java.io.FileReader; |
| 6 | +import java.io.IOException; |
| 7 | +import java.text.MessageFormat; |
| 8 | +import java.util.ArrayList; |
| 9 | + |
| 10 | +/** |
| 11 | + * EM最大期望算法工具类 |
| 12 | + * |
| 13 | + * @author lyq |
| 14 | + * |
| 15 | + */ |
| 16 | +public class EMTool { |
| 17 | + // 测试数据文件地址 |
| 18 | + private String dataFilePath; |
| 19 | + // 测试坐标点数据 |
| 20 | + private String[][] data; |
| 21 | + // 测试坐标点数据列表 |
| 22 | + private ArrayList<Point> pointArray; |
| 23 | + // 目标C1点 |
| 24 | + private Point p1; |
| 25 | + // 目标C2点 |
| 26 | + private Point p2; |
| 27 | + |
| 28 | + public EMTool(String dataFilePath) { |
| 29 | + this.dataFilePath = dataFilePath; |
| 30 | + pointArray = new ArrayList<>(); |
| 31 | + } |
| 32 | + |
| 33 | + /** |
| 34 | + * 从文件中读取数据 |
| 35 | + */ |
| 36 | + public void readDataFile() { |
| 37 | + File file = new File(dataFilePath); |
| 38 | + ArrayList<String[]> dataArray = new ArrayList<String[]>(); |
| 39 | + |
| 40 | + try { |
| 41 | + BufferedReader in = new BufferedReader(new FileReader(file)); |
| 42 | + String str; |
| 43 | + String[] tempArray; |
| 44 | + while ((str = in.readLine()) != null) { |
| 45 | + tempArray = str.split(" "); |
| 46 | + dataArray.add(tempArray); |
| 47 | + } |
| 48 | + in.close(); |
| 49 | + } catch (IOException e) { |
| 50 | + e.getStackTrace(); |
| 51 | + } |
| 52 | + |
| 53 | + data = new String[dataArray.size()][]; |
| 54 | + dataArray.toArray(data); |
| 55 | + |
| 56 | + // 开始时默认取头2个点作为2个簇中心 |
| 57 | + p1 = new Point(Integer.parseInt(data[0][0]), |
| 58 | + Integer.parseInt(data[0][1])); |
| 59 | + p2 = new Point(Integer.parseInt(data[1][0]), |
| 60 | + Integer.parseInt(data[1][1])); |
| 61 | + |
| 62 | + Point p; |
| 63 | + for (String[] array : data) { |
| 64 | + // 将数据转换为对象加入列表方便计算 |
| 65 | + p = new Point(Integer.parseInt(array[0]), |
| 66 | + Integer.parseInt(array[1])); |
| 67 | + pointArray.add(p); |
| 68 | + } |
| 69 | + } |
| 70 | + |
| 71 | + /** |
| 72 | + * 计算坐标点对于2个簇中心点的隶属度 |
| 73 | + * |
| 74 | + * @param p |
| 75 | + * 待测试坐标点 |
| 76 | + */ |
| 77 | + private void computeMemberShip(Point p) { |
| 78 | + // p点距离第一个簇中心点的距离 |
| 79 | + double distance1 = 0; |
| 80 | + // p距离第二个中心点的距离 |
| 81 | + double distance2 = 0; |
| 82 | + |
| 83 | + // 用欧式距离计算 |
| 84 | + distance1 = Math.pow(p.getX() - p1.getX(), 2) |
| 85 | + + Math.pow(p.getY() - p1.getY(), 2); |
| 86 | + distance2 = Math.pow(p.getX() - p2.getX(), 2) |
| 87 | + + Math.pow(p.getY() - p2.getY(), 2); |
| 88 | + |
| 89 | + // 计算对于p1点的隶属度,与距离成反比关系,距离靠近越小,隶属度越大,所以要用大的distance2另外的距离来表示 |
| 90 | + p.setMemberShip1(distance2 / (distance1 + distance2)); |
| 91 | + // 计算对于p2点的隶属度 |
| 92 | + p.setMemberShip2(distance1 / (distance1 + distance2)); |
| 93 | + } |
| 94 | + |
| 95 | + /** |
| 96 | + * 执行期望最大化步骤 |
| 97 | + */ |
| 98 | + public void exceptMaxStep() { |
| 99 | + // 新的优化过的簇中心点 |
| 100 | + double p1X = 0; |
| 101 | + double p1Y = 0; |
| 102 | + double p2X = 0; |
| 103 | + double p2Y = 0; |
| 104 | + double temp1 = 0; |
| 105 | + double temp2 = 0; |
| 106 | + // 误差值 |
| 107 | + double errorValue1 = 0; |
| 108 | + double errorValue2 = 0; |
| 109 | + // 上次更新的簇点坐标 |
| 110 | + Point lastP1 = null; |
| 111 | + Point lastP2 = null; |
| 112 | + |
| 113 | + // 当开始计算的时候,或是中心点的误差值超过1的时候都需要再次迭代计算 |
| 114 | + while (lastP1 == null || errorValue1 > 1.0 || errorValue2 > 1.0) { |
| 115 | + for (Point p : pointArray) { |
| 116 | + computeMemberShip(p); |
| 117 | + p1X += p.getMemberShip1() * p.getMemberShip1() * p.getX(); |
| 118 | + p1Y += p.getMemberShip1() * p.getMemberShip1() * p.getY(); |
| 119 | + temp1 += p.getMemberShip1() * p.getMemberShip1(); |
| 120 | + |
| 121 | + p2X += p.getMemberShip2() * p.getMemberShip2() * p.getX(); |
| 122 | + p2Y += p.getMemberShip2() * p.getMemberShip2() * p.getY(); |
| 123 | + temp2 += p.getMemberShip2() * p.getMemberShip2(); |
| 124 | + } |
| 125 | + |
| 126 | + lastP1 = new Point(p1.getX(), p1.getY()); |
| 127 | + lastP2 = new Point(p2.getX(), p2.getY()); |
| 128 | + |
| 129 | + // 套公式计算新的簇中心点坐标,最最大化处理 |
| 130 | + p1.setX(p1X / temp1); |
| 131 | + p1.setY(p1Y / temp1); |
| 132 | + p2.setX(p2X / temp2); |
| 133 | + p2.setY(p2Y / temp2); |
| 134 | + |
| 135 | + errorValue1 = Math.abs(lastP1.getX() - p1.getX()) |
| 136 | + + Math.abs(lastP1.getY() - p1.getY()); |
| 137 | + errorValue2 = Math.abs(lastP2.getX() - p2.getX()) |
| 138 | + + Math.abs(lastP2.getY() - p2.getY()); |
| 139 | + } |
| 140 | + |
| 141 | + System.out.println(MessageFormat.format( |
| 142 | + "簇中心节点p1({0}, {1}), p2({2}, {3})", p1.getX(), p1.getY(), |
| 143 | + p2.getX(), p2.getY())); |
| 144 | + } |
| 145 | + |
| 146 | +} |
0 commit comments