Skip to content

Commit

Permalink
SslHandler flushing with TCP Fast Open fix (netty#11077)
Browse files Browse the repository at this point in the history
Motivation:
SslHandler owns the responsibility to flush non-application data
(e.g. handshake, renegotiation, etc.) to the socket. However when
TCP Fast Open is supported but the client_hello cannot be written
in the SYN the client_hello may not always be flushed. SslHandler
may not wrap/flush previously written/flushed data in the event
it was not able to be wrapped due to NEED_UNWRAP state being
encountered in wrap (e.g. peer initiated renegotiation).

Modifications:
- SslHandler to flush in channelActive() if TFO is enabled and
  the client_hello cannot be written in the SYN.
- SslHandler to wrap application data after non-application data
  wrap and handshake status is FINISHED.
- SocketSslEchoTest only flushes when writes are done, and waits
  for the handshake to complete before writing.

Result:
SslHandler flushes handshake data for TFO, and previously flushed
application data after peer initiated renegotiation finishes.
  • Loading branch information
Scottmitch authored Mar 14, 2021
1 parent 3f97501 commit 0b0c234
Show file tree
Hide file tree
Showing 12 changed files with 155 additions and 119 deletions.
54 changes: 32 additions & 22 deletions handler/src/main/java/io/netty/handler/ssl/SslHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandler;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelOutboundBuffer;
import io.netty.channel.ChannelOutboundHandler;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise;
Expand Down Expand Up @@ -976,7 +977,7 @@ private boolean wrapNonAppData(final ChannelHandlerContext ctx, boolean inUnwrap
out = allocateOutNetBuf(ctx, 2048, 1);
}
SSLEngineResult result = wrap(alloc, engine, Unpooled.EMPTY_BUFFER, out);

HandshakeStatus status = result.getHandshakeStatus();
if (result.bytesProduced() > 0) {
ctx.write(out).addListener(new ChannelFutureListener() {
@Override
Expand All @@ -988,12 +989,18 @@ public void operationComplete(ChannelFuture future) {
}
});
if (inUnwrap) {
// We may be here because we read data and discovered the remote peer initiated a renegotiation
// and this write is to complete the new handshake. The user may have previously done a
// writeAndFlush which wasn't able to wrap data due to needing the pending handshake, so we
// attempt to wrap application data here if any is pending.
if (status == HandshakeStatus.FINISHED && !pendingUnencryptedWrites.isEmpty()) {
wrap(ctx, true);
}
needsFlush = true;
}
out = null;
}

HandshakeStatus status = result.getHandshakeStatus();
switch (status) {
case FINISHED:
setHandshakeSuccess();
Expand Down Expand Up @@ -1790,40 +1797,36 @@ public void run() {
* marked as success by this method
*/
private boolean setHandshakeSuccessIfStillHandshaking() {
if (!handshakePromise.isDone()) {
setHandshakeSuccess();
return true;
}
return false;
return setHandshakeSuccess();
}

/**
* Notify all the handshake futures about the successfully handshake
* @return {@code true} if {@link #handshakePromise} was set successfully and a {@link SslHandshakeCompletionEvent}
* was fired. {@code false} otherwise.
*/
private void setHandshakeSuccess() {
boolean notified = handshakePromise.trySuccess(ctx.channel());
SSLSession session = engine.getSession();

// There seems to be a bug in the SSLEngineImpl that is part of the OpenJDK that results in returning
// HandshakeStatus.FINISHED multiple times which is not expected. This only happens in TLSv1.3 so lets
// ensure we only notify once in this case.
//
// This is safe as TLSv1.3 does not support renegotiation and so we should never see two handshake events.
if (notified || !SslUtils.PROTOCOL_TLS_V1_3.equals(session.getProtocol())) {
private boolean setHandshakeSuccess() {
if (readDuringHandshake && !ctx.channel().config().isAutoRead()) {
readDuringHandshake = false;
ctx.read();
}
// Our control flow may invoke this method multiple times for a single FINISHED event. For example
// wrapNonAppData may drain pendingUnencryptedWrites in wrap which transitions to handshake from FINISHED to
// NOT_HANDSHAKING which invokes setHandshakeSuccessIfStillHandshaking, and then wrapNonAppData also directly
// invokes this method.
if (handshakePromise.trySuccess(ctx.channel())) {
if (logger.isDebugEnabled()) {
SSLSession session = engine.getSession();
logger.debug(
"{} HANDSHAKEN: protocol:{} cipher suite:{}",
ctx.channel(),
session.getProtocol(),
session.getCipherSuite());
}
ctx.fireUserEventTriggered(SslHandshakeCompletionEvent.SUCCESS);
return true;
}

if (readDuringHandshake && !ctx.channel().config().isAutoRead()) {
readDuringHandshake = false;
ctx.read();
}
return false;
}

/**
Expand Down Expand Up @@ -1967,6 +1970,11 @@ public void handlerAdded(final ChannelHandlerContext ctx) throws Exception {
// With TCP Fast Open, we write to the outbound buffer before the TCP connect is established.
// The buffer will then be flushed as part of establishing the connection, saving us a round-trip.
startHandshakeProcessing(active);
// If we weren't able to include client_hello in the TCP SYN (e.g. no token, disabled at the OS) we have to
// flush pending data in the outbound buffer later in channelActive().
final ChannelOutboundBuffer outboundBuffer;
needsFlush |= fastOpen && ((outboundBuffer = channel.unsafe().outboundBuffer()) == null ||
outboundBuffer.totalPendingWriteBytes() > 0);
}
}

Expand All @@ -1980,6 +1988,8 @@ private void startHandshakeProcessing(boolean flushAtEnd) {
handshake(flushAtEnd);
}
applyHandshakeTimeout();
} else if (needsFlush) {
forceFlush(ctx);
}
}

Expand Down
10 changes: 7 additions & 3 deletions handler/src/test/java/io/netty/handler/ssl/RenegotiateTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ public void testRenegotiateServer() throws Throwable {
.childHandler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
ch.pipeline().addLast(context.newHandler(ch.alloc()));
SslHandler handler = context.newHandler(ch.alloc());
handler.setHandshakeTimeoutMillis(0);
ch.pipeline().addLast(handler);
ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
private boolean renegotiate;

Expand All @@ -79,9 +81,9 @@ public void userEventTriggered(
public void operationComplete(Future<Channel> future) throws Exception {
if (!future.isSuccess()) {
error.compareAndSet(null, future.cause());
latch.countDown();
ctx.close();
}
latch.countDown();
}
});
} else {
Expand All @@ -108,7 +110,9 @@ public void operationComplete(Future<Channel> future) throws Exception {
.handler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
ch.pipeline().addLast(clientContext.newHandler(ch.alloc()));
SslHandler handler = clientContext.newHandler(ch.alloc());
handler.setHandshakeTimeoutMillis(0);
ch.pipeline().addLast(handler);
ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
@Override
public void userEventTriggered(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandler.Sharable;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.SimpleChannelInboundHandler;
Expand Down Expand Up @@ -52,10 +53,12 @@
import java.util.Collection;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;

import javax.net.ssl.SSLEngine;

import static org.hamcrest.MatcherAssert.assertThat;
Expand Down Expand Up @@ -255,8 +258,7 @@ public void testSslEcho(ServerBootstrap sb, Bootstrap cb) throws Throwable {

sb.childHandler(new ChannelInitializer<Channel>() {
@Override
@SuppressWarnings("deprecation")
public void initChannel(Channel sch) throws Exception {
public void initChannel(Channel sch) {
serverChannel = sch;

if (serverUsesDelegatedTaskExecutor) {
Expand All @@ -265,6 +267,7 @@ public void initChannel(Channel sch) throws Exception {
} else {
serverSslHandler = serverCtx.newHandler(sch.alloc());
}
serverSslHandler.setHandshakeTimeoutMillis(0);

sch.pipeline().addLast("ssl", serverSslHandler);
if (useChunkedWriteHandler) {
Expand All @@ -274,10 +277,10 @@ public void initChannel(Channel sch) throws Exception {
}
});

final CountDownLatch clientHandshakeEventLatch = new CountDownLatch(1);
cb.handler(new ChannelInitializer<Channel>() {
@Override
@SuppressWarnings("deprecation")
public void initChannel(Channel sch) throws Exception {
public void initChannel(Channel sch) {
clientChannel = sch;

if (clientUsesDelegatedTaskExecutor) {
Expand All @@ -286,12 +289,22 @@ public void initChannel(Channel sch) throws Exception {
} else {
clientSslHandler = clientCtx.newHandler(sch.alloc());
}
clientSslHandler.setHandshakeTimeoutMillis(0);

sch.pipeline().addLast("ssl", clientSslHandler);
if (useChunkedWriteHandler) {
sch.pipeline().addLast(new ChunkedWriteHandler());
}
sch.pipeline().addLast("handler", clientHandler);
sch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
if (evt instanceof SslHandshakeCompletionEvent) {
clientHandshakeEventLatch.countDown();
}
ctx.fireUserEventTriggered(evt);
}
});
}
});

Expand All @@ -300,9 +313,12 @@ public void initChannel(Channel sch) throws Exception {

final Future<Channel> clientHandshakeFuture = clientSslHandler.handshakeFuture();

// Wait for the handshake to complete before we flush anything. SslHandler should flush non-application data.
clientHandshakeFuture.sync();
clientHandshakeEventLatch.await();

clientChannel.writeAndFlush(Unpooled.wrappedBuffer(data, 0, FIRST_MESSAGE_SIZE));
clientSendCounter.set(FIRST_MESSAGE_SIZE);
clientHandshakeFuture.sync();

boolean needsRenegotiation = renegotiation.type == RenegotiationType.CLIENT_INITIATED;
Future<Channel> renegoFuture = null;
Expand Down Expand Up @@ -457,21 +473,21 @@ public void channelActive(ChannelHandlerContext ctx) throws Exception {
if (!autoRead) {
ctx.read();
}
ctx.fireChannelActive();
}

@Override
public final void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
try {
ctx.flush();
} finally {
if (!autoRead) {
ctx.read();
}
// We intentionally do not ctx.flush() here because we want to verify the SslHandler correctly flushing
// non-application and previously flushed writes internally.
if (!autoRead) {
ctx.read();
}
ctx.fireChannelReadComplete();
}

@Override
public final void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
public final void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
if (evt instanceof SslHandshakeCompletionEvent) {
SslHandshakeCompletionEvent handshakeEvt = (SslHandshakeCompletionEvent) evt;
if (handshakeEvt.cause() != null) {
Expand All @@ -481,6 +497,7 @@ public final void userEventTriggered(ChannelHandlerContext ctx, Object evt) thro
negoCounter.incrementAndGet();
logStats("HANDSHAKEN");
}
ctx.fireUserEventTriggered(evt);
}

@Override
Expand Down Expand Up @@ -528,7 +545,7 @@ private class EchoServerHandler extends EchoHandler {
}

@Override
public final void channelRegistered(ChannelHandlerContext ctx) throws Exception {
public final void channelRegistered(ChannelHandlerContext ctx) {
renegoFuture = null;
}

Expand All @@ -546,7 +563,7 @@ public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception
if (useCompositeByteBuf) {
buf = Unpooled.compositeBuffer().addComponent(true, buf);
}
ctx.write(buf);
ctx.writeAndFlush(buf);

recvCounter.addAndGet(actual.length);

Expand Down
9 changes: 3 additions & 6 deletions transport-native-epoll/src/main/c/netty_epoll_native.c
Original file line number Diff line number Diff line change
Expand Up @@ -492,13 +492,10 @@ static jboolean netty_epoll_native_isSupportingRecvmmsg(JNIEnv* env, jclass claz
return JNI_TRUE;
}

static jboolean netty_epoll_native_isSupportingTcpFastopen(JNIEnv* env, jclass clazz) {
static jint netty_epoll_native_tcpFastopenMode(JNIEnv* env, jclass clazz) {
int fastopen = 0;
getSysctlValue("/proc/sys/net/ipv4/tcp_fastopen", &fastopen);
if (fastopen > 0) {
return JNI_TRUE;
}
return JNI_FALSE;
return fastopen;
}

static jint netty_epoll_native_epollet(JNIEnv* env, jclass clazz) {
Expand Down Expand Up @@ -577,7 +574,7 @@ static const JNINativeMethod statically_referenced_fixed_method_table[] = {
{ "tcpMd5SigMaxKeyLen", "()I", (void *) netty_epoll_native_tcpMd5SigMaxKeyLen },
{ "isSupportingSendmmsg", "()Z", (void *) netty_epoll_native_isSupportingSendmmsg },
{ "isSupportingRecvmmsg", "()Z", (void *) netty_epoll_native_isSupportingRecvmmsg },
{ "isSupportingTcpFastopen", "()Z", (void *) netty_epoll_native_isSupportingTcpFastopen },
{ "tcpFastopenMode", "()I", (void *) netty_epoll_native_tcpFastopenMode },
{ "kernelVersion", "()Ljava/lang/String;", (void *) netty_epoll_native_kernelVersion }
};
static const jint statically_referenced_fixed_method_table_size = sizeof(statically_referenced_fixed_method_table) / sizeof(statically_referenced_fixed_method_table[0]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import java.util.Map;

import static io.netty.channel.epoll.LinuxSocket.newSocketStream;
import static io.netty.channel.epoll.Native.IS_SUPPORTING_TCP_FASTOPEN_SERVER;
import static io.netty.channel.unix.NativeInetAddress.address;

/**
Expand Down Expand Up @@ -68,8 +69,9 @@ protected boolean isCompatible(EventLoop loop) {
@Override
protected void doBind(SocketAddress localAddress) throws Exception {
super.doBind(localAddress);
if (Native.IS_SUPPORTING_TCP_FASTOPEN && config.getTcpFastopen() > 0) {
socket.setTcpFastOpen(config.getTcpFastopen());
final int tcpFastopen;
if (IS_SUPPORTING_TCP_FASTOPEN_SERVER && (tcpFastopen = config.getTcpFastopen()) > 0) {
socket.setTcpFastOpen(tcpFastopen);
}
socket.listen(config.getBacklog());
active = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import java.util.concurrent.Executor;

import static io.netty.channel.epoll.LinuxSocket.newSocketStream;
import static io.netty.channel.epoll.Native.IS_SUPPORTING_TCP_FASTOPEN_CLIENT;

/**
* {@link SocketChannel} implementation that uses linux EPOLL Edge-Triggered Mode for
Expand Down Expand Up @@ -116,7 +117,7 @@ protected AbstractEpollUnsafe newUnsafe() {

@Override
boolean doConnect0(SocketAddress remote) throws Exception {
if (Native.IS_SUPPORTING_TCP_FASTOPEN && config.isTcpFastOpenConnect()) {
if (IS_SUPPORTING_TCP_FASTOPEN_CLIENT && config.isTcpFastOpenConnect()) {
ChannelOutboundBuffer outbound = unsafe().outboundBuffer();
outbound.addFlush();
Object curr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
import static io.netty.channel.epoll.NativeStaticallyReferencedJniMethods.epollrdhup;
import static io.netty.channel.epoll.NativeStaticallyReferencedJniMethods.isSupportingRecvmmsg;
import static io.netty.channel.epoll.NativeStaticallyReferencedJniMethods.isSupportingSendmmsg;
import static io.netty.channel.epoll.NativeStaticallyReferencedJniMethods.isSupportingTcpFastopen;
import static io.netty.channel.epoll.NativeStaticallyReferencedJniMethods.kernelVersion;
import static io.netty.channel.epoll.NativeStaticallyReferencedJniMethods.tcpFastopenMode;
import static io.netty.channel.epoll.NativeStaticallyReferencedJniMethods.tcpMd5SigMaxKeyLen;
import static io.netty.channel.unix.Errors.ioResult;
import static io.netty.channel.unix.Errors.newIOException;
Expand Down Expand Up @@ -97,7 +97,27 @@ public void run() {
public static final boolean IS_SUPPORTING_SENDMMSG = isSupportingSendmmsg();
static final boolean IS_SUPPORTING_RECVMMSG = isSupportingRecvmmsg();
static final boolean IS_SUPPORTING_UDP_SEGMENT = isSupportingUdpSegment();
public static final boolean IS_SUPPORTING_TCP_FASTOPEN = isSupportingTcpFastopen();
private static final int TFO_ENABLED_CLIENT_MASK = 0x1;
private static final int TFO_ENABLED_SERVER_MASK = 0x2;
private static final int TCP_FASTOPEN_MODE = tcpFastopenMode();
/**
* <a href ="https://www.kernel.org/doc/Documentation/networking/ip-sysctl.txt">tcp_fastopen</a> client mode enabled
* state.
*/
static final boolean IS_SUPPORTING_TCP_FASTOPEN_CLIENT =
(TCP_FASTOPEN_MODE & TFO_ENABLED_CLIENT_MASK) == TFO_ENABLED_CLIENT_MASK;
/**
* <a href ="https://www.kernel.org/doc/Documentation/networking/ip-sysctl.txt">tcp_fastopen</a> server mode enabled
* state.
*/
static final boolean IS_SUPPORTING_TCP_FASTOPEN_SERVER =
(TCP_FASTOPEN_MODE & TFO_ENABLED_SERVER_MASK) == TFO_ENABLED_SERVER_MASK;
/**
* @deprecated Use {@link #IS_SUPPORTING_TCP_FASTOPEN_CLIENT} or {@link #IS_SUPPORTING_TCP_FASTOPEN_SERVER}.
*/
@Deprecated
public static final boolean IS_SUPPORTING_TCP_FASTOPEN = IS_SUPPORTING_TCP_FASTOPEN_CLIENT ||
IS_SUPPORTING_TCP_FASTOPEN_SERVER;
public static final int TCP_MD5SIG_MAXKEYLEN = tcpMd5SigMaxKeyLen();
public static final String KERNEL_VERSION = kernelVersion();

Expand Down
Loading

0 comments on commit 0b0c234

Please sign in to comment.