Skip to content

Commit

Permalink
wsutil: simplify send and recv extension interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
gobwas committed Feb 5, 2021
1 parent bd85c6a commit d7978b7
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 68 deletions.
39 changes: 27 additions & 12 deletions example/autobahn/autobahn.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,20 +214,25 @@ func wsflateHandler(w http.ResponseWriter, r *http.Request) {
fr := wsflate.NewReader(nil, func(r io.Reader) wsflate.Decompressor {
return flate.NewReader(r)
})

// MessageState implements wsutil.Extension and is used to check whether
// received WebSocket message is compressed. That is, it's generally
// possible to receive uncompressed messaged even if compression extension
// was negotiated.
var msg wsflate.MessageState

// Note that control frames are all written without compression.
controlHandler := wsutil.ControlFrameHandler(conn, ws.StateServerSide)
rd := wsutil.Reader{
Source: conn,
State: ws.StateServerSide | ws.StateExtended,
CheckUTF8: false,
OnIntermediate: controlHandler,
Extensions: []wsutil.RecvExtension{
wsutil.RecvExtensionFunc(wsflate.BitsRecv),
},
Extensions: []wsutil.RecvExtension{&msg},
}

wr := wsutil.NewWriter(conn, ws.StateServerSide|ws.StateExtended, 0)
wr.SetExtensions(wsutil.SendExtensionFunc(wsflate.BitsSend))
wr.SetExtensions(&msg)

for {
h, err := rd.NextFrame()
Expand All @@ -243,19 +248,29 @@ func wsflateHandler(w http.ResponseWriter, r *http.Request) {
continue
}

fr.Reset(&rd)
fw.Reset(wr)

wr.ResetOp(h.OpCode)

// Copy incoming bytes right into writer through decompressor and compressor.
if _, err = io.Copy(fw, fr); err != nil {
log.Fatal(err)
var (
src io.Reader = &rd
dst io.Writer = wr
)
if msg.IsCompressed() {
fr.Reset(src)
fw.Reset(dst)
src = fr
dst = fw
}
// Flush the flate writer.
if err = fw.Close(); err != nil {
// Copy incoming bytes right into writer, probably through decompressor
// and compressor.
if _, err = io.Copy(dst, src); err != nil {
log.Fatal(err)
}
if msg.IsCompressed() {
// Flush the flate writer.
if err = fw.Close(); err != nil {
log.Fatal(err)
}
}
// Flush WebSocket fragment writer. We could send multiple fragments
// for large messages.
if err = wr.Flush(); err != nil {
Expand Down
103 changes: 82 additions & 21 deletions wsflate/extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package wsflate

import (
"bytes"
"fmt"

"github.com/gobwas/httphead"
"github.com/gobwas/ws"
Expand Down Expand Up @@ -89,40 +88,102 @@ func (n *Extension) Reset() {
n.params = Parameters{}
}

var errNonFirstFragmentEnabledBit = ws.ProtocolError(
"non-first fragment contains compression bit enabled",
var ErrUnexpectedCompressionBit = ws.ProtocolError(
"control frame or non-first fragment of data contains compression bit set",
)

// BitsRecv changes RSV bits of the received frame header as if compression
// extension was negotiated.
func BitsRecv(fseq int, rsv byte) (byte, error) {
r1, r2, r3 := ws.RsvBits(rsv)
if fseq > 0 && r1 {
// UnsetBit clears the Per-Message Compression bit in header h and returns its
// modified copy. It reports whether compression bit was set in header h.
// It returns non-nil error if compression bit has unexpected value.
func UnsetBit(h ws.Header) (_ ws.Header, wasSet bool, err error) {
var s MessageState
h, err := s.UnsetBits(h)
return h, s.IsCompressed(), err
}

// SetBit sets the Per-Message Compression bit in header h and returns its
// modified copy.
// It returns non-nil error if compression bit has unexpected value.
func SetBit(h ws.Header) (_ ws.Header, err error) {
var s MessageState
s.SetCompressed(true)
return s.SetBits(h)
}

// MessageState holds message compression state.
//
// It is consulted during SetBits(h) call to make a decision whether we must
// set the Per-Message Compression bit for given header h argument.
// It is updated during UnsetBits(h) to reflect compression state of a message
// represented by header h argument.
// It can also be consulted/updated directly by calling
// IsCompressed()/SetCompressed().
//
// In general MessageState should be used when there is no direct access to
// connection to read frame from, but it is still needed to know if message
// being read is compressed. For other cases SetBit() and UnsetBit() should be
// used instead.
//
// NOTE: the compression state is updated during UnsetBits(h) only when header
// h argunent represents data (text or binary) frame.
type MessageState struct {
compressed bool
}

// SetCompressed marks message as "compressed" or "uncompressed".
// See https://tools.ietf.org/html/rfc7692#section-6
func (s *MessageState) SetCompressed(v bool) {
s.compressed = v
}

// IsCompressed reports whether message is "compressed".
// See https://tools.ietf.org/html/rfc7692#section-6
func (s *MessageState) IsCompressed() bool {
return s.compressed
}

// UnsetBits changes RSV bits of the given frame header h as if compression
// extension was negotiated. It returns modified copy of h and error if header
// is malformed from the RFC perspective.
func (s *MessageState) UnsetBits(h ws.Header) (ws.Header, error) {
r1, r2, r3 := ws.RsvBits(h.Rsv)
switch {
case h.OpCode.IsData() && h.OpCode != ws.OpContinuation:
h.Rsv = ws.Rsv(false, r2, r3)
s.SetCompressed(r1)
return h, nil

case r1:
// An endpoint MUST NOT set the "Per-Message Compressed"
// bit of control frames and non-first fragments of a data
// message. An endpoint receiving such a frame MUST _Fail
// the WebSocket Connection_.
return rsv, errNonFirstFragmentEnabledBit
}
if fseq > 0 {
return rsv, nil
return h, ErrUnexpectedCompressionBit

default:
// NOTE: do not change the state of s.compressed since UnsetBits()
// might also be called for (intermediate) control frames.
return h, nil
}
return ws.Rsv(false, r2, r3), nil
}

// BitsSend changes RSV bits of the frame header which is being send as if
// compression extension was negotiated.
func BitsSend(fseq int, rsv byte) (byte, error) {
r1, r2, r3 := ws.RsvBits(rsv)
// SetBits changes RSV bits of the frame header h which is being send as if
// compression extension was negotiated. It returns modified copy of h and
// error if header is malformed from the RFC perspective.
func (s *MessageState) SetBits(h ws.Header) (ws.Header, error) {
r1, r2, r3 := ws.RsvBits(h.Rsv)
if r1 {
return rsv, fmt.Errorf("wsflate: compression bit is already set")
return h, ErrUnexpectedCompressionBit
}
if fseq > 0 {
if !h.OpCode.IsData() || h.OpCode == ws.OpContinuation {
// An endpoint MUST NOT set the "Per-Message Compressed"
// bit of control frames and non-first fragments of a data
// message. An endpoint receiving such a frame MUST _Fail
// the WebSocket Connection_.
return rsv, nil
return h, nil
}
if s.IsCompressed() {
h.Rsv = ws.Rsv(true, r2, r3)
}
return ws.Rsv(true, r2, r3), nil
return h, nil
}
40 changes: 24 additions & 16 deletions wsflate/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,18 +98,17 @@ func (h *Helper) DecompressFrame(in ws.Frame) (f ws.Frame, err error) {

// CompressFrameBuffer compresses a frame using given buffer.
// Returned frame's payload holds bytes returned by buf.Bytes().
func (h *Helper) CompressFrameBuffer(buf Buffer, in ws.Frame) (f ws.Frame, err error) {
if !in.Header.Fin {
func (h *Helper) CompressFrameBuffer(buf Buffer, f ws.Frame) (ws.Frame, error) {
if !f.Header.Fin {
return f, fmt.Errorf("wsflate: fragmented messages are not allowed")
}
if err := h.CompressTo(buf, in.Payload); err != nil {
if err := h.CompressTo(buf, f.Payload); err != nil {
return f, err
}
// Copy initial frame.
f = in
var err error
f.Payload = buf.Bytes()
f.Header.Length = int64(len(f.Payload))
f.Header.Rsv, err = BitsSend(0, f.Header.Rsv)
f.Header, err = SetBit(f.Header)
if err != nil {
return f, err
}
Expand All @@ -118,21 +117,30 @@ func (h *Helper) CompressFrameBuffer(buf Buffer, in ws.Frame) (f ws.Frame, err e

// DecompressFrameBuffer decompresses a frame using given buffer.
// Returned frame's payload holds bytes returned by buf.Bytes().
func (h *Helper) DecompressFrameBuffer(buf Buffer, in ws.Frame) (f ws.Frame, err error) {
if !in.Header.Fin {
return f, fmt.Errorf("wsflate: fragmented messages are not allowed")
func (h *Helper) DecompressFrameBuffer(buf Buffer, f ws.Frame) (ws.Frame, error) {
if !f.Header.Fin {
return f, fmt.Errorf(
"wsflate: fragmented messages are not supported by helper",
)
}
if err := h.DecompressTo(buf, in.Payload); err != nil {
var (
compressed bool
err error
)
f.Header, compressed, err = UnsetBit(f.Header)
if err != nil {
return f, err
}
// Copy initial frame.
f = in
f.Payload = buf.Bytes()
f.Header.Length = int64(len(f.Payload))
f.Header.Rsv, err = BitsRecv(0, f.Header.Rsv)
if err != nil {
if !compressed {
return f, nil
}
if err := h.DecompressTo(buf, f.Payload); err != nil {
return f, err
}

f.Payload = buf.Bytes()
f.Header.Length = int64(len(f.Payload))

return f, nil
}

Expand Down
18 changes: 10 additions & 8 deletions wsutil/extenstion.go
Original file line number Diff line number Diff line change
@@ -1,29 +1,31 @@
package wsutil

import "github.com/gobwas/ws"

// RecvExtension is an interface for clearing fragment header RSV bits.
type RecvExtension interface {
BitsRecv(seq int, rsv byte) (byte, error)
UnsetBits(ws.Header) (ws.Header, error)
}

// RecvExtensionFunc is an adapter to allow the use of ordinary functions as
// RecvExtension.
type RecvExtensionFunc func(int, byte) (byte, error)
type RecvExtensionFunc func(ws.Header) (ws.Header, error)

// BitsRecv implements RecvExtension.
func (fn RecvExtensionFunc) BitsRecv(seq int, rsv byte) (byte, error) {
return fn(seq, rsv)
func (fn RecvExtensionFunc) UnsetBits(h ws.Header) (ws.Header, error) {
return fn(h)
}

// SendExtension is an interface for setting fragment header RSV bits.
type SendExtension interface {
BitsSend(seq int, rsv byte) (byte, error)
SetBits(ws.Header) (ws.Header, error)
}

// SendExtensionFunc is an adapter to allow the use of ordinary functions as
// SendExtension.
type SendExtensionFunc func(int, byte) (byte, error)
type SendExtensionFunc func(ws.Header) (ws.Header, error)

// BitsSend implements SendExtension.
func (fn SendExtensionFunc) BitsSend(seq int, rsv byte) (byte, error) {
return fn(seq, rsv)
func (fn SendExtensionFunc) SetBits(h ws.Header) (ws.Header, error) {
return fn(h)
}
8 changes: 2 additions & 6 deletions wsutil/reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ type Reader struct {
frame io.Reader // Used to as frame reader.
raw io.LimitedReader // Used to discard frames without cipher.
utf8 UTF8Reader // Used to check UTF8 sequences if CheckUTF8 is true.
fseq int // Fragment sequence in message counter.
}

// NewReader creates new frame reader that reads from r keeping given state to
Expand Down Expand Up @@ -200,8 +199,8 @@ func (r *Reader) NextFrame() (hdr ws.Header, err error) {
frame = NewCipherReader(frame, hdr.Mask)
}

for _, ext := range r.Extensions {
hdr.Rsv, err = ext.BitsRecv(r.fseq, hdr.Rsv)
for _, x := range r.Extensions {
hdr, err = x.UnsetBits(hdr)
if err != nil {
return hdr, err
}
Expand Down Expand Up @@ -237,10 +236,8 @@ func (r *Reader) NextFrame() (hdr ws.Header, err error) {

if hdr.Fin {
r.State = r.State.Clear(ws.StateFragmented)
r.fseq = 0
} else {
r.State = r.State.Set(ws.StateFragmented)
r.fseq++
}

return
Expand All @@ -261,7 +258,6 @@ func (r *Reader) reset() {
r.raw = io.LimitedReader{}
r.frame = nil
r.utf8 = UTF8Reader{}
r.fseq = 0
r.opCode = 0
}

Expand Down
6 changes: 3 additions & 3 deletions wsutil/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,8 +366,8 @@ func (w *Writer) WriteThrough(p []byte) (n int, err error) {
Fin: false,
Length: int64(len(p)),
}
for _, ext := range w.extensions {
frame.Header.Rsv, err = ext.BitsSend(w.fseq, frame.Header.Rsv)
for _, x := range w.extensions {
frame.Header, err = x.SetBits(frame.Header)
if err != nil {
return 0, err
}
Expand Down Expand Up @@ -479,7 +479,7 @@ func (w *Writer) flushFragment(fin bool) (err error) {
}
)
for _, ext := range w.extensions {
header.Rsv, err = ext.BitsSend(w.fseq, header.Rsv)
header, err = ext.SetBits(header)
if err != nil {
return err
}
Expand Down
5 changes: 3 additions & 2 deletions wsutil/writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -362,8 +362,9 @@ func TestWriterLargeWrite(t *testing.T) {

// Test that event for big writes extensions set their bits.
var rsv = [3]bool{true, true, false}
w.SetExtensions(SendExtensionFunc(func(fseq int, bits byte) (byte, error) {
return ws.Rsv(rsv[0], rsv[1], rsv[2]), nil
w.SetExtensions(SendExtensionFunc(func(h ws.Header) (ws.Header, error) {
h.Rsv = ws.Rsv(rsv[0], rsv[1], rsv[2])
return h, nil
}))

// Write message with size twice bigger than writer's internal buffer.
Expand Down

0 comments on commit d7978b7

Please sign in to comment.