Skip to content

Commit

Permalink
Merge pull request square#5867 from square/jwilson.0314.more_compress…
Browse files Browse the repository at this point in the history
…ion_

Hook up compression in WebSocketReader and WebSocketWriter
  • Loading branch information
swankjesse authored Mar 15, 2020
2 parents 97a5e7a + 8703126 commit 84a7d0d
Show file tree
Hide file tree
Showing 10 changed files with 336 additions and 70 deletions.
3 changes: 2 additions & 1 deletion okhttp/src/main/java/okhttp3/WebSocket.kt
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ interface WebSocket {

/**
* Returns the size in bytes of all messages enqueued to be transmitted to the server. This
* doesn't include framing overhead. It also doesn't include any bytes buffered by the operating
* doesn't include framing overhead. If compression is enabled, uncompressed messages size
* is used to calculate this value. It also doesn't include any bytes buffered by the operating
* system or network intermediaries. This method returns 0 if no messages are waiting in the
* queue. If may return a nonzero value after the web socket has been canceled; this indicates
* that enqueued messages were not transmitted.
Expand Down
39 changes: 37 additions & 2 deletions okhttp/src/main/java/okhttp3/internal/ws/RealWebSocket.kt
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,13 @@ class RealWebSocket(
synchronized(this) {
this.name = name
this.streams = streams
this.writer = WebSocketWriter(streams.client, streams.sink, random)
this.writer = WebSocketWriter(
isClient = streams.client,
sink = streams.sink,
random = random,
messageDeflater = null,
minimumDeflateSize = Long.MAX_VALUE
)
this.writerTask = WriterTask()
if (pingIntervalMillis != 0L) {
val pingIntervalNanos = MILLISECONDS.toNanos(pingIntervalMillis)
Expand All @@ -229,7 +235,12 @@ class RealWebSocket(
}
}

reader = WebSocketReader(streams.client, streams.source, this)
reader = WebSocketReader(
isClient = streams.client,
source = streams.source,
frameCallback = this,
messageInflater = null
)
}

/** Receive frames until there are no more. Invoked only by the reader thread. */
Expand Down Expand Up @@ -304,13 +315,19 @@ class RealWebSocket(
require(code != -1)

var toClose: Streams? = null
var readerToClose: WebSocketReader? = null
var writerToClose: WebSocketWriter? = null
synchronized(this) {
check(receivedCloseCode == -1) { "already closed" }
receivedCloseCode = code
receivedCloseReason = reason
if (enqueuedClose && messageAndCloseQueue.isEmpty()) {
toClose = this.streams
this.streams = null
readerToClose = this.reader
this.reader = null
writerToClose = this.writer
this.writer = null
this.taskQueue.shutdown()
}
}
Expand All @@ -323,6 +340,8 @@ class RealWebSocket(
}
} finally {
toClose?.closeQuietly()
readerToClose?.closeQuietly()
writerToClose?.closeQuietly()
}
}

Expand Down Expand Up @@ -422,6 +441,8 @@ class RealWebSocket(
var receivedCloseCode = -1
var receivedCloseReason: String? = null
var streamsToClose: Streams? = null
var readerToClose: WebSocketReader? = null
var writerToClose: WebSocketWriter? = null

synchronized(this@RealWebSocket) {
if (failed) {
Expand All @@ -438,6 +459,10 @@ class RealWebSocket(
if (receivedCloseCode != -1) {
streamsToClose = this.streams
this.streams = null
readerToClose = this.reader
this.reader = null
writerToClose = this.writer
this.writer = null
this.taskQueue.shutdown()
} else {
// When we request a graceful close also schedule a cancel of the web socket.
Expand Down Expand Up @@ -476,6 +501,8 @@ class RealWebSocket(
return true
} finally {
streamsToClose?.closeQuietly()
readerToClose?.closeQuietly()
writerToClose?.closeQuietly()
}
}

Expand Down Expand Up @@ -505,18 +532,26 @@ class RealWebSocket(

fun failWebSocket(e: Exception, response: Response?) {
val streamsToClose: Streams?
val readerToClose: WebSocketReader?
val writerToClose: WebSocketWriter?
synchronized(this) {
if (failed) return // Already failed.
failed = true
streamsToClose = this.streams
this.streams = null
readerToClose = this.reader
this.reader = null
writerToClose = this.writer
this.writer = null
taskQueue.shutdown()
}

try {
listener.onFailure(this, e, response)
} finally {
streamsToClose?.closeQuietly()
readerToClose?.closeQuietly()
writerToClose?.closeQuietly()
}
}

Expand Down
10 changes: 5 additions & 5 deletions okhttp/src/main/java/okhttp3/internal/ws/WebSocketExtensions.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
package okhttp3.internal.ws

import java.io.IOException
import okhttp3.Response
import okhttp3.Headers
import okhttp3.internal.delimiterOffset
import okhttp3.internal.trimSubstring

Expand Down Expand Up @@ -82,7 +82,7 @@ data class WebSocketExtensions(
private const val HEADER_WEB_SOCKET_EXTENSION = "Sec-WebSocket-Extensions"

@Throws(IOException::class)
fun parse(response: Response): WebSocketExtensions {
fun parse(responseHeaders: Headers): WebSocketExtensions {
// Note that this code does case-insensitive comparisons, even though the spec doesn't specify
// whether extension tokens and parameters are case-insensitive or not.

Expand All @@ -94,11 +94,11 @@ data class WebSocketExtensions(
var unexpectedValues = false

// Parse each header.
for (i in 0 until response.headers.size) {
if (!response.headers.name(i).equals(HEADER_WEB_SOCKET_EXTENSION, ignoreCase = true)) {
for (i in 0 until responseHeaders.size) {
if (!responseHeaders.name(i).equals(HEADER_WEB_SOCKET_EXTENSION, ignoreCase = true)) {
continue // Not a header we're interested in.
}
val header = response.headers.value(i)
val header = responseHeaders.value(i)

// Parse each extension.
var pos = 0
Expand Down
42 changes: 33 additions & 9 deletions okhttp/src/main/java/okhttp3/internal/ws/WebSocketReader.kt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package okhttp3.internal.ws

import java.io.Closeable
import java.io.IOException
import java.net.ProtocolException
import java.util.concurrent.TimeUnit
Expand Down Expand Up @@ -50,19 +51,20 @@ import okio.ByteString
*
* [rfc_6455]: http://tools.ietf.org/html/rfc6455
*/
internal class WebSocketReader(
class WebSocketReader(
private val isClient: Boolean,
val source: BufferedSource,
private val frameCallback: FrameCallback
) {

var closed = false
private val frameCallback: FrameCallback,
private val messageInflater: MessageInflater?
) : Closeable {
private var closed = false

// Stateful data about the current frame.
private var opcode = 0
private var frameLength = 0L
private var isFinalFrame = false
private var isControlFrame = false
private var readingCompressedMessage = false

private val controlFrameBuffer = Buffer()
private val messageFrameBuffer = Buffer()
Expand Down Expand Up @@ -125,12 +127,25 @@ internal class WebSocketReader(
}

val reservedFlag1 = b0 and B0_FLAG_RSV1 != 0
when (opcode) {
OPCODE_TEXT, OPCODE_BINARY -> {
if (reservedFlag1) {
if (messageInflater == null) throw ProtocolException("Unexpected rsv1 flag")
readingCompressedMessage = true
} else {
readingCompressedMessage = false
}
}
else -> {
if (reservedFlag1) throw ProtocolException("Unexpected rsv1 flag")
}
}

val reservedFlag2 = b0 and B0_FLAG_RSV2 != 0
if (reservedFlag2) throw ProtocolException("Unexpected rsv2 flag")

val reservedFlag3 = b0 and B0_FLAG_RSV3 != 0
if (reservedFlag1 || reservedFlag2 || reservedFlag3) {
// Reserved flags are for extensions which we currently do not support.
throw ProtocolException("Reserved flags are unsupported.")
}
if (reservedFlag3) throw ProtocolException("Unexpected rsv3 flag")

val b1 = source.readByte() and 0xff

Expand Down Expand Up @@ -216,6 +231,10 @@ internal class WebSocketReader(

readMessage()

if (readingCompressedMessage) {
messageInflater!!.inflate(messageFrameBuffer)
}

if (opcode == OPCODE_TEXT) {
frameCallback.onReadMessage(messageFrameBuffer.readUtf8())
} else {
Expand Down Expand Up @@ -264,4 +283,9 @@ internal class WebSocketReader(
}
}
}

@Throws(IOException::class)
override fun close() {
messageInflater?.close()
}
}
48 changes: 31 additions & 17 deletions okhttp/src/main/java/okhttp3/internal/ws/WebSocketWriter.kt
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
*/
package okhttp3.internal.ws

import java.io.Closeable
import java.io.IOException
import java.util.Random
import okhttp3.internal.ws.WebSocketProtocol.B0_FLAG_FIN
import okhttp3.internal.ws.WebSocketProtocol.B0_FLAG_RSV1
import okhttp3.internal.ws.WebSocketProtocol.B1_FLAG_MASK
import okhttp3.internal.ws.WebSocketProtocol.OPCODE_CONTROL_CLOSE
import okhttp3.internal.ws.WebSocketProtocol.OPCODE_CONTROL_PING
Expand All @@ -39,11 +41,15 @@ import okio.ByteString
*
* [rfc_6455]: http://tools.ietf.org/html/rfc6455
*/
internal class WebSocketWriter(
class WebSocketWriter(
private val isClient: Boolean,
val sink: BufferedSink,
val random: Random
) {
val random: Random,
private val messageDeflater: MessageDeflater?,
private val minimumDeflateSize: Long
) : Closeable {
/** This buffer is holds outbound data for compression and masking. */
private val messageBuffer = Buffer()

/** The [Buffer] of [sink]. Write to this and then flush/emit [sink]. */
private val sinkBuffer: Buffer = sink.buffer
Expand Down Expand Up @@ -136,47 +142,55 @@ internal class WebSocketWriter(
fun writeMessageFrame(formatOpcode: Int, data: ByteString) {
if (writerClosed) throw IOException("closed")

val b0 = formatOpcode or B0_FLAG_FIN
messageBuffer.write(data)

var b0 = formatOpcode or B0_FLAG_FIN
val messageDeflater = this.messageDeflater
if (messageDeflater != null && data.size >= minimumDeflateSize) {
messageDeflater.deflate(messageBuffer)
b0 = b0 or B0_FLAG_RSV1
}
val dataSize = messageBuffer.size
sinkBuffer.writeByte(b0)

var b1 = 0
if (isClient) {
b1 = b1 or B1_FLAG_MASK
}
when {
data.size <= PAYLOAD_BYTE_MAX -> {
b1 = b1 or data.size
dataSize <= PAYLOAD_BYTE_MAX -> {
b1 = b1 or dataSize.toInt()
sinkBuffer.writeByte(b1)
}
data.size <= PAYLOAD_SHORT_MAX -> {
dataSize <= PAYLOAD_SHORT_MAX -> {
b1 = b1 or PAYLOAD_SHORT
sinkBuffer.writeByte(b1)
sinkBuffer.writeShort(data.size)
sinkBuffer.writeShort(dataSize.toInt())
}
else -> {
b1 = b1 or PAYLOAD_LONG
sinkBuffer.writeByte(b1)
sinkBuffer.writeLong(data.size.toLong())
sinkBuffer.writeLong(dataSize)
}
}

if (isClient) {
random.nextBytes(maskKey!!)
sinkBuffer.write(maskKey)

if (data.size > 0L) {
val bufferStart = sinkBuffer.size
sinkBuffer.write(data)

sinkBuffer.readAndWriteUnsafe(maskCursor!!)
maskCursor.seek(bufferStart)
if (dataSize > 0L) {
messageBuffer.readAndWriteUnsafe(maskCursor!!)
maskCursor.seek(0L)
toggleMask(maskCursor, maskKey)
maskCursor.close()
}
} else {
sinkBuffer.write(data)
}

sinkBuffer.write(messageBuffer, dataSize)
sink.emit()
}

override fun close() {
messageDeflater?.close()
}
}
Loading

0 comments on commit 84a7d0d

Please sign in to comment.