forked from FudanNLP/fnlp
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
全部删除重新添加了一下,不知道这次会不会有冲突
- Loading branch information
Showing
18 changed files
with
1,884 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
169 changes: 169 additions & 0 deletions
169
fnlp-core/src/main/java/org/fnlp/ml/classifier/bayes/BayesClassifier.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
package org.fnlp.ml.classifier.bayes; | ||
|
||
import gnu.trove.iterator.TIntFloatIterator; | ||
|
||
import java.io.BufferedInputStream; | ||
import java.io.BufferedOutputStream; | ||
import java.io.File; | ||
import java.io.FileInputStream; | ||
import java.io.FileOutputStream; | ||
import java.io.IOException; | ||
import java.io.ObjectInputStream; | ||
import java.io.ObjectOutputStream; | ||
import java.io.Serializable; | ||
import java.util.ArrayList; | ||
import java.util.Arrays; | ||
import java.util.zip.GZIPInputStream; | ||
import java.util.zip.GZIPOutputStream; | ||
|
||
import org.fnlp.ml.classifier.AbstractClassifier; | ||
import org.fnlp.ml.classifier.LabelParser.Type; | ||
import org.fnlp.ml.classifier.linear.Linear; | ||
import org.fnlp.ml.classifier.LabelParser; | ||
import org.fnlp.ml.classifier.Predict; | ||
import org.fnlp.ml.classifier.TPredict; | ||
import org.fnlp.ml.feature.FeatureSelect; | ||
import org.fnlp.ml.types.Instance; | ||
import org.fnlp.ml.types.alphabet.AlphabetFactory; | ||
import org.fnlp.ml.types.sv.HashSparseVector; | ||
import org.fnlp.nlp.pipe.Pipe; | ||
import org.fnlp.util.exception.LoadModelException; | ||
import org.junit.Ignore; | ||
/** | ||
* 朴素贝叶斯分类器 | ||
* @author sywu | ||
* | ||
*/ | ||
public class BayesClassifier extends AbstractClassifier implements Serializable{ | ||
protected AlphabetFactory factory; | ||
protected ItemFrequency tf; | ||
protected Pipe pipe; | ||
protected FeatureSelect fs; | ||
|
||
@Override | ||
public Predict classify(Instance instance, int n) { | ||
// TODO Auto-generated method stub | ||
|
||
int typeSize=tf.getTypeSize(); | ||
float[] score=new float[typeSize]; | ||
Arrays.fill(score, 0.0f); | ||
|
||
Object obj=instance.getData(); | ||
if(!(obj instanceof HashSparseVector)){ | ||
System.out.println("error 输入类型非HashSparseVector!"); | ||
return null; | ||
} | ||
HashSparseVector data = (HashSparseVector) obj; | ||
if(fs!=null) | ||
data=fs.select(data); | ||
TIntFloatIterator it = data.data.iterator(); | ||
float feaSize=tf.getFeatureSize(); | ||
while (it.hasNext()) { | ||
it.advance(); | ||
if(it.key()==0) | ||
continue; | ||
int feature=it.key(); | ||
for(int type=0;type<typeSize;type++){ | ||
float itemF=tf.getItemFrequency(feature, type); | ||
float typeF=tf.getTypeFrequency(type); | ||
score[type]+=it.value()*Math.log((itemF+0.1)/(typeF+feaSize)); | ||
} | ||
} | ||
|
||
Predict<Integer> res=new Predict<Integer>(n); | ||
for(int type=0;type<typeSize;type++) | ||
res.add(type, score[type]); | ||
|
||
return res; | ||
} | ||
|
||
@Override | ||
public Predict classify(Instance instance, Type type, int n) { | ||
// TODO Auto-generated method stub | ||
Predict res = (Predict) classify(instance, n); | ||
return LabelParser.parse(res,factory.DefaultLabelAlphabet(),type); | ||
} | ||
/** | ||
* 得到类标签 | ||
* @param idx 类标签对应的索引 | ||
* @return | ||
*/ | ||
public String getLabel(int idx) { | ||
return factory.DefaultLabelAlphabet().lookupString(idx); | ||
} | ||
|
||
/** | ||
* 将分类器保存到文件 | ||
* @param file | ||
* @throws IOException | ||
*/ | ||
public void saveTo(String file) throws IOException { | ||
File f = new File(file); | ||
File path = f.getParentFile(); | ||
if(!path.exists()){ | ||
path.mkdirs(); | ||
} | ||
|
||
ObjectOutputStream out = new ObjectOutputStream(new GZIPOutputStream( | ||
new BufferedOutputStream(new FileOutputStream(file)))); | ||
out.writeObject(this); | ||
out.close(); | ||
} | ||
/** | ||
* 从文件读入分类器 | ||
* @param file | ||
* @return | ||
* @throws LoadModelException | ||
*/ | ||
public static BayesClassifier loadFrom(String file) throws LoadModelException{ | ||
BayesClassifier cl = null; | ||
try { | ||
ObjectInputStream in = new ObjectInputStream(new GZIPInputStream( | ||
new BufferedInputStream(new FileInputStream(file)))); | ||
cl = (BayesClassifier) in.readObject(); | ||
in.close(); | ||
} catch (Exception e) { | ||
throw new LoadModelException(e,file); | ||
} | ||
return cl; | ||
} | ||
public void fS_CS(float percent){featureSelectionChiSquare(percent);} | ||
public void featureSelectionChiSquare(float percent){ | ||
fs=new FeatureSelect(tf.getFeatureSize()); | ||
fs.fS_CS(tf, percent); | ||
} | ||
public void fS_CS_Max(float percent){featureSelectionChiSquareMax(percent);} | ||
public void featureSelectionChiSquareMax(float percent){ | ||
fs=new FeatureSelect(tf.getFeatureSize()); | ||
fs.fS_CS_Max(tf, percent); | ||
} | ||
public void fS_IG(float percent){featureSelectionInformationGain(percent);} | ||
public void featureSelectionInformationGain(float percent){ | ||
fs=new FeatureSelect(tf.getFeatureSize()); | ||
fs.fS_IG(tf, percent); | ||
} | ||
public void noFeatureSelection(){ | ||
fs=null; | ||
} | ||
public ItemFrequency getTf() { | ||
return tf; | ||
} | ||
|
||
public void setTf(ItemFrequency tf) { | ||
this.tf = tf; | ||
} | ||
public Pipe getPipe() { | ||
return pipe; | ||
} | ||
|
||
public void setPipe(Pipe pipe) { | ||
this.pipe = pipe; | ||
} | ||
|
||
public void setFactory(AlphabetFactory factory){ | ||
this.factory=factory; | ||
} | ||
public AlphabetFactory getFactory(){ | ||
return factory; | ||
} | ||
} |
37 changes: 37 additions & 0 deletions
37
fnlp-core/src/main/java/org/fnlp/ml/classifier/bayes/BayesTrainer.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
package org.fnlp.ml.classifier.bayes; | ||
|
||
import gnu.trove.iterator.TIntFloatIterator; | ||
|
||
import java.util.List; | ||
|
||
import org.fnlp.ml.classifier.AbstractClassifier; | ||
import org.fnlp.ml.classifier.linear.AbstractTrainer; | ||
import org.fnlp.ml.types.Instance; | ||
import org.fnlp.ml.types.InstanceSet; | ||
import org.fnlp.ml.types.alphabet.AlphabetFactory; | ||
import org.fnlp.ml.types.sv.HashSparseVector; | ||
import org.fnlp.nlp.pipe.Pipe; | ||
import org.fnlp.nlp.pipe.SeriesPipes; | ||
/** | ||
* 贝叶斯文本分类模型训练器 | ||
* 输入训练数据为稀疏矩阵 | ||
* @author sywu | ||
* | ||
*/ | ||
public class BayesTrainer{ | ||
|
||
public AbstractClassifier train(InstanceSet trainset) { | ||
AlphabetFactory af=trainset.getAlphabetFactory(); | ||
SeriesPipes pp=(SeriesPipes) trainset.getPipes(); | ||
pp.removeTargetPipe(); | ||
return train(trainset,af,pp); | ||
} | ||
public AbstractClassifier train(InstanceSet trainset,AlphabetFactory af,Pipe pp) { | ||
ItemFrequency tf=new ItemFrequency(trainset,af); | ||
BayesClassifier classifier=new BayesClassifier(); | ||
classifier.setFactory(af); | ||
classifier.setPipe(pp); | ||
classifier.setTf(tf); | ||
return classifier; | ||
} | ||
} |
145 changes: 145 additions & 0 deletions
145
fnlp-core/src/main/java/org/fnlp/ml/classifier/bayes/Heap.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
package org.fnlp.ml.classifier.bayes; | ||
|
||
import java.util.ArrayList; | ||
/** | ||
* 堆 | ||
* @author sywu | ||
* | ||
* @param <T> 存储的数据类型 | ||
*/ | ||
public class Heap<T>{ | ||
private boolean isMinRootHeap; | ||
private ArrayList<T> datas; | ||
private double[] scores; | ||
private int maxsize; | ||
private int size; | ||
|
||
public Heap(int max,boolean isMinRootHeap) { | ||
this.isMinRootHeap=isMinRootHeap; | ||
maxsize = max; | ||
scores = new double[maxsize+1]; | ||
datas= new ArrayList<T>(); | ||
size = 0; | ||
datas.add(null); | ||
scores[0]=0; | ||
|
||
} | ||
public Heap(int max) { | ||
this(max,true); | ||
} | ||
|
||
|
||
private int leftchild(int pos) { | ||
return 2 * pos; | ||
} | ||
|
||
private int rightchild(int pos) { | ||
return 2 * pos + 1; | ||
} | ||
|
||
private int parent(int pos) { | ||
return pos / 2; | ||
} | ||
|
||
private boolean isleaf(int pos) { | ||
return ((pos > size / 2) && (pos <= size)); | ||
} | ||
|
||
private boolean needSwapWithParent(int pos){ | ||
return isMinRootHeap? | ||
scores[pos] < scores[parent(pos)]: | ||
scores[pos] > scores[parent(pos)]; | ||
} | ||
|
||
private void swap(int pos1, int pos2) { | ||
double tmp; | ||
tmp = scores[pos1]; | ||
scores[pos1] = scores[pos2]; | ||
scores[pos2] = tmp; | ||
T t1,t2; | ||
t1=datas.get(pos1); | ||
t2=datas.get(pos2); | ||
datas.set(pos1, t2); | ||
datas.set(pos2, t1); | ||
} | ||
|
||
|
||
public void insert(double score,T data) { | ||
if(size<maxsize){ | ||
size++; | ||
scores[size] = score; | ||
datas.add(data); | ||
int current = size; | ||
while (current!=1&&needSwapWithParent(current)) { | ||
swap(current, parent(current)); | ||
current = parent(current); | ||
} | ||
} | ||
else { | ||
if(isMinRootHeap? | ||
score>scores[1]: | ||
score<scores[1]){ | ||
scores[1]=score; | ||
datas.set(1, data); | ||
pushdown(1); | ||
} | ||
} | ||
} | ||
|
||
|
||
public void print() { | ||
int i; | ||
for (i = 1; i <= size; i++) | ||
System.out.println(scores[i] + " " +datas.get(i).toString()); | ||
System.out.println(); | ||
} | ||
|
||
|
||
// public int removemin() { | ||
// swap(1, size); | ||
// size--; | ||
// if (size != 0) | ||
// pushdown(1); | ||
// return score[size + 1]; | ||
// } | ||
private int findcheckchild(int pos){ | ||
int rlt; | ||
rlt = leftchild(pos); | ||
if(rlt==size) | ||
return rlt; | ||
if (isMinRootHeap?(scores[rlt] > scores[rlt + 1]):(scores[rlt] < scores[rlt + 1])) | ||
rlt = rlt + 1; | ||
return rlt; | ||
} | ||
|
||
private void pushdown(int pos) { | ||
int checkchild; | ||
while (!isleaf(pos)) { | ||
checkchild = findcheckchild(pos); | ||
if(needSwapWithParent(checkchild)) | ||
swap(pos, checkchild); | ||
else | ||
return; | ||
pos = checkchild; | ||
} | ||
} | ||
|
||
public ArrayList<T> getData(){ | ||
return datas; | ||
} | ||
|
||
public static void main(String args[]) | ||
{ | ||
Heap<String> hm = new Heap<String>(6,true); | ||
hm.insert(1,"11"); | ||
hm.insert(4,"44"); | ||
hm.insert(2,"22"); | ||
hm.insert(6,"66"); | ||
hm.insert(3,"33"); | ||
hm.insert(5,"55"); | ||
hm.insert(9,"99"); | ||
hm.insert(7,"77"); | ||
hm.print(); | ||
|
||
} | ||
} |
Oops, something went wrong.