Skip to content

Commit

Permalink
fixed small bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
ptnplanet committed Mar 11, 2012
1 parent 654df58 commit aa92e66
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 25 deletions.
40 changes: 27 additions & 13 deletions BayesClassifier.java
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ private float categoryProbability(Collection<T> features, K category) {
* @param features The set of features to use.
* @return A sorted <code>Set</code> of category-probability-entries.
*/
private SortedSet<Entry<K, Float>> categoryProbabilities(
private SortedSet<Classification<T, K>> categoryProbabilities(
Collection<T> features) {

/*
Expand All @@ -68,22 +68,26 @@ private SortedSet<Entry<K, Float>> categoryProbabilities(
* achieve the desired functionality. A custom comparator is therefore
* needed.
*/
SortedSet<Entry<K, Float>> probabilities =
new TreeSet<Entry<K, Float>>(new Comparator<Entry<K, Float>>() {
SortedSet<Classification<T, K>> probabilities =
new TreeSet<Classification<T, K>>(
new Comparator<Classification<T, K>>() {

@Override
public int compare(Entry<K, Float> o1, Entry<K, Float> o2) {
int toReturn = o1.getValue().compareTo(o2.getValue());
if ((toReturn == 0) && (o1.getKey() != o2.getKey())) {
public int compare(Classification<T, K> o1,
Classification<T, K> o2) {
int toReturn = Float.compare(
o1.getProbability(), o2.getProbability());
if ((toReturn == 0)
&& !o1.getCategory().equals(o2.getCategory()))
toReturn = -1;
}
return toReturn;
}
});

for (K category : this.getCategories())
probabilities.add(
new SimpleEntry<K, Float>(category,
this.categoryProbability(features, category)));
probabilities.add(new Classification<T, K>(
features, category,
this.categoryProbability(features, category)));
return probabilities;
}

Expand All @@ -94,10 +98,20 @@ public int compare(Entry<K, Float> o1, Entry<K, Float> o2) {
*/
@Override
public K classify(Collection<T> features) {
SortedSet<Entry<K, Float>> probabilites =
SortedSet<Classification<T, K>> probabilites =
this.categoryProbabilities(features);
if (probabilites.size() > 0)
return probabilites.last().getKey();

System.out.println("Results:\t");
for (Classification<T, K> prob : probabilites)
System.out.println(prob);

if (probabilites.size() > 0) {
System.out.println("Classified as " +
probabilites.last().getCategory());
return probabilites.last().getCategory();
} else {
System.out.println("No results");
}
return null;
}

Expand Down
8 changes: 8 additions & 0 deletions Classification.java
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,14 @@ public Collection<T> getFeatureset() {
return featureset;
}

/**
* Retrieves the classification's probability.
* @return
*/
public float getProbability() {
return this.probability;
}

/**
* Retrieves the category the featureset was classified as.
*
Expand Down
40 changes: 31 additions & 9 deletions Classifier.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public abstract class Classifier<T, K> implements IFeatureProbability<T, K> {
/**
* The initial memory capacity or how many classifications are memorized.
*/
private static final int MEMORY_CAPACITY = 1000;
private static final int MEMORY_CAPACITY = 5;

/**
* A dictionary mapping features to their number of occurrences in each
Expand All @@ -53,6 +53,10 @@ public abstract class Classifier<T, K> implements IFeatureProbability<T, K> {
*/
private Dictionary<K, Integer> totalCategoryCount;

/**
* The classifier's memory. It will forget old classifications as soon as
* they become too old.
*/
private Queue<Classification<T, K>> memoryQueue;

/**
Expand Down Expand Up @@ -129,17 +133,17 @@ public void incrementFeature(T feature, K category) {
}
Integer count = features.get(feature);
if (count == null) {
features.put(feature, 1);
features.put(feature, 0);
count = features.get(feature);
}
count++;
features.put(feature, ++count);

Integer totalCount = this.totalFeatureCount.get(feature);
if (totalCount == null) {
this.totalFeatureCount.put(feature, 1);
totalCount = this.totalFeatureCount.get(feature);
}
totalCount++;
this.totalFeatureCount.put(feature, ++totalCount);
}

/**
Expand All @@ -154,7 +158,7 @@ public void incrementCategory(K category) {
this.totalCategoryCount.put(category, 1);
count = this.totalCategoryCount.get(category);
}
count++;
this.totalCategoryCount.put(category, ++count);
}

/**
Expand Down Expand Up @@ -345,13 +349,31 @@ public float featureWeighedAverage(T feature, K category,
* @param features The features that resulted in the given category.
*/
public void learn(K category, Collection<T> features) {
for (T feature : features)
this.incrementFeature(feature, category);
this.incrementCategory(category);
this.learn(new Classification<T, K>(features, category));
}

this.memoryQueue.offer(new Classification<T, K>(features, category));
/**
* Train the classifier by telling it that the given features resulted in
* the given category.
*
* @param classification The classification to learn.
*/
public void learn(Classification<T, K> classification) {

System.out.println("Learning new classification:\t"
+ classification);

for (T feature : classification.getFeatureset())
this.incrementFeature(feature, classification.getCategory());
this.incrementCategory(classification.getCategory());

this.memoryQueue.offer(classification);
if (this.memoryQueue.size() > Classifier.MEMORY_CAPACITY) {
Classification<T, K> toForget = this.memoryQueue.remove();

System.out.println("Memory over capacity. Forgetting about\t"
+ toForget);

for (T feature : toForget.getFeatureset())
this.decrementFeature(feature, toForget.getCategory());
this.decrementCategory(toForget.getCategory());
Expand Down
4 changes: 1 addition & 3 deletions ClassifierTester.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,11 @@ public static void main(String[] args) throws IOException {
Arrays.asList(
Arrays.copyOfRange(tokens, 2, tokens.length));
classifier.learn(tokens[1], context);
System.out.println("trained. number of categories:\t"
+ classifier.getCategoriesTotal());
} else if (tokens[0].startsWith("c")) {
Collection<String> context =
Arrays.asList(
Arrays.copyOfRange(tokens, 1, tokens.length));
System.out.println(classifier.classify(context));
classifier.classify(context);
}
System.out.print("> ");
}
Expand Down

0 comments on commit aa92e66

Please sign in to comment.