Skip to content

Commit

Permalink
Enable pure Protobuf binary messaging over WebSocket. (heroiclabs#279)
Browse files Browse the repository at this point in the history
  • Loading branch information
zyro committed Dec 12, 2018
1 parent d37914a commit 85467ff
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 30 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [keep a changelog](http://keepachangelog.com) and this pr

## [Unreleased]
### Added
- WebSocket connections can now use pure Protobuf binary messaging.
- Lua runtime tournament listings now return duration, end active, and end time fields.
- Lua runtime tournament end hooks now contain duration, end active, and end time fields.
- Lua runtime tournament reset hooks now contain duration, end active, and end time fields.
Expand Down
42 changes: 35 additions & 7 deletions server/message_router.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
package server

import (
"bytes"
"github.com/golang/protobuf/jsonpb"
"github.com/golang/protobuf/proto"
"github.com/heroiclabs/nakama/rtapi"
"go.uber.org/zap"
)
Expand Down Expand Up @@ -45,19 +47,45 @@ func (r *LocalMessageRouter) SendToPresenceIDs(logger *zap.Logger, presenceIDs [
return
}

payload, err := r.jsonpbMarshaler.MarshalToString(envelope)
if err != nil {
logger.Error("Could not marshall message to json", zap.Error(err))
return
}
payloadBytes := []byte(payload)
// Prepare payload variables but do not initialize until we hit a session that needs them to avoid unnecessary work.
var payloadProtobuf []byte
var payloadJson []byte

for _, presenceID := range presenceIDs {
session := r.sessionRegistry.Get(presenceID.SessionID)
if session == nil {
logger.Debug("No session to route to", zap.String("sid", presenceID.SessionID.String()))
continue
}
if err := session.SendBytes(isStream, mode, payloadBytes); err != nil {

var err error
switch session.Format() {
case SessionFormatProtobuf:
if payloadProtobuf == nil {
// Marshal the payload now that we know this format is needed.
payloadProtobuf, err = proto.Marshal(envelope)
if err != nil {
logger.Error("Could not marshal message", zap.Error(err))
return
}
}
err = session.SendBytes(isStream, mode, payloadProtobuf)
case SessionFormatJson:
fallthrough
default:
if payloadJson == nil {
// Marshal the payload now that we know this format is needed.
var buf bytes.Buffer
if err = r.jsonpbMarshaler.Marshal(&buf, envelope); err == nil {
payloadJson = buf.Bytes()
} else {
logger.Error("Could not marshal message", zap.Error(err))
return
}
}
err = session.SendBytes(isStream, mode, payloadJson)
}
if err != nil {
logger.Error("Failed to route to", zap.String("sid", presenceID.SessionID.String()), zap.Error(err))
}
}
Expand Down
80 changes: 58 additions & 22 deletions server/session_ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"context"
"errors"
"fmt"
"github.com/golang/protobuf/proto"
"sync"
"time"

Expand All @@ -39,6 +40,7 @@ type sessionWS struct {
logger *zap.Logger
config Config
id uuid.UUID
format SessionFormat
userID uuid.UUID
username *atomic.String
expiry int64
Expand All @@ -50,6 +52,7 @@ type sessionWS struct {

jsonpbMarshaler *jsonpb.Marshaler
jsonpbUnmarshaler *jsonpb.Unmarshaler
wsMessageType int
queuePriorityThreshold int
pingPeriodDuration time.Duration
pongWaitDuration time.Duration
Expand All @@ -68,18 +71,24 @@ type sessionWS struct {
outgoingStopCh chan struct{}
}

func NewSessionWS(logger *zap.Logger, config Config, userID uuid.UUID, username string, expiry int64, clientIP string, clientPort string, jsonpbMarshaler *jsonpb.Marshaler, jsonpbUnmarshaler *jsonpb.Unmarshaler, conn *websocket.Conn, sessionRegistry *SessionRegistry, matchmaker Matchmaker, tracker Tracker) Session {
func NewSessionWS(logger *zap.Logger, config Config, format SessionFormat, userID uuid.UUID, username string, expiry int64, clientIP string, clientPort string, jsonpbMarshaler *jsonpb.Marshaler, jsonpbUnmarshaler *jsonpb.Unmarshaler, conn *websocket.Conn, sessionRegistry *SessionRegistry, matchmaker Matchmaker, tracker Tracker) Session {
sessionID := uuid.Must(uuid.NewV4())
sessionLogger := logger.With(zap.String("uid", userID.String()), zap.String("sid", sessionID.String()))

sessionLogger.Info("New WebSocket session connected")
sessionLogger.Info("New WebSocket session connected", zap.Uint8("format", uint8(format)))

ctx, ctxCancelFn := context.WithCancel(context.Background())

wsMessageType := websocket.TextMessage
if format == SessionFormatProtobuf {
wsMessageType = websocket.BinaryMessage
}

return &sessionWS{
logger: sessionLogger,
config: config,
id: sessionID,
format: format,
userID: userID,
username: atomic.NewString(username),
expiry: expiry,
Expand All @@ -91,6 +100,7 @@ func NewSessionWS(logger *zap.Logger, config Config, userID uuid.UUID, username

jsonpbMarshaler: jsonpbMarshaler,
jsonpbUnmarshaler: jsonpbUnmarshaler,
wsMessageType: wsMessageType,
queuePriorityThreshold: (config.GetSocket().OutgoingQueueSize / 3) * 2,
pingPeriodDuration: time.Duration(config.GetSocket().PingPeriodMs) * time.Millisecond,
pongWaitDuration: time.Duration(config.GetSocket().PongWaitMs) * time.Millisecond,
Expand Down Expand Up @@ -159,7 +169,7 @@ func (s *sessionWS) Consume(processRequest func(logger *zap.Logger, session Sess
go s.processOutgoing()

for {
_, data, err := s.conn.ReadMessage()
messageType, data, err := s.conn.ReadMessage()
if err != nil {
// Ignore "normal" WebSocket errors.
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) {
Expand All @@ -170,6 +180,12 @@ func (s *sessionWS) Consume(processRequest func(logger *zap.Logger, session Sess
}
break
}
if messageType != s.wsMessageType {
// Expected text but received binary, or expected binary but received text.
// Disconnect client if it attempts to use this kind of mixed protocol mode.
s.logger.Debug("Received unexpected WebSocket message type", zap.Int("expected", s.wsMessageType), zap.Int("actual", messageType))
break
}

s.receivedMessageCounter--
if s.receivedMessageCounter <= 0 {
Expand All @@ -178,21 +194,29 @@ func (s *sessionWS) Consume(processRequest func(logger *zap.Logger, session Sess
}

request := &rtapi.Envelope{}
if err = s.jsonpbUnmarshaler.Unmarshal(bytes.NewReader(data), request); err != nil {
switch s.format {
case SessionFormatProtobuf:
err = proto.Unmarshal(data, request)
case SessionFormatJson:
fallthrough
default:
err = s.jsonpbUnmarshaler.Unmarshal(bytes.NewReader(data), request)
}
if err != nil {
// If the payload is malformed the client is incompatible or misbehaving, either way disconnect it now.
s.logger.Warn("Received malformed payload", zap.String("data", string(data)))
s.logger.Warn("Received malformed payload", zap.Binary("data", data))
break
} else {
switch request.Cid {
case "":
if !processRequest(s.logger, s, request) {
break
}
default:
requestLogger := s.logger.With(zap.String("cid", request.Cid))
if !processRequest(requestLogger, s, request) {
break
}
}

switch request.Cid {
case "":
if !processRequest(s.logger, s, request) {
break
}
default:
requestLogger := s.logger.With(zap.String("cid", request.Cid))
if !processRequest(requestLogger, s, request) {
break
}
}
}
Expand Down Expand Up @@ -238,7 +262,7 @@ func (s *sessionWS) processOutgoing() {
}
// Process the outgoing message queue.
s.conn.SetWriteDeadline(time.Now().Add(s.writeWaitDuration))
if err := s.conn.WriteMessage(websocket.TextMessage, payload); err != nil {
if err := s.conn.WriteMessage(s.wsMessageType, payload); err != nil {
s.Unlock()
s.logger.Warn("Could not write message", zap.Error(err))
return
Expand Down Expand Up @@ -268,22 +292,34 @@ func (s *sessionWS) pingNow() bool {
}

func (s *sessionWS) Format() SessionFormat {
return SessionFormatJson
return s.format
}

func (s *sessionWS) Send(isStream bool, mode uint8, envelope *rtapi.Envelope) error {
payload, err := s.jsonpbMarshaler.MarshalToString(envelope)
var payload []byte
var err error
switch s.format {
case SessionFormatProtobuf:
payload, err = proto.Marshal(envelope)
case SessionFormatJson:
fallthrough
default:
var buf bytes.Buffer
if err = s.jsonpbMarshaler.Marshal(&buf, envelope); err == nil {
payload = buf.Bytes()
}
}
if err != nil {
s.logger.Warn("Could not marshal to json", zap.Error(err))
s.logger.Warn("Could not marshal envelope", zap.Error(err))
return err
}

if s.logger.Core().Enabled(zap.DebugLevel) {
switch envelope.Message.(type) {
case *rtapi.Envelope_Error:
s.logger.Debug("Sending error message", zap.String("payload", payload))
s.logger.Debug("Sending error message", zap.Binary("payload", payload))
default:
s.logger.Debug(fmt.Sprintf("Sending %T message", envelope.Message), zap.String("payload", payload))
s.logger.Debug(fmt.Sprintf("Sending %T message", envelope.Message), zap.Any("envelope", envelope))
}
}

Expand Down
17 changes: 16 additions & 1 deletion server/socket_ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,21 @@ func NewSocketWsAcceptor(logger *zap.Logger, config Config, sessionRegistry *Ses

// This handler will be attached to the API Gateway server.
return func(w http.ResponseWriter, r *http.Request) {
// Check format.
var format SessionFormat
switch r.URL.Query().Get("format") {
case "protobuf":
format = SessionFormatProtobuf
case "json":
fallthrough
case "":
format = SessionFormatJson
default:
// Invalid values are rejected.
http.Error(w, "Invalid format parameter", 400)
return
}

// Check authentication.
token := r.URL.Query().Get("token")
if token == "" {
Expand Down Expand Up @@ -90,7 +105,7 @@ func NewSocketWsAcceptor(logger *zap.Logger, config Config, sessionRegistry *Ses
span := trace.NewSpan("nakama.session.ws", nil, trace.StartOptions{})

// Wrap the connection for application handling.
s := NewSessionWS(logger, config, userID, username, expiry, clientIP, clientPort, jsonpbMarshaler, jsonpbUnmarshaler, conn, sessionRegistry, matchmaker, tracker)
s := NewSessionWS(logger, config, format, userID, username, expiry, clientIP, clientPort, jsonpbMarshaler, jsonpbUnmarshaler, conn, sessionRegistry, matchmaker, tracker)

// Add to the session registry.
sessionRegistry.add(s)
Expand Down

0 comments on commit 85467ff

Please sign in to comment.