Skip to content

Commit

Permalink
firmata: remove race conditions identified in Firmata client
Browse files Browse the repository at this point in the history
Signed-off-by: deadprogram <[email protected]>
  • Loading branch information
deadprogram committed Feb 8, 2017
1 parent fc3db1c commit b0a8bda
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 26 deletions.
4 changes: 2 additions & 2 deletions examples/firmata_button.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import (
func main() {
firmataAdaptor := firmata.NewAdaptor("/dev/ttyACM0")

button := gpio.NewButtonDriver(firmataAdaptor, "5")
led := gpio.NewLedDriver(firmataAdaptor, "13")
button := gpio.NewButtonDriver(firmataAdaptor, "2")
led := gpio.NewLedDriver(firmataAdaptor, "3")

work := func() {
button.On(gpio.ButtonPush, func(data interface{}) {
Expand Down
51 changes: 37 additions & 14 deletions platforms/firmata/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"fmt"
"io"
"math"
"sync"
"sync/atomic"
"time"

"gobot.io/x/gobot"
Expand Down Expand Up @@ -64,10 +66,12 @@ type Client struct {
pins []Pin
FirmwareName string
ProtocolVersion string
connected bool
connected atomic.Value
connection io.ReadWriteCloser
analogPins []int
initTimeInterval time.Duration
initFunc func() error
initMutex sync.Mutex
gobot.Eventer
}

Expand Down Expand Up @@ -95,10 +99,11 @@ func New() *Client {
connection: nil,
pins: []Pin{},
analogPins: []int{},
connected: false,
Eventer: gobot.NewEventer(),
}

c.connected.Store(false)

for _, s := range []string{
"FirmwareQuery",
"CapabilityQuery",
Expand All @@ -114,15 +119,32 @@ func New() *Client {
return c
}

func (b *Client) setInitFunc(f func() error) {
b.initMutex.Lock()
defer b.initMutex.Unlock()
b.initFunc = f
}

func (b *Client) getInitFunc() func() error {
b.initMutex.Lock()
defer b.initMutex.Unlock()
f := b.initFunc
return f
}

func (b *Client) setConnected(c bool) {
b.connected.Store(c)
}

// Disconnect disconnects the Client
func (b *Client) Disconnect() (err error) {
b.connected = false
b.setConnected(false)
return b.connection.Close()
}

// Connected returns the current connection state of the Client
func (b *Client) Connected() bool {
return b.connected
return b.connected.Load().(bool)
}

// Pins returns all available pins
Expand All @@ -134,45 +156,46 @@ func (b *Client) Pins() []Pin {
// then continuously polls the firmata board for new information when it's
// available.
func (b *Client) Connect(conn io.ReadWriteCloser) (err error) {
if b.connected {
if b.Connected() {
return ErrConnected
}

b.connection = conn
b.Reset()

initFunc := b.ProtocolVersionQuery
b.setInitFunc(b.ProtocolVersionQuery)

b.Once(b.Event("ProtocolVersion"), func(data interface{}) {
initFunc = b.FirmwareQuery
b.setInitFunc(b.FirmwareQuery)
})

b.Once(b.Event("FirmwareQuery"), func(data interface{}) {
initFunc = b.CapabilitiesQuery
b.setInitFunc(b.CapabilitiesQuery)
})

b.Once(b.Event("CapabilityQuery"), func(data interface{}) {
initFunc = b.AnalogMappingQuery
b.setInitFunc(b.AnalogMappingQuery)
})

b.Once(b.Event("AnalogMappingQuery"), func(data interface{}) {
initFunc = func() error { return nil }
b.setInitFunc(func() error { return nil })
b.ReportDigital(0, 1)
b.ReportDigital(1, 1)
b.connected = true
b.setConnected(true)
})

for {
if err := initFunc(); err != nil {
f := b.getInitFunc()
if err := f(); err != nil {
return err
}
if err := b.process(); err != nil {
return err
}
if b.connected {
if b.Connected() {
go func() {
for {
if !b.connected {
if !b.Connected() {
break
}

Expand Down
40 changes: 30 additions & 10 deletions platforms/firmata/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,19 @@ import (
type readWriteCloser struct{}

func (readWriteCloser) Write(p []byte) (int, error) {
writeDataMutex.Lock()
defer writeDataMutex.Unlock()
return testWriteData.Write(p)
}

var clientMutex sync.Mutex
var writeDataMutex sync.Mutex
var testReadData = []byte{}
var testWriteData = bytes.Buffer{}

func (readWriteCloser) Read(b []byte) (int, error) {
clientMutex.Lock()
defer clientMutex.Unlock()
writeDataMutex.Lock()
defer writeDataMutex.Unlock()

size := len(b)
if len(testReadData) < size {
Expand Down Expand Up @@ -80,7 +83,7 @@ func initTestFirmata() *Client {
b.process()
}

b.connected = true
b.setConnected(true)
b.Connect(readWriteCloser{})

return b
Expand Down Expand Up @@ -286,31 +289,44 @@ func TestProcessStringData(t *testing.T) {
func TestConnect(t *testing.T) {
b := New()

var responseMutex sync.Mutex
responseMutex.Lock()
response := testProtocolResponse()

go func() {
for {
testReadData = append(testReadData, response...)
time.Sleep(100 * time.Millisecond)
}
}()
responseMutex.Unlock()

b.Once(b.Event("ProtocolVersion"), func(data interface{}) {
responseMutex.Lock()
response = testFirmwareResponse()
responseMutex.Unlock()
})

b.Once(b.Event("FirmwareQuery"), func(data interface{}) {
responseMutex.Lock()
response = testCapabilitiesResponse()
responseMutex.Unlock()
})

b.Once(b.Event("CapabilityQuery"), func(data interface{}) {
responseMutex.Lock()
response = testAnalogMappingResponse()
responseMutex.Unlock()
})

b.Once(b.Event("AnalogMappingQuery"), func(data interface{}) {
responseMutex.Lock()
response = testProtocolResponse()
responseMutex.Unlock()
})

go func() {
for {
responseMutex.Lock()
testReadData = append(testReadData, response...)
responseMutex.Unlock()
time.Sleep(100 * time.Millisecond)
}
}()

gobottest.Assert(t, b.Connect(readWriteCloser{}), nil)
}

Expand Down Expand Up @@ -342,9 +358,13 @@ func TestServoConfig(t *testing.T) {
}

for _, test := range tests {
writeDataMutex.Lock()
testWriteData.Reset()
writeDataMutex.Unlock()
err := b.ServoConfig(test.arguments[0], test.arguments[1], test.arguments[2])
writeDataMutex.Lock()
gobottest.Assert(t, testWriteData.Bytes(), test.expected)
gobottest.Assert(t, err, test.result)
writeDataMutex.Unlock()
}
}

0 comments on commit b0a8bda

Please sign in to comment.