Skip to content

Commit

Permalink
Don't depend on readerIndex() of 0 when handling websocket frames (ne…
Browse files Browse the repository at this point in the history
…tty#12022)


Motivation:

We had some code iin the websocket implementation which depended on the fact that the buffers have a rreaderIndex() of 0. This is not needed at all and may produce bugs in the future.

Modifications:

- Refactor the code to not depend on the readerIndex of 0.
- Add unit test

Result:

Make implemenatation more robust
  • Loading branch information
normanmaurer authored Jan 19, 2022
1 parent f6ea528 commit 3cf83d5
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,6 @@ private static ByteBuf newBinaryData(int statusCode, String reasonText) {
if (!reasonText.isEmpty()) {
binaryData.writeCharSequence(reasonText, CharsetUtil.UTF_8);
}

binaryData.readerIndex(0);
return binaryData;
}

Expand All @@ -133,12 +131,11 @@ public CloseWebSocketFrame(boolean finalFragment, int rsv, ByteBuf binaryData) {
*/
public int statusCode() {
ByteBuf binaryData = content();
if (binaryData == null || binaryData.capacity() == 0) {
if (binaryData == null || binaryData.readableBytes() < 2) {
return -1;
}

binaryData.readerIndex(0);
return binaryData.getShort(0);
return binaryData.getShort(binaryData.readerIndex());
}

/**
Expand All @@ -147,15 +144,11 @@ public int statusCode() {
*/
public String reasonText() {
ByteBuf binaryData = content();
if (binaryData == null || binaryData.capacity() <= 2) {
if (binaryData == null || binaryData.readableBytes() <= 2) {
return "";
}

binaryData.readerIndex(2);
String reasonText = binaryData.toString(CharsetUtil.UTF_8);
binaryData.readerIndex(0);

return reasonText;
return binaryData.toString(binaryData.readerIndex() + 2, binaryData.readableBytes() - 2, CharsetUtil.UTF_8);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ public void check(ByteBuf buffer) {
buffer.forEachByte(this);
}

void check(ByteBuf buffer, int index, int length) {
checking = true;
buffer.forEachByte(index, length, this);
}

public void finish() {
checking = false;
codep = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -463,30 +463,23 @@ protected void checkCloseFrameBody(
if (buffer == null || !buffer.isReadable()) {
return;
}
if (buffer.readableBytes() == 1) {
if (buffer.readableBytes() < 2) {
protocolViolation(ctx, buffer, WebSocketCloseStatus.INVALID_PAYLOAD_DATA, "Invalid close frame body");
}

// Save reader index
int idx = buffer.readerIndex();
buffer.readerIndex(0);

// Must have 2 byte integer within the valid range
int statusCode = buffer.readShort();
int statusCode = buffer.getShort(buffer.readerIndex());
if (!WebSocketCloseStatus.isValidStatusCode(statusCode)) {
protocolViolation(ctx, buffer, "Invalid close frame getStatus code: " + statusCode);
}

// May have UTF-8 message
if (buffer.isReadable()) {
if (buffer.readableBytes() > 2) {
try {
new Utf8Validator().check(buffer);
new Utf8Validator().check(buffer, buffer.readerIndex() + 2, buffer.readableBytes() - 2);
} catch (CorruptedWebSocketFrameException ex) {
protocolViolation(ctx, buffer, ex);
}
}

// Restore reader index
buffer.readerIndex(idx);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
*/
package io.netty.handler.codec.http.websocketx;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.util.CharsetUtil;
import org.assertj.core.api.ThrowableAssert;
import org.junit.jupiter.api.Test;

Expand Down Expand Up @@ -68,6 +71,15 @@ void testValidCode() {
doTestValidCode(new CloseWebSocketFrame(true, 0, 1000, "valid code"), 1000, "valid code");
}

@Test
void testNonZeroReaderIndex() {
ByteBuf buffer = Unpooled.buffer().writeZero(1);
buffer.writeShort(WebSocketCloseStatus.NORMAL_CLOSURE.code())
.writeCharSequence(WebSocketCloseStatus.NORMAL_CLOSURE.reasonText(), CharsetUtil.US_ASCII);
doTestValidCode(new CloseWebSocketFrame(true, 0, buffer.skipBytes(1)),
WebSocketCloseStatus.NORMAL_CLOSURE.code(), WebSocketCloseStatus.NORMAL_CLOSURE.reasonText());
}

private static void doTestInvalidCode(ThrowableAssert.ThrowingCallable callable) {
assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(callable);
}
Expand Down

0 comments on commit 3cf83d5

Please sign in to comment.