Skip to content

Commit

Permalink
Improve error handling in ThriftClientHandler
Browse files Browse the repository at this point in the history
If transport fails, and there are no pending requests new requests
could potentially hang forever.
  • Loading branch information
dain committed Mar 11, 2018
1 parent f73f4ea commit 0ace248
Showing 1 changed file with 36 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@

import javax.annotation.concurrent.ThreadSafe;

import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
Expand Down Expand Up @@ -95,27 +95,36 @@ private void sendMessage(ChannelHandlerContext context, ThriftRequest thriftRequ
}
}

// if this connection is failed, immediately fail the request
TException channelError = this.channelError.get();
if (channelError != null) {
thriftRequest.failed(channelError);
requestBuffer.release();
return;
}

try {
ChannelFuture sendFuture = context.write(requestBuffer, promise);
sendFuture.addListener(future -> messageSent(context, sendFuture, requestHandler));
}
catch (Throwable t) {
onError(context, t);
onError(context, t, Optional.of(requestHandler));
requestBuffer.release();
}
}

private void messageSent(ChannelHandlerContext context, ChannelFuture future, RequestHandler requestHandler)
{
try {
if (!future.isSuccess()) {
onError(context, new TTransportException("Sending request failed", future.cause()));
onError(context, new TTransportException("Sending request failed", future.cause()), Optional.of(requestHandler));
return;
}

requestHandler.onRequestSent();
}
catch (Throwable t) {
onError(context, t);
onError(context, t, Optional.of(requestHandler));
}
}

Expand All @@ -135,21 +144,22 @@ public void channelRead(ChannelHandlerContext context, Object message)

private void messageReceived(ChannelHandlerContext context, ByteBuf response)
{
RequestHandler requestHandler = null;
try {
OptionalInt sequenceId = messageEncoding.extractResponseSequenceId(response.retainedDuplicate());
if (!sequenceId.isPresent()) {
throw new TTransportException("Could not find sequenceId in Thrift message");
}

RequestHandler requestHandler = pendingRequests.remove(sequenceId.getAsInt());
requestHandler = pendingRequests.remove(sequenceId.getAsInt());
if (requestHandler == null) {
throw new TTransportException("Unknown sequence id in response: " + sequenceId.getAsInt());
}

requestHandler.onResponseReceived(response.retainedDuplicate());
}
catch (Throwable t) {
onError(context, t);
onError(context, t, Optional.ofNullable(requestHandler));
}
finally {
response.release();
Expand All @@ -159,19 +169,17 @@ private void messageReceived(ChannelHandlerContext context, ByteBuf response)
@Override
public void exceptionCaught(ChannelHandlerContext context, Throwable cause)
{
onError(context, cause);
onError(context, cause, Optional.empty());
}

@Override
public void channelInactive(ChannelHandlerContext context)
throws Exception
{
if (!pendingRequests.isEmpty()) {
onError(context, new TTransportException("Client was disconnected by server"));
}
onError(context, new TTransportException("Client was disconnected by server"), Optional.empty());
}

private void onError(ChannelHandlerContext context, Throwable throwable)
private void onError(ChannelHandlerContext context, Throwable throwable, Optional<RequestHandler> currentRequest)
{
TException thriftException;
if (throwable instanceof TException) {
Expand All @@ -187,13 +195,20 @@ private void onError(ChannelHandlerContext context, Throwable throwable)
return;
}

// current request may have already been removed from pendingRequests, so notify it directly
currentRequest.ifPresent(request -> {
pendingRequests.remove(request.getSequenceId());
request.onChannelError(thriftException);
});

// notify all pending requests of the error
// Note while loop should not be necessary since this class should be single
// threaded, but it is better to be safe in cleanup code
while (!pendingRequests.isEmpty()) {
for (Iterator<RequestHandler> iterator = pendingRequests.values().iterator(); iterator.hasNext(); ) {
RequestHandler requestHandler = iterator.next();
iterator.remove();
requestHandler.onChannelError(thriftException);
}
pendingRequests.values().removeIf(request -> {
request.onChannelError(thriftException);
return true;
});
}

context.close();
Expand Down Expand Up @@ -258,6 +273,11 @@ public RequestHandler(ThriftRequest thriftRequest, int sequenceId)
this.sequenceId = sequenceId;
}

public int getSequenceId()
{
return sequenceId;
}

void registerRequestTimeout(EventExecutor executor)
{
try {
Expand Down

0 comments on commit 0ace248

Please sign in to comment.