Skip to content

Commit

Permalink
Merge pull request #12 from jrbarron/tokenizer
Browse files Browse the repository at this point in the history
Add support for specifying a custom tokenizer on the NaiveBayes model
  • Loading branch information
cdipaolo authored Sep 17, 2016
2 parents 86c1fda + 732a627 commit 4b2f5a3
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 8 deletions.
28 changes: 23 additions & 5 deletions text/bayes.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,23 @@ type NaiveBayes struct {
// stream holds the datastream
stream <-chan base.TextDatapoint

// tokenizer is used by a model
// to split the input into tokens
tokenize Tokenizer

// Output is the io.Writer used for logging
// and printing. Defaults to os.Stdout.
Output io.Writer `json:"-"`
}

// Tokenizer accepts a sentence as input and breaks
// it down into a slice of tokens
type Tokenizer func(string) []string

func spaceTokenizer(input string) []string {
return strings.Split(strings.ToLower(input), " ")
}

// concurrentMap allows concurrency-friendly map
// access via its exported Get and Set methods
type concurrentMap struct {
Expand Down Expand Up @@ -236,6 +248,7 @@ func NewNaiveBayes(stream <-chan base.TextDatapoint, classes uint8, sanitize fun

sanitize: transform.RemoveFunc(sanitize),
stream: stream,
tokenize: spaceTokenizer,

Output: os.Stdout,
}
Expand All @@ -249,7 +262,7 @@ func (b *NaiveBayes) Predict(sentence string) uint8 {
sums := make([]float64, len(b.Count))

sentence, _, _ = transform.String(b.sanitize, sentence)
words := strings.Split(strings.ToLower(sentence), " ")
words := b.tokenize(sentence)
for _, word := range words {
w, ok := b.Words.Get(word)
if !ok {
Expand Down Expand Up @@ -300,7 +313,7 @@ func (b *NaiveBayes) Probability(sentence string) (uint8, float64) {
}

sentence, _, _ = transform.String(b.sanitize, sentence)
words := strings.Split(strings.ToLower(sentence), " ")
words := b.tokenize(sentence)
for _, word := range words {
w, ok := b.Words.Get(word)
if !ok {
Expand Down Expand Up @@ -353,9 +366,7 @@ func (b *NaiveBayes) OnlineLearn(errors chan<- error) {
if more {
// sanitize and break up document
sanitized, _, _ := transform.String(b.sanitize, point.X)
sanitized = strings.ToLower(sanitized)

words := strings.Split(sanitized, " ")
words := b.tokenize(sanitized)

C := int(point.Y)

Expand Down Expand Up @@ -425,6 +436,13 @@ func (b *NaiveBayes) UpdateSanitize(sanitize func(rune) bool) {
b.sanitize = transform.RemoveFunc(sanitize)
}

// UpdateTokenizer updates NaiveBayes model's tokenizer function.
// The default implementation will convert the input to lower
// case and split on the space character.
func (b *NaiveBayes) UpdateTokenizer(tokenizer Tokenizer) {
b.tokenize = tokenizer
}

// String implements the fmt interface for clean printing. Here
// we're using it to print the model as the equation h(θ)=...
// where h is the perceptron hypothesis model.
Expand Down
59 changes: 59 additions & 0 deletions text/bayes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -356,3 +356,62 @@ func TestPersistNaiveBayesShouldPass1(t *testing.T) {
assert.EqualValues(t, 1, class, "Class should be 1")
assert.True(t, p > 0.75, "There should be a greater than 75 percent chance the document is positive - Given %v", p)
}

func TestTokenizer(t *testing.T) {
stream := make(chan base.TextDatapoint, 100)
errors := make(chan error)

// This is a somewhat contrived test case since splitting on commas is
// probably not very useful, but it is designed to purely test the
// tokenizer. A more useful, but too complicated test case would be to use
// a tokenizer that does something like porter stemming.
model := NewNaiveBayes(stream, 3, func(rune) bool {
// do not filter out commas
return false
})
model.UpdateTokenizer(func(input string) []string {
return strings.Split(strings.ToLower(input), ",")
})

go model.OnlineLearn(errors)

stream <- base.TextDatapoint{
X: "I,love,the,city",
Y: 1,
}

stream <- base.TextDatapoint{
X: "I,hate,Los,Angeles",
Y: 0,
}

stream <- base.TextDatapoint{
X: "My,mother,is,not,a,nice,lady",
Y: 0,
}

close(stream)

for {
err, more := <-errors
if more {
fmt.Printf("Error passed: %v", err)
} else {
// training is done!
break
}
}

// now you can predict like normal
class := model.Predict("My,mo~~~ther,is,in,Los,Angeles") // 0
assert.EqualValues(t, 0, class, "Class should be 0")

// test small document classification
class, p := model.Probability("Mother,Los,Angeles")
assert.EqualValues(t, 0, class, "Class should be 0")
assert.True(t, p > 0.75, "There should be a greater than 75 percent chance the document is negative - Given %v", p)

class, p = model.Probability("Love,the,CiTy")
assert.EqualValues(t, 1, class, "Class should be 1")
assert.True(t, p > 0.75, "There should be a greater than 75 percent chance the document is positive - Given %v", p)
}
5 changes: 2 additions & 3 deletions text/tfidf.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package text
import (
"math"
"sort"
"strings"

"golang.org/x/text/transform"
)
Expand Down Expand Up @@ -81,7 +80,7 @@ func (f Frequencies) Swap(i, j int) {
// this is calculated
func (t *TFIDF) TFIDF(word string, sentence string) float64 {
sentence, _, _ = transform.String(t.sanitize, sentence)
document := strings.Split(strings.ToLower(sentence), " ")
document := t.tokenize(sentence)

return t.TermFrequency(word, document) * t.InverseDocumentFrequency(word)
}
Expand All @@ -96,7 +95,7 @@ func (t *TFIDF) TFIDF(word string, sentence string) float64 {
// by importance
func (t *TFIDF) MostImportantWords(sentence string, n int) Frequencies {
sentence, _, _ = transform.String(t.sanitize, sentence)
document := strings.Split(strings.ToLower(sentence), " ")
document := t.tokenize(sentence)

freq := TermFrequencies(document)
for i := range freq {
Expand Down

0 comments on commit 4b2f5a3

Please sign in to comment.