Skip to content

Commit

Permalink
bayes&knn
Browse files Browse the repository at this point in the history
全部删除重新添加了一下,不知道这次会不会有冲突
  • Loading branch information
sywu committed Sep 25, 2014
1 parent f047365 commit aef1da6
Show file tree
Hide file tree
Showing 18 changed files with 1,884 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,11 @@ public static Predict parse(TPredict res,
break;
case STRING:
pred = new Predict<String>(n);
for(int i=0;i<n;i++){
for(int i=0;i<n;i++){
if(res.getLabel(i)==null){
pred.set(i, "null", 0f);
continue;
}
int idx = (Integer) res.getLabel(i);
String l = labels.lookupString(idx);
pred.set(i, l, res.getScore(i));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public class Predict<T> implements TPredict<T> {
/**
* 保存个数
*/
int n;
protected int n;
/**
* 缺省只保存一个最大值
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
package org.fnlp.ml.classifier.bayes;

import gnu.trove.iterator.TIntFloatIterator;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;

import org.fnlp.ml.classifier.AbstractClassifier;
import org.fnlp.ml.classifier.LabelParser.Type;
import org.fnlp.ml.classifier.linear.Linear;
import org.fnlp.ml.classifier.LabelParser;
import org.fnlp.ml.classifier.Predict;
import org.fnlp.ml.classifier.TPredict;
import org.fnlp.ml.feature.FeatureSelect;
import org.fnlp.ml.types.Instance;
import org.fnlp.ml.types.alphabet.AlphabetFactory;
import org.fnlp.ml.types.sv.HashSparseVector;
import org.fnlp.nlp.pipe.Pipe;
import org.fnlp.util.exception.LoadModelException;
import org.junit.Ignore;
/**
* 朴素贝叶斯分类器
* @author sywu
*
*/
public class BayesClassifier extends AbstractClassifier implements Serializable{
protected AlphabetFactory factory;
protected ItemFrequency tf;
protected Pipe pipe;
protected FeatureSelect fs;

@Override
public Predict classify(Instance instance, int n) {
// TODO Auto-generated method stub

int typeSize=tf.getTypeSize();
float[] score=new float[typeSize];
Arrays.fill(score, 0.0f);

Object obj=instance.getData();
if(!(obj instanceof HashSparseVector)){
System.out.println("error 输入类型非HashSparseVector!");
return null;
}
HashSparseVector data = (HashSparseVector) obj;
if(fs!=null)
data=fs.select(data);
TIntFloatIterator it = data.data.iterator();
float feaSize=tf.getFeatureSize();
while (it.hasNext()) {
it.advance();
if(it.key()==0)
continue;
int feature=it.key();
for(int type=0;type<typeSize;type++){
float itemF=tf.getItemFrequency(feature, type);
float typeF=tf.getTypeFrequency(type);
score[type]+=it.value()*Math.log((itemF+0.1)/(typeF+feaSize));
}
}

Predict<Integer> res=new Predict<Integer>(n);
for(int type=0;type<typeSize;type++)
res.add(type, score[type]);

return res;
}

@Override
public Predict classify(Instance instance, Type type, int n) {
// TODO Auto-generated method stub
Predict res = (Predict) classify(instance, n);
return LabelParser.parse(res,factory.DefaultLabelAlphabet(),type);
}
/**
* 得到类标签
* @param idx 类标签对应的索引
* @return
*/
public String getLabel(int idx) {
return factory.DefaultLabelAlphabet().lookupString(idx);
}

/**
* 将分类器保存到文件
* @param file
* @throws IOException
*/
public void saveTo(String file) throws IOException {
File f = new File(file);
File path = f.getParentFile();
if(!path.exists()){
path.mkdirs();
}

ObjectOutputStream out = new ObjectOutputStream(new GZIPOutputStream(
new BufferedOutputStream(new FileOutputStream(file))));
out.writeObject(this);
out.close();
}
/**
* 从文件读入分类器
* @param file
* @return
* @throws LoadModelException
*/
public static BayesClassifier loadFrom(String file) throws LoadModelException{
BayesClassifier cl = null;
try {
ObjectInputStream in = new ObjectInputStream(new GZIPInputStream(
new BufferedInputStream(new FileInputStream(file))));
cl = (BayesClassifier) in.readObject();
in.close();
} catch (Exception e) {
throw new LoadModelException(e,file);
}
return cl;
}
public void fS_CS(float percent){featureSelectionChiSquare(percent);}
public void featureSelectionChiSquare(float percent){
fs=new FeatureSelect(tf.getFeatureSize());
fs.fS_CS(tf, percent);
}
public void fS_CS_Max(float percent){featureSelectionChiSquareMax(percent);}
public void featureSelectionChiSquareMax(float percent){
fs=new FeatureSelect(tf.getFeatureSize());
fs.fS_CS_Max(tf, percent);
}
public void fS_IG(float percent){featureSelectionInformationGain(percent);}
public void featureSelectionInformationGain(float percent){
fs=new FeatureSelect(tf.getFeatureSize());
fs.fS_IG(tf, percent);
}
public void noFeatureSelection(){
fs=null;
}
public ItemFrequency getTf() {
return tf;
}

public void setTf(ItemFrequency tf) {
this.tf = tf;
}
public Pipe getPipe() {
return pipe;
}

public void setPipe(Pipe pipe) {
this.pipe = pipe;
}

public void setFactory(AlphabetFactory factory){
this.factory=factory;
}
public AlphabetFactory getFactory(){
return factory;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package org.fnlp.ml.classifier.bayes;

import gnu.trove.iterator.TIntFloatIterator;

import java.util.List;

import org.fnlp.ml.classifier.AbstractClassifier;
import org.fnlp.ml.classifier.linear.AbstractTrainer;
import org.fnlp.ml.types.Instance;
import org.fnlp.ml.types.InstanceSet;
import org.fnlp.ml.types.alphabet.AlphabetFactory;
import org.fnlp.ml.types.sv.HashSparseVector;
import org.fnlp.nlp.pipe.Pipe;
import org.fnlp.nlp.pipe.SeriesPipes;
/**
* 贝叶斯文本分类模型训练器
* 输入训练数据为稀疏矩阵
* @author sywu
*
*/
public class BayesTrainer{

public AbstractClassifier train(InstanceSet trainset) {
AlphabetFactory af=trainset.getAlphabetFactory();
SeriesPipes pp=(SeriesPipes) trainset.getPipes();
pp.removeTargetPipe();
return train(trainset,af,pp);
}
public AbstractClassifier train(InstanceSet trainset,AlphabetFactory af,Pipe pp) {
ItemFrequency tf=new ItemFrequency(trainset,af);
BayesClassifier classifier=new BayesClassifier();
classifier.setFactory(af);
classifier.setPipe(pp);
classifier.setTf(tf);
return classifier;
}
}
145 changes: 145 additions & 0 deletions fnlp-core/src/main/java/org/fnlp/ml/classifier/bayes/Heap.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
package org.fnlp.ml.classifier.bayes;

import java.util.ArrayList;
/**
* 堆
* @author sywu
*
* @param <T> 存储的数据类型
*/
public class Heap<T>{
private boolean isMinRootHeap;
private ArrayList<T> datas;
private double[] scores;
private int maxsize;
private int size;

public Heap(int max,boolean isMinRootHeap) {
this.isMinRootHeap=isMinRootHeap;
maxsize = max;
scores = new double[maxsize+1];
datas= new ArrayList<T>();
size = 0;
datas.add(null);
scores[0]=0;

}
public Heap(int max) {
this(max,true);
}


private int leftchild(int pos) {
return 2 * pos;
}

private int rightchild(int pos) {
return 2 * pos + 1;
}

private int parent(int pos) {
return pos / 2;
}

private boolean isleaf(int pos) {
return ((pos > size / 2) && (pos <= size));
}

private boolean needSwapWithParent(int pos){
return isMinRootHeap?
scores[pos] < scores[parent(pos)]:
scores[pos] > scores[parent(pos)];
}

private void swap(int pos1, int pos2) {
double tmp;
tmp = scores[pos1];
scores[pos1] = scores[pos2];
scores[pos2] = tmp;
T t1,t2;
t1=datas.get(pos1);
t2=datas.get(pos2);
datas.set(pos1, t2);
datas.set(pos2, t1);
}


public void insert(double score,T data) {
if(size<maxsize){
size++;
scores[size] = score;
datas.add(data);
int current = size;
while (current!=1&&needSwapWithParent(current)) {
swap(current, parent(current));
current = parent(current);
}
}
else {
if(isMinRootHeap?
score>scores[1]:
score<scores[1]){
scores[1]=score;
datas.set(1, data);
pushdown(1);
}
}
}


public void print() {
int i;
for (i = 1; i <= size; i++)
System.out.println(scores[i] + " " +datas.get(i).toString());
System.out.println();
}


// public int removemin() {
// swap(1, size);
// size--;
// if (size != 0)
// pushdown(1);
// return score[size + 1];
// }
private int findcheckchild(int pos){
int rlt;
rlt = leftchild(pos);
if(rlt==size)
return rlt;
if (isMinRootHeap?(scores[rlt] > scores[rlt + 1]):(scores[rlt] < scores[rlt + 1]))
rlt = rlt + 1;
return rlt;
}

private void pushdown(int pos) {
int checkchild;
while (!isleaf(pos)) {
checkchild = findcheckchild(pos);
if(needSwapWithParent(checkchild))
swap(pos, checkchild);
else
return;
pos = checkchild;
}
}

public ArrayList<T> getData(){
return datas;
}

public static void main(String args[])
{
Heap<String> hm = new Heap<String>(6,true);
hm.insert(1,"11");
hm.insert(4,"44");
hm.insert(2,"22");
hm.insert(6,"66");
hm.insert(3,"33");
hm.insert(5,"55");
hm.insert(9,"99");
hm.insert(7,"77");
hm.print();

}
}
Loading

0 comments on commit aef1da6

Please sign in to comment.