Skip to content

Commit

Permalink
Fix: CORS Headers implementation (zio#1409)
Browse files Browse the repository at this point in the history
* it looks like the CombinedHttpHeaders implementation from netty is  broken. Switching back to DefaultHttpHeaders

* fixed CORS headers implementation

* use DefaultHttpHeaders instead as CombineHeaders implementation from Netty is creating more issues than it solves

* merge of same headers into one, according to RFC

* fixed for scala 2.12

* cleanup after merge

* formatted
  • Loading branch information
gciuloaica authored Sep 12, 2022
1 parent e51ded1 commit d9d18bb
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 22 deletions.
17 changes: 12 additions & 5 deletions zio-http/src/main/scala/zio/http/Headers.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package zio.http

import io.netty.handler.codec.http.{CombinedHttpHeaders, HttpHeaders}
import io.netty.handler.codec.http.{DefaultHttpHeaders, HttpHeaders}
import zio.Chunk
import zio.http.headers.{HeaderConstructors, HeaderExtension}

Expand Down Expand Up @@ -37,11 +37,18 @@ final case class Headers(toChunk: Chunk[Header]) extends HeaderExtension[Headers
/**
* Converts a Headers to [io.netty.handler.codec.http.HttpHeaders]
*/
private[zio] def encode: HttpHeaders =
self.toList
.foldLeft[HttpHeaders](new CombinedHttpHeaders(true)) { case (headers, entry) =>
private[http] def encode: HttpHeaders = {
val (exceptions, regularHeaders) = self.toList.span(h => h._1.contains(HeaderNames.setCookie))
val combinedHeaders = regularHeaders
.groupBy(_._1)
.map { case (key, tuples) =>
key -> tuples.map(_._2).map(value => if (value.contains(",")) s"""\"$value\"""" else value).mkString(",")
}
(exceptions ++ combinedHeaders)
.foldLeft[HttpHeaders](new DefaultHttpHeaders(true)) { case (headers, entry) =>
headers.add(entry._1, entry._2)
}
}

}

Expand Down Expand Up @@ -69,6 +76,6 @@ object Headers extends HeaderConstructors {

def when(cond: Boolean)(headers: => Headers): Headers = if (cond) headers else Headers.empty

private[zio] def decode(headers: HttpHeaders): Headers =
private[http] def decode(headers: HttpHeaders): Headers =
Headers(headers.entries().asScala.toList.map(entry => (entry.getKey, entry.getValue)))
}
22 changes: 13 additions & 9 deletions zio-http/src/main/scala/zio/http/middleware/Cors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package zio.http.middleware

import io.netty.handler.codec.http.HttpHeaderNames
import zio.http._
import zio.http.middleware.Cors.CorsConfig
import zio.http.middleware.Cors.{CorsConfig, buildHeaders}

private[zio] trait Cors {

Expand All @@ -24,17 +24,13 @@ private[zio] trait Cors {
}
def corsHeaders(origin: Header, method: Method, isPreflight: Boolean): Headers = {
Headers.ifThenElse(isPreflight)(
onTrue = config.allowedHeaders.fold(Headers.empty) { h =>
Headers(HttpHeaderNames.ACCESS_CONTROL_ALLOW_HEADERS.toString(), h.mkString(","))
},
onFalse = config.exposedHeaders.fold(Headers.empty) { h =>
Headers(HttpHeaderNames.ACCESS_CONTROL_EXPOSE_HEADERS.toString(), h.mkString(","))
},
onTrue = buildHeaders(HttpHeaderNames.ACCESS_CONTROL_ALLOW_HEADERS.toString(), config.allowedHeaders),
onFalse = buildHeaders(HttpHeaderNames.ACCESS_CONTROL_EXPOSE_HEADERS.toString(), config.exposedHeaders),
) ++
Headers(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN.toString(), origin._2) ++
Headers(
buildHeaders(
HttpHeaderNames.ACCESS_CONTROL_ALLOW_METHODS.toString(),
config.allowedMethods.fold(method.toString())(m => m.map(m => m.toString()).mkString(",")),
config.allowedMethods.map(_.map(_.toJava.name())),
) ++
Headers.when(config.allowCredentials) {
Headers(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS, config.allowCredentials.toString)
Expand Down Expand Up @@ -73,4 +69,12 @@ object Cors {
),
exposedHeaders: Option[Set[String]] = Some(Set("*")),
)

private def buildHeaders(headerName: String, values: Option[Set[String]]): Headers = {
values match {
case Some(headerValues) =>
Headers(headerValues.toList.map(value => headerName -> value))
case None => Headers.empty
}
}
}
18 changes: 18 additions & 0 deletions zio-http/src/test/scala/zio/http/internal/HttpRunnableSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -114,4 +114,22 @@ abstract class HttpRunnableSpec extends ZIOSpecDefault { self =>
.map(_.status)
} yield status
}

def headers(
method: Method = Method.GET,
path: Path,
headers: Headers = Headers.empty,
): ZIO[EventLoopGroup with ChannelFactory with DynamicServer, Throwable, Headers] = {
for {
port <- DynamicServer.port
headers <- Client
.request(
"http://localhost:%d/%s".format(port, path),
method,
ssl = ClientSSLOptions.DefaultSSL,
headers = headers,
)
.map(_.headers)
} yield headers
}
}
17 changes: 11 additions & 6 deletions zio-http/src/test/scala/zio/http/middleware/CorsSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ import zio.test.Assertion.hasSubset
import zio.test._

object CorsSpec extends ZIOSpecDefault with HttpAppTestExtensions {
val app = Http.ok @@ cors()
val app = Http.ok @@ cors(CorsConfig(allowedMethods = Some(Set(Method.GET))))

override def spec = suite("CorsMiddlewares")(
test("OPTIONS request") {
val request = Request(
Expand All @@ -17,19 +18,23 @@ object CorsSpec extends ZIOSpecDefault with HttpAppTestExtensions {
headers = Headers.accessControlRequestMethod(Method.GET) ++ Headers.origin("test-env"),
)

val expected = Headers
val initialHeaders = Headers
.accessControlAllowCredentials(true)
.withAccessControlAllowMethods(Method.GET)
.withAccessControlAllowOrigin("test-env")
.withAccessControlAllowHeaders(
CorsConfig().allowedHeaders.getOrElse(Set.empty).mkString(","),
)
.toList

val expected = CorsConfig().allowedHeaders
.fold(Headers.empty) { h =>
h
.map(value => Headers.empty.withAccessControlAllowHeaders(value))
.fold(initialHeaders)(_ ++ _)
}
.toList
for {
res <- app(request)
} yield assert(res.headersAsList)(hasSubset(expected)) &&
assertTrue(res.status == Status.NoContent)

},
test("GET request") {
val request =
Expand Down
7 changes: 7 additions & 0 deletions zio-http/src/test/scala/zio/http/service/ServerSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,13 @@ object ServerSpec extends HttpRunnableSpec {
actual <- Http.response(res).withServer(server).deploy.headerValue(HeaderNames.server).run()
} yield assert(actual)(isSome(equalTo(server)))
},
test("multiple headers of same type with different values") {
val expectedValue = "test1,test2"
for {
res <- Response.text("abc").withVary("test1").withVary("test2").freeze
actual <- Http.response(res).deploy.headerValue(HeaderNames.vary).run()
} yield assert(actual)(isSome(equalTo(expectedValue)))
},
),
)

Expand Down
29 changes: 27 additions & 2 deletions zio-http/src/test/scala/zio/http/service/StaticServerSpec.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package zio.http.service

import zio.http.Middleware.cors
import zio.http._
import zio.http.internal.{DynamicServer, HttpGen, HttpRunnableSpec}
import zio.http.middleware.Cors.CorsConfig
import zio.test.Assertion.{equalTo, not}
import zio.test.TestAspect.timeout
import zio.test.{Gen, TestEnvironment, assertTrue, assertZIO, checkAll}
Expand All @@ -26,7 +28,11 @@ object StaticServerSpec extends HttpRunnableSpec {
case _ -> !! / "throwable" => throw new Exception("Throw inside Handler")
}

private val app = serve { nonZIO ++ staticApp }
private val staticAppWithCors = Http.collectZIO[Request] { case Method.GET -> !! / "success-cors" =>
ZIO.succeed(Response.ok.withVary("test1").withVary("test2"))
} @@ cors(CorsConfig(allowedMethods = Some(Set(Method.GET, Method.POST))))

private val app = serve { nonZIO ++ staticApp ++ staticAppWithCors }

private val methodGenWithoutHEAD: Gen[Any, Method] = Gen.fromIterable(
List(
Expand Down Expand Up @@ -84,7 +90,7 @@ object StaticServerSpec extends HttpRunnableSpec {
suite("Server") {
app
.as(
List(serverStartSpec, staticAppSpec, nonZIOSpec, throwableAppSpec),
List(serverStartSpec, staticAppSpec, nonZIOSpec, throwableAppSpec, multiHeadersSpec),
)
}.provideSomeLayerShared[TestEnvironment](env) @@ timeout(30 seconds)

Expand Down Expand Up @@ -122,4 +128,23 @@ object StaticServerSpec extends HttpRunnableSpec {
} yield assertTrue(status == Status.InternalServerError)
}
}

def multiHeadersSpec = suite("Multi headers spec")(
test("Multiple headers should have the value combined in a single header") {
for {
result <- headers(Method.GET, !! / "success-cors")
} yield {
assertTrue(result.hasHeader(HeaderNames.vary)) &&
assertTrue(result.vary.contains("test1,test2"))
}
},
test("CORS headers should be properly encoded") {
for {
result <- headers(Method.GET, !! / "success-cors", Headers.origin("example.com"))
} yield {
assertTrue(result.hasHeader(HeaderNames.accessControlAllowMethods)) &&
assertTrue(result.accessControlAllowMethods.contains("GET,POST"))
}
},
)
}

0 comments on commit d9d18bb

Please sign in to comment.