diff --git a/server/mqtt.go b/server/mqtt.go index 33a06f2fc5..273d71f83e 100644 --- a/server/mqtt.go +++ b/server/mqtt.go @@ -24,18 +24,21 @@ var deprecatedTopics = []string{ // MQTT is the MQTT server. It uses the MQTT client for publishing. type MQTT struct { - log *util.Logger - Handler *mqtt.Client - root string + log *util.Logger + Handler *mqtt.Client + root string + publisher func(topic string, retained bool, payload string) } // NewMQTT creates MQTT server func NewMQTT(root string) *MQTT { - return &MQTT{ + m := &MQTT{ log: util.NewLogger("mqtt"), Handler: mqtt.Instance, root: root, } + m.publisher = m.publishString + return m } func (m *MQTT) encode(v interface{}) string { @@ -65,7 +68,7 @@ func (m *MQTT) encode(v interface{}) string { } func (m *MQTT) publishComplex(topic string, retained bool, payload interface{}) { - if payload == nil { + if _, ok := payload.(fmt.Stringer); ok || payload == nil { m.publishSingleValue(topic, retained, payload) return } @@ -105,11 +108,15 @@ func (m *MQTT) publishComplex(topic string, retained bool, payload interface{}) } } -func (m *MQTT) publishSingleValue(topic string, retained bool, payload interface{}) { +func (m *MQTT) publishString(topic string, retained bool, payload string) { token := m.Handler.Client.Publish(topic, m.Handler.Qos, retained, m.encode(payload)) go m.Handler.WaitForToken("send", topic, token) } +func (m *MQTT) publishSingleValue(topic string, retained bool, payload interface{}) { + m.publisher(topic, retained, m.encode(payload)) +} + func (m *MQTT) publish(topic string, retained bool, payload interface{}) { // publish phase values if slice, ok := payload.([]float64); ok && len(slice) == 3 { diff --git a/server/mqtt_test.go b/server/mqtt_test.go index e181a5419e..1ebe08b774 100644 --- a/server/mqtt_test.go +++ b/server/mqtt_test.go @@ -2,9 +2,12 @@ package server import ( "math" + "strconv" "testing" + "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestMqttNaNInf(t *testing.T) { @@ -12,3 +15,42 @@ func TestMqttNaNInf(t *testing.T) { assert.Equal(t, "NaN", m.encode(math.NaN()), "NaN not encoded as string") assert.Equal(t, "+Inf", m.encode(math.Inf(0)), "Inf not encoded as string") } + +func TestPublishTypes(t *testing.T) { + var topics, payloads []string + + reset := func() { + topics = topics[:0] + payloads = payloads[:0] + } + + m := &MQTT{ + publisher: func(topic string, retained bool, payload string) { + topics = append(topics, topic) + payloads = append(payloads, payload) + }, + } + + now := time.Now() + m.publish("test", false, now) + require.Len(t, topics, 1) + assert.Equal(t, strconv.FormatInt(now.Unix(), 10), payloads[0], "time not encoded as unix timestamp") + reset() + + m.publish("test", false, struct { + Foo string + }{ + Foo: "bar", + }) + require.Len(t, topics, 1) + assert.Equal(t, `test/foo`, topics[0], "struct mismatch") + assert.Equal(t, `bar`, payloads[0], "struct mismatch") + reset() + + slice := []int{10, 20} + m.publish("test", false, slice) + require.Len(t, topics, 3) + assert.Equal(t, []string{`test`, `test/1`, `test/2`}, topics, "slice mismatch") + assert.Equal(t, []string{`2`, `10`, `20`}, payloads, "slice mismatch") + reset() +}