Skip to content

Commit

Permalink
Add CorsWebFilter
Browse files Browse the repository at this point in the history
This new WebFilter implementation is designed to allow initial
CORS support when using WebFlux functional API. More high-level
API may be introduced later.

Issue: SPR-15567
  • Loading branch information
sdeleuze committed Jun 21, 2017
1 parent 59e9094 commit 1e04cdf
Show file tree
Hide file tree
Showing 2 changed files with 198 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package org.springframework.web.cors.reactive;

import reactor.core.publisher.Mono;

import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.util.Assert;
import org.springframework.web.cors.*;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain;


/**
* {@link WebFilter} that handles CORS preflight requests and intercepts
* CORS simple and actual requests thanks to a {@link CorsProcessor} implementation
* ({@link DefaultCorsProcessor} by default) in order to add the relevant CORS
* response headers (like {@code Access-Control-Allow-Origin}) using the provided
* {@link CorsConfigurationSource} (for example an {@link UrlBasedCorsConfigurationSource}
* instance.
*
* <p>This is an alternative to Spring WebFlux Java config CORS configuration,
* mostly useful for applications using the functional API.
*
* @author Sebastien Deleuze
* @since 5.0
* @see <a href="http://www.w3.org/TR/cors/">CORS W3C recommendation</a>
*/
public class CorsWebFilter implements WebFilter {

private final CorsConfigurationSource configSource;

private final CorsProcessor processor;


/**
* Constructor accepting a {@link CorsConfigurationSource} used by the filter
* to find the {@link CorsConfiguration} to use for each incoming request.
* @see UrlBasedCorsConfigurationSource
*/
public CorsWebFilter(CorsConfigurationSource configSource) {
this(configSource, new DefaultCorsProcessor());
}

/**
* Constructor accepting a {@link CorsConfigurationSource} used by the filter
* to find the {@link CorsConfiguration} to use for each incoming request and a
* custom {@link CorsProcessor} to use to apply the matched
* {@link CorsConfiguration} for a request.
* @see UrlBasedCorsConfigurationSource
*/
public CorsWebFilter(CorsConfigurationSource configSource, CorsProcessor processor) {
Assert.notNull(configSource, "CorsConfigurationSource must not be null");
Assert.notNull(processor, "CorsProcessor must not be null");
this.configSource = configSource;
this.processor = processor;
}


@Override
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
ServerHttpRequest request = exchange.getRequest();
if (CorsUtils.isCorsRequest(request)) {
CorsConfiguration corsConfiguration = this.configSource.getCorsConfiguration(exchange);
if (corsConfiguration != null) {
boolean isValid = this.processor.process(corsConfiguration, exchange);
if (!isValid || CorsUtils.isPreFlightRequest(request)) {
return Mono.empty();
}
}
}
return chain.filter(exchange);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package org.springframework.web.cors.reactive;


import java.io.IOException;
import java.util.Arrays;

import javax.servlet.ServletException;

import org.junit.Before;
import org.junit.Test;
import reactor.core.publisher.Mono;

import org.springframework.http.HttpMethod;
import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest;
import org.springframework.mock.http.server.reactive.test.MockServerWebExchange;
import org.springframework.web.cors.CorsConfiguration;
import org.springframework.web.server.WebFilterChain;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.springframework.http.HttpHeaders.*;

/**
* Unit tests for {@link CorsWebFilter}.
* @author Sebastien Deleuze
*/
public class CorsWebFilterTests {

private CorsWebFilter filter;

private final CorsConfiguration config = new CorsConfiguration();

@Before
public void setup() throws Exception {
config.setAllowedOrigins(Arrays.asList("http://domain1.com", "http://domain2.com"));
config.setAllowedMethods(Arrays.asList("GET", "POST"));
config.setAllowedHeaders(Arrays.asList("header1", "header2"));
config.setExposedHeaders(Arrays.asList("header3", "header4"));
config.setMaxAge(123L);
config.setAllowCredentials(false);
filter = new CorsWebFilter(r -> config);
}

@Test
public void validActualRequest() {

MockServerHttpRequest request = MockServerHttpRequest
.get("http://domain1.com/test.html")
.header(HOST, "domain1.com")
.header(ORIGIN, "http://domain2.com")
.header("header2", "foo")
.build();
MockServerWebExchange exchange = new MockServerWebExchange(request);

WebFilterChain filterChain = (filterExchange) -> {
try {
assertEquals("http://domain2.com", filterExchange.getResponse().getHeaders().getFirst(ACCESS_CONTROL_ALLOW_ORIGIN));
assertEquals("header3, header4", filterExchange.getResponse().getHeaders().getFirst(ACCESS_CONTROL_EXPOSE_HEADERS));
} catch (AssertionError ex) {
return Mono.error(ex);
}
return Mono.empty();

};
filter.filter(exchange, filterChain);
}

@Test
public void invalidActualRequest() throws ServletException, IOException {

MockServerHttpRequest request = MockServerHttpRequest
.delete("http://domain1.com/test.html")
.header(HOST, "domain1.com")
.header(ORIGIN, "http://domain2.com")
.header("header2", "foo")
.build();
MockServerWebExchange exchange = new MockServerWebExchange(request);

WebFilterChain filterChain = (filterExchange) -> Mono.error(new AssertionError("Invalid requests must not be forwarded to the filter chain"));
filter.filter(exchange, filterChain);

assertNull(exchange.getResponse().getHeaders().getFirst(ACCESS_CONTROL_ALLOW_ORIGIN));
}

@Test
public void validPreFlightRequest() throws ServletException, IOException {

MockServerHttpRequest request = MockServerHttpRequest
.options("http://domain1.com/test.html")
.header(HOST, "domain1.com")
.header(ORIGIN, "http://domain2.com")
.header(ACCESS_CONTROL_REQUEST_METHOD, HttpMethod.GET.name())
.header(ACCESS_CONTROL_REQUEST_HEADERS, "header1, header2")
.build();
MockServerWebExchange exchange = new MockServerWebExchange(request);

WebFilterChain filterChain = (filterExchange) -> Mono.error(new AssertionError("Preflight requests must not be forwarded to the filter chain"));
filter.filter(exchange, filterChain);

assertEquals("http://domain2.com", exchange.getResponse().getHeaders().getFirst(ACCESS_CONTROL_ALLOW_ORIGIN));
assertEquals("header1, header2", exchange.getResponse().getHeaders().getFirst(ACCESS_CONTROL_ALLOW_HEADERS));
assertEquals("header3, header4", exchange.getResponse().getHeaders().getFirst(ACCESS_CONTROL_EXPOSE_HEADERS));
assertEquals(123L, Long.parseLong(exchange.getResponse().getHeaders().getFirst(ACCESS_CONTROL_MAX_AGE)));
}

@Test
public void invalidPreFlightRequest() throws ServletException, IOException {

MockServerHttpRequest request = MockServerHttpRequest
.options("http://domain1.com/test.html")
.header(HOST, "domain1.com")
.header(ORIGIN, "http://domain2.com")
.header(ACCESS_CONTROL_REQUEST_METHOD, HttpMethod.DELETE.name())
.header(ACCESS_CONTROL_REQUEST_HEADERS, "header1, header2")
.build();
MockServerWebExchange exchange = new MockServerWebExchange(request);

WebFilterChain filterChain = (filterExchange) -> Mono.error(new AssertionError("Preflight requests must not be forwarded to the filter chain"));
filter.filter(exchange, filterChain);

assertNull(exchange.getResponse().getHeaders().getFirst(ACCESS_CONTROL_ALLOW_ORIGIN));
}

}

0 comments on commit 1e04cdf

Please sign in to comment.