Skip to content

Commit

Permalink
Merge pull request #1 from cdipaolo/master
Browse files Browse the repository at this point in the history
Update piazzamp/goml
  • Loading branch information
piazzamp authored Aug 27, 2016
2 parents 0b98886 + e2563be commit 5ac6002
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 15 deletions.
33 changes: 23 additions & 10 deletions text/bayes.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ type NaiveBayes struct {

// Output is the io.Writer used for logging
// and printing. Defaults to os.Stdout.
Output io.Writer
Output io.Writer `json:"-"`
}

// concurrentMap allows concurrency-friendly map
Expand All @@ -168,21 +168,34 @@ type concurrentMap struct {
words map[string]Word
}

func (m *concurrentMap) MarshalJSON() ([]byte, error) {
return json.Marshal(m.words)
}

func (m *concurrentMap) UnmarshalJSON(data []byte) error {
err := json.Unmarshal(data, &m.words)
if err != nil {
return err
}

return nil
}

// Get looks up a word from h's Word map and should be used
// in place of a direct map lookup. The only caveat is that
// it will always return the 'success' boolean
func (h *concurrentMap) Get(w string) (Word, bool) {
h.RLock()
result, ok := h.words[w]
h.RUnlock()
func (m *concurrentMap) Get(w string) (Word, bool) {
m.RLock()
result, ok := m.words[w]
m.RUnlock()
return result, ok
}

// Set sets word k's value to v in h's Word map
func (h *concurrentMap) Set(k string, v Word) {
h.Lock()
h.words[k] = v
h.Unlock()
func (m *concurrentMap) Set(k string, v Word) {
m.Lock()
m.words[k] = v
m.Unlock()
}

// Word holds the structural
Expand All @@ -207,7 +220,7 @@ type Word struct {
// DocsSeen is the same as Seen but
// a word is only counted once even
// if it's in a document multiple times
DocsSeen uint64
DocsSeen uint64 `json:"-"`
}

// NewNaiveBayes returns a NaiveBayes model the
Expand Down
86 changes: 81 additions & 5 deletions text/bayes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func TestExampleClassificationShouldPass1(t *testing.T) {
assert.True(t, p > 0.75, "There should be a greater than 75 percent chance the document is negative - Given %v", p)

class, p = model.Probability("Love the CiTy")
assert.EqualValues(t, 1, class, "Class should be 0")
assert.EqualValues(t, 1, class, "Class should be 1")
assert.True(t, p > 0.75, "There should be a greater than 75 percent chance the document is positive - Given %v", p)
}

Expand Down Expand Up @@ -253,19 +253,15 @@ func TestConcurrentPredictionAndLearningShouldNotFail(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
// fmt.Println("beginning predicting")
for i := 0; i < 500; i++ {
model.Predict(strings.Repeat("some stuff that might be in the training data like iterate", 25))
}
// fmt.Println("done predicting")
}()

wg.Add(1)
go func() {
defer wg.Done()
// fmt.Println("beginning learning")
model.OnlineLearn(errors)
// fmt.Println("done learning")
}()

go func() {
Expand All @@ -280,3 +276,83 @@ func TestConcurrentPredictionAndLearningShouldNotFail(t *testing.T) {
close(c)
wg.Wait()
}

//* Test Persitance To File *//
func TestPersistNaiveBayesShouldPass1(t *testing.T) {
var err error

// create the channel of data and errors
stream := make(chan base.TextDatapoint, 100)
errors := make(chan error)

// make a new NaiveBayes model with
// 2 classes expected (classes in
// datapoints will now expect {0,1}.
// in general, given n as the classes
// variable, the model will expect
// datapoint classes in {0,...,n-1})
model := NewNaiveBayes(stream, 3, base.OnlyWordsAndNumbers)

go model.OnlineLearn(errors)

for i := 1; i < 10; i++ {
stream <- base.TextDatapoint{
X: "I love the city",
Y: 1,
}

stream <- base.TextDatapoint{
X: "I hate Los Angeles",
Y: 0,
}
}

close(stream)

for {
err, more := <-errors
if more {
fmt.Printf("Error passed: %v", err)
} else {
// training is done!
break
}
}

// now you can predict like normal
class := model.Predict("My mo~~~ther is in Los Angeles") // 0
assert.EqualValues(t, 0, class, "Class should be 0")

// test small document classification
class, p := model.Probability("Mother Los Angeles")
assert.EqualValues(t, 0, class, "Class should be 0")
assert.True(t, p > 0.75, "There should be a greater than 75 percent chance the document is negative - Given %v", p)

class, p = model.Probability("Love the CiTy")
assert.EqualValues(t, 1, class, "Class should be 1")
assert.True(t, p > 0.75, "There should be a greater than 75 percent chance the document is positive - Given %v", p)

// persist to file
err = model.PersistToFile("/tmp/.goml/Bayes.json")
assert.Nil(t, err, "Persistance error should be nil")

model.Words = concurrentMap{}

class, p = model.Probability("Mother Los Angeles")
assert.Equal(t, p, 0.5, "With a blank model the prediction should be 0.5 for both classes", p)

// restore from file
err = model.RestoreFromFile("/tmp/.goml/Bayes.json")
assert.Nil(t, err, "Persistance error should be nil")

class = model.Predict("My mo~~~ther is in Los Angeles") // 0
assert.EqualValues(t, 0, class, "Class should be 0")

class, p = model.Probability("Mother Los Angeles")
assert.EqualValues(t, 0, class, "Class should be 0")
assert.True(t, p > 0.75, "There should be a greater than 75 percent chance the document is negative - Given %v", p)

class, p = model.Probability("Love the CiTy")
assert.EqualValues(t, 1, class, "Class should be 1")
assert.True(t, p > 0.75, "There should be a greater than 75 percent chance the document is positive - Given %v", p)
}

0 comments on commit 5ac6002

Please sign in to comment.