Skip to content

Commit

Permalink
Refactor: Use builder pattern for SocketDecoder and SocketProtocol (z…
Browse files Browse the repository at this point in the history
…io#1367)

* use builder pattern for SockerDecoder and SocketProtocol

* fixed naming of allow extensions flag

* fixed example

* migrated to builder pattern

* updated SocketDecoder implementation to reduce the surface areas of the API.

* formatted

* updated SocketDecoder and SockerProtocol API

* refactored SocketProtocol and SocketDecoder according to review

* refactor: reorder methods

Co-authored-by: Tushar Mathur <[email protected]>
  • Loading branch information
gciuloaica and tusharmath authored Aug 10, 2022
1 parent bf8d7a0 commit 4743454
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 136 deletions.
6 changes: 3 additions & 3 deletions example/src/main/scala/example/WebSocketAdvanced.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ object WebSocketAdvanced extends ZIOAppDefault {
val httpSocket: Http[Any, Throwable, WebSocketChannelEvent, Unit] =
messageSocket ++ channelSocket

val protocol = SocketProtocol.subProtocol("json") // Setup protocol settings
val protocol = SocketProtocol.default.withSubProtocol(Some("json")) // Setup protocol settings

val decoder = SocketDecoder.allowExtensions // Setup decoder settings
val decoder = SocketDecoder.default.withExtensions(allowed = true) // Setup decoder settings

val socketApp: SocketApp[Any] = // Combine all channel handlers together
httpSocket.toSocketApp
Expand All @@ -63,5 +63,5 @@ object WebSocketAdvanced extends ZIOAppDefault {
case Method.GET -> !! / "subscriptions" => socketApp.toResponse
}

override val run = Server.start(8090, app)
override val run = Server.start(8091, app)
}
4 changes: 2 additions & 2 deletions zio-http/src/main/scala/zhttp/socket/SocketApp.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ final case class SocketApp[-R](
* Frame decoder configuration
*/
def withDecoder(decoder: SocketDecoder): SocketApp[R] =
copy(decoder = self.decoder ++ decoder)
copy(decoder = decoder, protocol = protocol.withDecoderConfig(decoder))

/**
* Server side websocket configuration
*/
def withProtocol(protocol: SocketProtocol): SocketApp[R] =
copy(protocol = self.protocol ++ protocol)
copy(protocol = protocol)
}

object SocketApp {
Expand Down
80 changes: 32 additions & 48 deletions zio-http/src/main/scala/zhttp/socket/SocketDecoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,77 +5,61 @@ import io.netty.handler.codec.http.websocketx.WebSocketDecoderConfig
/**
* Frame decoder configuration
*/
sealed trait SocketDecoder { self =>
import SocketDecoder._
def ++(other: SocketDecoder): SocketDecoder = SocketDecoder.Concat(self, other)
def javaConfig: WebSocketDecoderConfig = {
val b = WebSocketDecoderConfig.newBuilder()
def loop(decoder: SocketDecoder): Unit = {
decoder match {
case Default => ()
case MaxFramePayloadLength(length) => b.maxFramePayloadLength(length)
case RejectMaskedFrames => b.expectMaskedFrames(false)
case AllowMaskMismatch => b.allowMaskMismatch(true)
case AllowExtensions => b.allowExtensions(true)
case AllowProtocolViolation => b.closeOnProtocolViolation(false)
case SkipUTF8Validation => b.withUTF8Validator(false)
case Concat(a, b) =>
loop(a)
loop(b)
}
()
}
loop(self)
b.build()
}
}
final case class SocketDecoder(
maxFramePayloadLength: Int = 65536,
expectMaskedFrames: Boolean = true,
allowMaskMismatch: Boolean = false,
allowExtensions: Boolean = false,
closeOnProtocolViolation: Boolean = true,
withUTF8Validator: Boolean = true,
) { self =>

object SocketDecoder {
private final case class MaxFramePayloadLength(length: Int) extends SocketDecoder
private case object RejectMaskedFrames extends SocketDecoder
private case object AllowMaskMismatch extends SocketDecoder
private case object AllowExtensions extends SocketDecoder
private case object AllowProtocolViolation extends SocketDecoder
private case object SkipUTF8Validation extends SocketDecoder
private final case class Concat(a: SocketDecoder, b: SocketDecoder) extends SocketDecoder
private case object Default extends SocketDecoder
def javaConfig[zhttp]: WebSocketDecoderConfig = WebSocketDecoderConfig
.newBuilder()
.maxFramePayloadLength(maxFramePayloadLength)
.expectMaskedFrames(expectMaskedFrames)
.allowMaskMismatch(allowMaskMismatch)
.allowExtensions(allowExtensions)
.closeOnProtocolViolation(closeOnProtocolViolation)
.withUTF8Validator(withUTF8Validator)
.build()

/**
* Sets Maximum length of a frame's payload. Setting this to an appropriate
* value for you application helps check for denial of services attacks.
*/
def maxFramePayloadLength(length: Int): SocketDecoder = MaxFramePayloadLength(length)
def withExtensions(allowed: Boolean): SocketDecoder = self.copy(allowExtensions = allowed)

/**
* Web socket servers must set this to true to reject incoming masked payload.
* When set to true, frames which are not masked properly according to the
* standard will still be accepted.
*/
def rejectMaskedFrames: SocketDecoder = RejectMaskedFrames
def withMaskMismatch(allowed: Boolean): SocketDecoder = self.copy(allowMaskMismatch = allowed)

/**
* When set to true, frames which are not masked properly according to the
* standard will still be accepted.
* Web socket servers must set this to true to reject incoming masked payload.
*/
def allowMaskMismatch: SocketDecoder = AllowMaskMismatch
def withMaskedFrames(allowed: Boolean): SocketDecoder = self.copy(expectMaskedFrames = allowed)

/**
* Allow extensions to be used in the reserved bits of the web socket frame
* Sets Maximum length of a frame's payload. Setting this to an appropriate
* value for you application helps check for denial of services attacks.
*/
def allowExtensions: SocketDecoder = AllowExtensions
def withMaxFramePayloadLength(length: Int): SocketDecoder = self.copy(maxFramePayloadLength = length)

/**
* Flag to not send close frame immediately on any protocol violation.ion.
*/
def allowProtocolViolation: SocketDecoder = AllowProtocolViolation
def withProtocolViolation(allowed: Boolean): SocketDecoder = self.copy(closeOnProtocolViolation = allowed)

/**
* Allows you to avoid adding of Utf8FrameValidator to the pipeline on the
* WebSocketServerProtocolHandler creation. This is useful (less overhead)
* when you use only BinaryWebSocketFrame within your web socket connection.
*/
def skipUTF8Validation: SocketDecoder = SkipUTF8Validation
def withUTF8Validation(enable: Boolean): SocketDecoder = self.copy(withUTF8Validator = enable)
}

object SocketDecoder {

/**
* Creates an default decoder configuration.
*/
def default: SocketDecoder = Default
def default: SocketDecoder = SocketDecoder()
}
129 changes: 46 additions & 83 deletions zio-http/src/main/scala/zhttp/socket/SocketProtocol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,118 +10,81 @@ import zio.Duration
/**
* Server side websocket configuration
*/
sealed trait SocketProtocol { self =>
import SocketProtocol._

def ++(other: SocketProtocol): SocketProtocol = SocketProtocol.Concat(self, other)

def clientBuilder: WebSocketClientProtocolConfig.Builder = {
val b = WebSocketClientProtocolConfig.newBuilder()
def loop(protocol: SocketProtocol): Unit = {
protocol match {
case Default => ()
case SubProtocol(name) => b.subprotocol(name)
case HandshakeTimeoutMillis(duration) => b.handshakeTimeoutMillis(duration.toMillis)
case ForceCloseTimeoutMillis(duration) => b.forceCloseTimeoutMillis(duration.toMillis)
case ForwardCloseFrames => b.handleCloseFrames(false)
case SendCloseFrame(status) => b.sendCloseFrame(status.asJava)
case SendCloseFrameCode(code, reason) => b.sendCloseFrame(new WebSocketCloseStatus(code, reason))
case ForwardPongFrames => b.dropPongFrames(false)
case Concat(a, b) =>
loop(a)
loop(b)
}
()
}
loop(self)
b
}

def serverBuilder: WebSocketServerProtocolConfig.Builder = {
val b = WebSocketServerProtocolConfig.newBuilder().checkStartsWith(true).websocketPath("")
def loop(protocol: SocketProtocol): Unit = {
protocol match {
case Default => ()
case SubProtocol(name) => b.subprotocols(name)
case HandshakeTimeoutMillis(duration) => b.handshakeTimeoutMillis(duration.toMillis)
case ForceCloseTimeoutMillis(duration) => b.forceCloseTimeoutMillis(duration.toMillis)
case ForwardCloseFrames => b.handleCloseFrames(false)
case SendCloseFrame(status) => b.sendCloseFrame(status.asJava)
case SendCloseFrameCode(code, reason) => b.sendCloseFrame(new WebSocketCloseStatus(code, reason))
case ForwardPongFrames => b.dropPongFrames(false)
case Concat(a, b) =>
loop(a)
loop(b)
}
()
}
loop(self)
b
}

}

object SocketProtocol {
final case class SocketProtocol(
subprotocols: Option[String] = None,
handshakeTimeoutMillis: Long = 10000L,
forceCloseTimeoutMillis: Long = -1L,
handleCloseFrames: Boolean = true,
sendCloseFrame: WebSocketCloseStatus = WebSocketCloseStatus.NORMAL_CLOSURE,
dropPongFrames: Boolean = true,
decoderConfig: SocketDecoder = SocketDecoder.default,
) { self =>

def clientBuilder: WebSocketClientProtocolConfig.Builder = WebSocketClientProtocolConfig
.newBuilder()
.subprotocol(subprotocols.orNull)
.handshakeTimeoutMillis(handshakeTimeoutMillis)
.forceCloseTimeoutMillis(forceCloseTimeoutMillis)
.handleCloseFrames(handleCloseFrames)
.sendCloseFrame(sendCloseFrame)
.dropPongFrames(dropPongFrames)

def serverBuilder: WebSocketServerProtocolConfig.Builder = WebSocketServerProtocolConfig
.newBuilder()
.checkStartsWith(true)
.websocketPath("")
.subprotocols(subprotocols.orNull)
.handshakeTimeoutMillis(handshakeTimeoutMillis)
.forceCloseTimeoutMillis(forceCloseTimeoutMillis)
.handleCloseFrames(handleCloseFrames)
.sendCloseFrame(sendCloseFrame)
.dropPongFrames(dropPongFrames)
.decoderConfig(decoderConfig.javaConfig)

/**
* Close frame to send, when close frame was not send manually.
*/
def closeFrame(status: CloseStatus): SocketProtocol = SendCloseFrame(status)
def withCloseFrame(code: Int, reason: String): SocketProtocol =
self.copy(sendCloseFrame = new WebSocketCloseStatus(code, reason))

/**
* Close frame to send, when close frame was not send manually.
*/
def closeFrame(code: Int, reason: String): SocketProtocol =
SendCloseFrameCode(code, reason)
def withCloseStatus(status: CloseStatus): SocketProtocol = self.copy(sendCloseFrame = status.asJava)

/**
* Creates an default decoder configuration.
*/
def default: SocketProtocol = Default
def withDecoderConfig(socketDecoder: SocketDecoder): SocketProtocol = self.copy(decoderConfig = socketDecoder)

/**
* Close the connection if it was not closed by the client after timeout
* specified
*/
def forceCloseTimeout(duration: Duration): SocketProtocol =
ForceCloseTimeoutMillis(duration)
def withForceCloseTimeout(duration: Duration): SocketProtocol = self.copy(forceCloseTimeoutMillis = duration.toMillis)

/**
* Close frames should be forwarded
*/
def forwardCloseFrames: SocketProtocol = ForwardCloseFrames
def withForwardCloseFrames(forward: Boolean): SocketProtocol = self.copy(handleCloseFrames = forward)

/**
* If pong frames should be forwarded
* Pong frames should be forwarded
*/
def forwardPongFrames: SocketProtocol = ForwardPongFrames
def withForwardPongFrames(forward: Boolean): SocketProtocol = self.copy(dropPongFrames = !forward)

/**
* Handshake timeout in mills
*/
def handshakeTimeout(duration: Duration): SocketProtocol =
HandshakeTimeoutMillis(duration)
def withHandshakeTimeout(duration: Duration): SocketProtocol = self.copy(handshakeTimeoutMillis = duration.toMillis)

/**
* Used to specify the websocket sub-protocol
*/
def subProtocol(name: String): SocketProtocol = SubProtocol(name)

private final case class SubProtocol(name: String) extends SocketProtocol

private final case class SendCloseFrame(status: CloseStatus) extends SocketProtocol

private final case class HandshakeTimeoutMillis(duration: Duration) extends SocketProtocol

private final case class ForceCloseTimeoutMillis(duration: Duration) extends SocketProtocol

private final case class SendCloseFrameCode(code: Int, reason: String) extends SocketProtocol

private final case class Concat(a: SocketProtocol, b: SocketProtocol) extends SocketProtocol

private case object Default extends SocketProtocol
def withSubProtocol(name: Option[String]): SocketProtocol = self.copy(subprotocols = name)
}

private case object ForwardPongFrames extends SocketProtocol
object SocketProtocol {

private case object ForwardCloseFrames extends SocketProtocol
/**
* Creates an default decoder configuration.
*/
def default: SocketProtocol = SocketProtocol()
}

0 comments on commit 4743454

Please sign in to comment.