Skip to content

Commit

Permalink
WebSocket Improvements (zio#2358)
Browse files Browse the repository at this point in the history
* cleanup

* propagate traces
  • Loading branch information
adamgfraser authored Jul 30, 2023
1 parent 1b09b78 commit 538761f
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 27 deletions.
14 changes: 8 additions & 6 deletions zio-http-testkit/src/main/scala/zio/http/TestChannel.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package zio.http

import zio._
import zio.stacktracer.TracingImplicits.disableAutoTrace

import zio.http.ChannelEvent.{Unregistered, UserEvent, UserEventTriggered}

Expand All @@ -8,15 +10,15 @@ case class TestChannel(
out: Queue[WebSocketChannelEvent],
promise: Promise[Nothing, Unit],
) extends WebSocketChannel {
def awaitShutdown: UIO[Unit] =
def awaitShutdown(implicit trace: Trace): UIO[Unit] =
promise.await
def receive: Task[WebSocketChannelEvent] =
def receive(implicit trace: Trace): Task[WebSocketChannelEvent] =
in.take
def send(in: WebSocketChannelEvent): Task[Unit] =
def send(in: WebSocketChannelEvent)(implicit trace: Trace): Task[Unit] =
out.offer(in).unit
def sendAll(in: Iterable[WebSocketChannelEvent]): Task[Unit] =
def sendAll(in: Iterable[WebSocketChannelEvent])(implicit trace: Trace): Task[Unit] =
out.offerAll(in).unit
def shutdown: UIO[Unit] =
def shutdown(implicit trace: Trace): UIO[Unit] =
in.offer(ChannelEvent.Unregistered) *>
out.offer(ChannelEvent.Unregistered) *>
promise.succeed(()).unit
Expand All @@ -27,7 +29,7 @@ object TestChannel {
in: Queue[WebSocketChannelEvent],
out: Queue[WebSocketChannelEvent],
promise: Promise[Nothing, Unit],
): ZIO[Any, Nothing, TestChannel] =
)(implicit trace: Trace): ZIO[Any, Nothing, TestChannel] =
for {
_ <- out.offer(UserEventTriggered(UserEvent.HandshakeComplete))
} yield TestChannel(in, out, promise)
Expand Down
30 changes: 15 additions & 15 deletions zio-http/src/main/scala/zio/http/Channel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,44 +28,44 @@ trait Channel[-In, +Out] { self =>
/**
* Await shutdown of the channel.
*/
def awaitShutdown: UIO[Unit]
def awaitShutdown(implicit trace: Trace): UIO[Unit]

/**
* Read a message from the channel, suspending until the next message is
* available.
*/
def receive: Task[Out]
def receive(implicit trace: Trace): Task[Out]

/**
* Send a message to the channel.
*/
def send(in: In): Task[Unit]
def send(in: In)(implicit trace: Trace): Task[Unit]

/**
* Send all messages to the channel.
*/
def sendAll(in: Iterable[In]): Task[Unit]
def sendAll(in: Iterable[In])(implicit trace: Trace): Task[Unit]

/**
* Shut down the channel.
*/
def shutdown: UIO[Unit]
def shutdown(implicit trace: Trace): UIO[Unit]

/**
* Constructs a new channel that automatically transforms messages sent to
* this channel using the specified function.
*/
final def contramap[In2](f: In2 => In): Channel[In2, Out] =
new Channel[In2, Out] {
def awaitShutdown: UIO[Unit] =
def awaitShutdown(implicit trace: Trace): UIO[Unit] =
self.awaitShutdown
def receive: Task[Out] =
def receive(implicit trace: Trace): Task[Out] =
self.receive
def send(in: In2): Task[Unit] =
def send(in: In2)(implicit trace: Trace): Task[Unit] =
self.send(f(in))
def sendAll(in: Iterable[In2]): Task[Unit] =
def sendAll(in: Iterable[In2])(implicit trace: Trace): Task[Unit] =
self.sendAll(in.map(f))
def shutdown: UIO[Unit] =
def shutdown(implicit trace: Trace): UIO[Unit] =
self.shutdown
}

Expand All @@ -75,15 +75,15 @@ trait Channel[-In, +Out] { self =>
*/
final def map[Out2](f: Out => Out2)(implicit trace: Trace): Channel[In, Out2] =
new Channel[In, Out2] {
def awaitShutdown: UIO[Unit] =
def awaitShutdown(implicit trace: Trace): UIO[Unit] =
self.awaitShutdown
def receive: Task[Out2] =
def receive(implicit trace: Trace): Task[Out2] =
self.receive.map(f)
def send(in: In): Task[Unit] =
def send(in: In)(implicit trace: Trace): Task[Unit] =
self.send(in)
def sendAll(in: Iterable[In]): Task[Unit] =
def sendAll(in: Iterable[In])(implicit trace: Trace): Task[Unit] =
self.sendAll(in)
def shutdown: UIO[Unit] =
def shutdown(implicit trace: Trace): UIO[Unit] =
self.shutdown
}

Expand Down
12 changes: 6 additions & 6 deletions zio-http/src/main/scala/zio/http/WebSocketChannel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package zio.http

import zio._
import zio.stacktracer.TracingImplicits.disableAutoTrace

import zio.http.ChannelEvent.{ExceptionCaught, Read, Registered, Unregistered, UserEventTriggered}
import zio.http.netty.NettyChannel
Expand All @@ -31,31 +32,30 @@ private[http] object WebSocketChannel {
queue: Queue[WebSocketChannelEvent],
): WebSocketChannel =
new WebSocketChannel {
def awaitShutdown: UIO[Unit] =
def awaitShutdown(implicit trace: Trace): UIO[Unit] =
nettyChannel.awaitClose

def receive: Task[WebSocketChannelEvent] =
def receive(implicit trace: Trace): Task[WebSocketChannelEvent] =
queue.take

def send(in: WebSocketChannelEvent): Task[Unit] = {
def send(in: WebSocketChannelEvent)(implicit trace: Trace): Task[Unit] = {
in match {
case Read(message) => nettyChannel.writeAndFlush(frameToNetty(message))
case _ => ZIO.unit
}
}

def sendAll(in: Iterable[WebSocketChannelEvent]): Task[Unit] =
def sendAll(in: Iterable[WebSocketChannelEvent])(implicit trace: Trace): Task[Unit] =
ZIO.suspendSucceed {
val iterator = in.iterator.collect { case Read(message) => message }

println(s"sendAll")
ZIO.whileLoop(iterator.hasNext) {
val message = iterator.next()
if (iterator.hasNext) nettyChannel.write(frameToNetty(message))
else nettyChannel.writeAndFlush(frameToNetty(message))
}(_ => ())
}
def shutdown: UIO[Unit] =
def shutdown(implicit trace: Trace): UIO[Unit] =
nettyChannel.close(false).orDie
}

Expand Down
1 change: 1 addition & 0 deletions zio-http/src/main/scala/zio/http/WebSocketConfig.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package zio.http

import zio.Duration
import zio.stacktracer.TracingImplicits.disableAutoTrace

/**
* Server side websocket configuration
Expand Down
1 change: 1 addition & 0 deletions zio-http/src/main/scala/zio/http/WebSocketFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package zio.http

import zio.Chunk
import zio.stacktracer.TracingImplicits.disableAutoTrace

sealed trait WebSocketFrame extends Product with Serializable { self =>
def isFinal: Boolean = true
Expand Down

0 comments on commit 538761f

Please sign in to comment.