Skip to content

Commit

Permalink
wsutil: add MaxFrameSize to Reader (gobwas#127)
Browse files Browse the repository at this point in the history
Fixes gobwas#26
  • Loading branch information
agnivade authored Jan 25, 2021
1 parent e5bc048 commit 7637d01
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 1 deletion.
15 changes: 14 additions & 1 deletion wsutil/reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ import (
// preceding NextFrame() call.
var ErrNoFrameAdvance = errors.New("no frame advance")

// ErrFrameTooLarge indicates that a message of length higher than
// MaxFrameSize was being read.
var ErrFrameTooLarge = errors.New("frame too large")

// FrameHandlerFunc handles parsed frame header and its body represented by
// io.Reader.
//
Expand Down Expand Up @@ -42,7 +46,12 @@ type Reader struct {
// header RSV segment.
Extensions []RecvExtension

// TODO(gobwas): add max frame size limit here.
// MaxFrameSize controls the maximum frame size in bytes
// that can be read. A message exceeding that size will return
// a ErrFrameTooLarge to the application.
//
// Not setting this field means there is no limit.
MaxFrameSize int64

OnContinuation FrameHandlerFunc
OnIntermediate FrameHandlerFunc
Expand Down Expand Up @@ -175,6 +184,10 @@ func (r *Reader) NextFrame() (hdr ws.Header, err error) {
return hdr, err
}

if n := r.MaxFrameSize; n > 0 && hdr.Length > n {
return hdr, ErrFrameTooLarge
}

// Save raw reader to use it on discarding frame without ciphering and
// other streaming checks.
r.raw = io.LimitedReader{
Expand Down
27 changes: 27 additions & 0 deletions wsutil/reader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,33 @@ func TestReaderNextFrameAndReadEOF(t *testing.T) {

}

func TestMaxFrameSize(t *testing.T) {
var buf bytes.Buffer
msg := []byte("small frame")
f := ws.NewTextFrame(msg)
if err := ws.WriteFrame(&buf, f); err != nil {
t.Fatal(err)
}
r := Reader{
Source: &buf,
MaxFrameSize: int64(len(msg)) - 1,
}

_, err := r.NextFrame()
if got, want := err, ErrFrameTooLarge; got != want {
t.Errorf("NextFrame() error = %v; want %v", got, want)
}

p := make([]byte, 100)
n, err := r.Read(p)
if got, want := err, ErrNoFrameAdvance; got != want {
t.Errorf("Read() error = %v; want %v", got, want)
}
if got, want := n, 0; got != want {
t.Errorf("Read() bytes returned = %v; want %v", got, want)
}
}

func TestReaderUTF8(t *testing.T) {
yo := []byte("Ё")
if !utf8.ValidString(string(yo)) {
Expand Down

0 comments on commit 7637d01

Please sign in to comment.