Skip to content

Commit

Permalink
Verify ssl key alias on server startup
Browse files Browse the repository at this point in the history
  • Loading branch information
cbo-indeed authored and mbhave committed Feb 12, 2020
1 parent 747eab0 commit e351605
Show file tree
Hide file tree
Showing 10 changed files with 256 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.eclipse.jetty.http.HttpVersion;
import org.eclipse.jetty.http2.HTTP2Cipher;
import org.eclipse.jetty.http2.server.HTTP2ServerConnectionFactory;
import org.eclipse.jetty.server.ConnectionFactory;
import org.eclipse.jetty.server.Connector;
import org.eclipse.jetty.server.HttpConfiguration;
import org.eclipse.jetty.server.HttpConnectionFactory;
Expand All @@ -37,6 +38,7 @@
import org.springframework.boot.web.server.Http2;
import org.springframework.boot.web.server.Ssl;
import org.springframework.boot.web.server.SslStoreProvider;
import org.springframework.boot.web.server.SslUtils;
import org.springframework.boot.web.server.WebServerException;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
Expand Down Expand Up @@ -105,7 +107,8 @@ private ServerConnector createHttp11ServerConnector(Server server, HttpConfigura
HttpConnectionFactory connectionFactory = new HttpConnectionFactory(config);
SslConnectionFactory sslConnectionFactory = new SslConnectionFactory(sslContextFactory,
HttpVersion.HTTP_1_1.asString());
return new ServerConnector(server, sslConnectionFactory, connectionFactory);
return new SslValidatingServerConnector(server, sslContextFactory, this.ssl.getKeyAlias(), sslConnectionFactory,
connectionFactory);
}

private boolean isAlpnPresent() {
Expand All @@ -123,7 +126,8 @@ private ServerConnector createHttp2ServerConnector(Server server, HttpConfigurat
sslContextFactory.setCipherComparator(HTTP2Cipher.COMPARATOR);
sslContextFactory.setProvider("Conscrypt");
SslConnectionFactory ssl = new SslConnectionFactory(sslContextFactory, alpn.getProtocol());
return new ServerConnector(server, ssl, alpn, h2, new HttpConnectionFactory(config));
return new SslValidatingServerConnector(server, sslContextFactory, this.ssl.getKeyAlias(), ssl, alpn, h2,
new HttpConnectionFactory(config));
}

/**
Expand Down Expand Up @@ -215,4 +219,35 @@ private void configureSslTrustStore(SslContextFactory factory, Ssl ssl) {
}
}

/**
* A {@link ServerConnector} that validates the ssl key alias on server startup.
*/
static class SslValidatingServerConnector extends ServerConnector {

private SslContextFactory sslContextFactory;

private String keyAlias;

SslValidatingServerConnector(Server server, SslContextFactory sslContextFactory, String keyAlias,
SslConnectionFactory sslConnectionFactory, HttpConnectionFactory connectionFactory) {
super(server, sslConnectionFactory, connectionFactory);
this.sslContextFactory = sslContextFactory;
this.keyAlias = keyAlias;
}

SslValidatingServerConnector(Server server, SslContextFactory sslContextFactory, String keyAlias,
ConnectionFactory... factories) {
super(server, factories);
this.sslContextFactory = sslContextFactory;
this.keyAlias = keyAlias;
}

@Override
protected void doStart() throws Exception {
super.doStart();
SslUtils.assertStoreContainsAlias(this.sslContextFactory.getKeyStore(), this.keyAlias);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import org.springframework.boot.web.server.Http2;
import org.springframework.boot.web.server.Ssl;
import org.springframework.boot.web.server.SslStoreProvider;
import org.springframework.boot.web.server.SslUtils;
import org.springframework.boot.web.server.WebServerException;
import org.springframework.util.ResourceUtils;

Expand Down Expand Up @@ -106,6 +107,8 @@ else if (this.ssl.getClientAuth() == Ssl.ClientAuth.WANT) {
protected KeyManagerFactory getKeyManagerFactory(Ssl ssl, SslStoreProvider sslStoreProvider) {
try {
KeyStore keyStore = getKeyStore(ssl, sslStoreProvider);
SslUtils.assertStoreContainsAlias(keyStore, ssl.getKeyAlias());

KeyManagerFactory keyManagerFactory = (ssl.getKeyAlias() == null)
? KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
: new ConfigurableAliasKeyManagerFactory(ssl.getKeyAlias(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@

import org.springframework.boot.web.server.Ssl;
import org.springframework.boot.web.server.SslStoreProvider;
import org.springframework.boot.web.server.SslUtils;
import org.springframework.boot.web.server.WebServerException;
import org.springframework.util.ResourceUtils;

Expand Down Expand Up @@ -107,6 +108,8 @@ private SslClientAuthMode getSslClientAuthMode(Ssl ssl) {
private KeyManager[] getKeyManagers(Ssl ssl, SslStoreProvider sslStoreProvider) {
try {
KeyStore keyStore = getKeyStore(ssl, sslStoreProvider);
SslUtils.assertStoreContainsAlias(keyStore, ssl.getKeyAlias());

KeyManagerFactory keyManagerFactory = KeyManagerFactory
.getInstance(KeyManagerFactory.getDefaultAlgorithm());
char[] keyPassword = (ssl.getKeyPassword() != null) ? ssl.getKeyPassword().toCharArray() : null;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright 2012-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.boot.web.server;

import java.security.KeyStore;
import java.security.KeyStoreException;

import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

/**
* Provides utilities around SSL.
*
* @author Chris Bono
* @since 2.1.x
*/
public final class SslUtils {

private SslUtils() {
}

public static void assertStoreContainsAlias(KeyStore keyStore, String keyAlias) {
if (!StringUtils.isEmpty(keyAlias)) {
try {
Assert.state(keyStore.containsAlias(keyAlias),
() -> String.format("Keystore does not contain specified alias '%s'", keyAlias));
}
catch (KeyStoreException ex) {
throw new IllegalStateException(
String.format("Could not determine if keystore contains alias '%s'", keyAlias), ex);
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
import java.time.Duration;
import java.util.Arrays;

import javax.net.ssl.SSLHandshakeException;

import org.junit.Test;
import org.mockito.InOrder;
import reactor.core.publisher.Mono;
Expand Down Expand Up @@ -101,14 +99,6 @@ public void whenSslIsConfiguredWithAValidAliasARequestSucceeds() {
StepVerifier.create(result).expectNext("Hello World").verifyComplete();
}

@Test
public void whenSslIsConfiguredWithAnInvalidAliasTheSslHandshakeFails() {
Mono<String> result = testSslWithAlias("test-alias-bad");
StepVerifier.setDefaultTimeout(Duration.ofSeconds(30));
StepVerifier.create(result).expectErrorMatches((throwable) -> throwable instanceof SSLHandshakeException
&& throwable.getMessage().contains("HANDSHAKE_FAILURE")).verify();
}

protected Mono<String> testSslWithAlias(String alias) {
String keyStore = "classpath:test.jks";
String keyPassword = "password";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,14 @@
import org.mockito.ArgumentCaptor;
import org.mockito.InOrder;

import org.springframework.boot.web.reactive.server.AbstractReactiveWebServerFactory;
import org.springframework.boot.web.reactive.server.AbstractReactiveWebServerFactoryTests;
import org.springframework.boot.web.server.Ssl;
import org.springframework.http.server.reactive.HttpHandler;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.mock;
Expand Down Expand Up @@ -166,4 +169,19 @@ public void compressionOfResponseToGetRequest() {
public void compressionOfResponseToPostRequest() {
}

@Test
@Override
public void sslWithInvalidAliasFailsDuringStartup() {
String keyStore = "classpath:test.jks";
String keyPassword = "password";
AbstractReactiveWebServerFactory factory = getFactory();
Ssl ssl = new Ssl();
ssl.setKeyStore(keyStore);
ssl.setKeyPassword(keyPassword);
ssl.setKeyAlias("test-alias-404");
factory.setSsl(ssl);
assertThatThrownBy(() -> factory.getWebServer(new EchoHandler()).start())
.isInstanceOf(ConnectorStartFailedException.class);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,11 @@
import org.mockito.InOrder;

import org.springframework.boot.testsupport.rule.OutputCapture;
import org.springframework.boot.testsupport.web.servlet.ExampleServlet;
import org.springframework.boot.web.server.Ssl;
import org.springframework.boot.web.server.WebServerException;
import org.springframework.boot.web.servlet.ServletContextInitializer;
import org.springframework.boot.web.servlet.ServletRegistrationBean;
import org.springframework.boot.web.servlet.server.AbstractServletWebServerFactory;
import org.springframework.boot.web.servlet.server.AbstractServletWebServerFactoryTests;
import org.springframework.core.io.ByteArrayResource;
Expand All @@ -81,6 +84,7 @@
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.mock;
Expand Down Expand Up @@ -529,6 +533,18 @@ public void onStartup(ServletContext servletContext) throws ServletException {
this.webServer.start();
}

@Test
@Override
public void sslWithInvalidAliasFailsDuringStartup() {
AbstractServletWebServerFactory factory = getFactory();
Ssl ssl = getSsl(null, "password", "test-alias-404", "src/test/resources/test.jks");
factory.setSsl(ssl);
ServletRegistrationBean<ExampleServlet> registration = new ServletRegistrationBean<>(
new ExampleServlet(true, false), "/hello");
assertThatThrownBy(() -> factory.getWebServer(registration).start())
.isInstanceOf(ConnectorStartFailedException.class);
}

@Override
protected JspServlet getJspServlet() throws ServletException {
Tomcat tomcat = ((TomcatWebServer) this.webServer).getTomcat();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,44 @@ protected final void testBasicSslWithKeyStore(String keyStore, String keyPasswor
assertThat(result.block(Duration.ofSeconds(30))).isEqualTo("Hello World");
}

@Test
public void sslWithValidAlias() {
String keyStore = "classpath:test.jks";
String keyPassword = "password";
AbstractReactiveWebServerFactory factory = getFactory();
Ssl ssl = new Ssl();
ssl.setKeyStore(keyStore);
ssl.setKeyPassword(keyPassword);
ssl.setKeyAlias("test-alias");
factory.setSsl(ssl);
this.webServer = factory.getWebServer(new EchoHandler());
this.webServer.start();
ReactorClientHttpConnector connector = buildTrustAllSslConnector();
WebClient client = WebClient.builder().baseUrl("https://localhost:" + this.webServer.getPort())
.clientConnector(connector).build();

Mono<String> result = client.post().uri("/test").contentType(MediaType.TEXT_PLAIN)
.body(BodyInserters.fromObject("Hello World")).exchange()
.flatMap((response) -> response.bodyToMono(String.class));

StepVerifier.setDefaultTimeout(Duration.ofSeconds(30));
StepVerifier.create(result).expectNext("Hello World").verifyComplete();
}

@Test
public void sslWithInvalidAliasFailsDuringStartup() {
String keyStore = "classpath:test.jks";
String keyPassword = "password";
AbstractReactiveWebServerFactory factory = getFactory();
Ssl ssl = new Ssl();
ssl.setKeyStore(keyStore);
ssl.setKeyPassword(keyPassword);
ssl.setKeyAlias("test-alias-404");
factory.setSsl(ssl);
assertThatThrownBy(() -> factory.getWebServer(new EchoHandler()).start())
.hasStackTraceContaining("Keystore does not contain specified alias 'test-alias-404'");
}

protected ReactorClientHttpConnector buildTrustAllSslConnector() {
SslContextBuilder builder = SslContextBuilder.forClient().sslProvider(SslProvider.JDK)
.trustManager(InsecureTrustManagerFactory.INSTANCE);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Copyright 2012-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.boot.web.server;

import java.io.File;
import java.io.FileInputStream;
import java.security.KeyStore;
import java.security.KeyStoreException;

import org.junit.Before;
import org.junit.Test;

import static org.assertj.core.api.Assertions.assertThatThrownBy;

/**
* Tests for {@link SslUtils}.
*
* @author Chris Bono
*/

public class SslUtilsTest {

private static final String VALID_ALIAS = "test-alias";

private static final String INVALID_ALIAS = "test-alias-5150";

private KeyStore keyStore;

@Before
public void loadKeystore() throws Exception {
this.keyStore = KeyStore.getInstance(KeyStore.getDefaultType());
this.keyStore.load(new FileInputStream(new File("src/test/resources/test.jks")), "secret".toCharArray());
}

@Test
public void assertStoreContainsAliasPassesWhenAliasFound() throws KeyStoreException {
SslUtils.assertStoreContainsAlias(this.keyStore, VALID_ALIAS);
}

@Test
public void assertStoreContainsAliasPassesWhenNullAlias() throws KeyStoreException {
SslUtils.assertStoreContainsAlias(this.keyStore, null);
}

@Test
public void assertStoreContainsAliasPassesWhenEmptyAlias() throws KeyStoreException {
SslUtils.assertStoreContainsAlias(this.keyStore, "");
}

@Test
public void assertStoreContainsAliasFailsWhenAliasNotFound() throws KeyStoreException {
assertThatThrownBy(() -> SslUtils.assertStoreContainsAlias(this.keyStore, INVALID_ALIAS))
.isInstanceOf(IllegalStateException.class)
.hasMessage("Keystore does not contain specified alias '" + INVALID_ALIAS + "'");
}

@Test
public void assertStoreContainsAliasFailsWhenKeyStoreThrowsExceptionOnContains() throws KeyStoreException {
KeyStore uninitializedKeyStore = KeyStore.getInstance(KeyStore.getDefaultType());
assertThatThrownBy(() -> SslUtils.assertStoreContainsAlias(uninitializedKeyStore, "alias"))
.isInstanceOf(IllegalStateException.class)
.hasMessage("Could not determine if keystore contains alias 'alias'");
}

}
Loading

0 comments on commit e351605

Please sign in to comment.