Skip to content

Commit

Permalink
Add TestingPooledByteBufAllocator to catch leaks
Browse files Browse the repository at this point in the history
The TestingPooledByteBufAllocator tracks all buffers and when closed
ensures that all buffers have zero outstanding references.
  • Loading branch information
dain committed May 4, 2018
1 parent cdb9f9b commit 97cb22e
Show file tree
Hide file tree
Showing 13 changed files with 212 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import io.airlift.drift.integration.scribe.drift.DriftLogEntry;
import io.airlift.drift.integration.scribe.drift.DriftScribe;
import io.airlift.drift.transport.client.DriftClientConfig;
import io.airlift.drift.transport.netty.buffer.TestingPooledByteBufAllocator;
import io.airlift.drift.transport.netty.client.DriftNettyClientConfig;
import io.airlift.drift.transport.netty.client.DriftNettyClientModule;
import io.airlift.drift.transport.netty.client.DriftNettyConnectionFactoryConfig;
Expand Down Expand Up @@ -80,9 +81,11 @@ private static int logNettyDriftClient(
.setTrustCertificate(ClientTestUtils.getCertificateChainFile())
.setSslEnabled(secure);

try (DriftNettyMethodInvokerFactory<String> methodInvokerFactory = new DriftNettyMethodInvokerFactory<>(
new DriftNettyConnectionFactoryConfig().setConnectionPoolEnabled(true),
clientIdentity -> config)) {
try (TestingPooledByteBufAllocator testingAllocator = new TestingPooledByteBufAllocator();
DriftNettyMethodInvokerFactory<String> methodInvokerFactory = new DriftNettyMethodInvokerFactory<>(
new DriftNettyConnectionFactoryConfig().setConnectionPoolEnabled(true),
clientIdentity -> config,
testingAllocator)) {
DriftClientFactoryManager<String> clientFactoryManager = new DriftClientFactoryManager<>(CODEC_MANAGER, methodInvokerFactory);
DriftClientFactory proxyFactory = clientFactoryManager.createDriftClientFactory("clientIdentity", addressSelector, NORMAL_RESULT);

Expand Down Expand Up @@ -116,7 +119,8 @@ private static int logNettyStaticDriftClient(
.setTrustCertificate(ClientTestUtils.getCertificateChainFile())
.setSslEnabled(secure);

try (DriftNettyMethodInvokerFactory<?> methodInvokerFactory = createStaticDriftNettyMethodInvokerFactory(config)) {
try (TestingPooledByteBufAllocator testingAllocator = new TestingPooledByteBufAllocator();
DriftNettyMethodInvokerFactory<?> methodInvokerFactory = createStaticDriftNettyMethodInvokerFactory(config, testingAllocator)) {
DriftClientFactory proxyFactory = new DriftClientFactory(CODEC_MANAGER, methodInvokerFactory, addressSelector, NORMAL_RESULT);

DriftScribe scribe = proxyFactory.createDriftClient(DriftScribe.class, Optional.empty(), filters, new DriftClientConfig()).get();
Expand Down Expand Up @@ -149,9 +153,11 @@ private static int logNettyDriftClientAsync(
.setTrustCertificate(ClientTestUtils.getCertificateChainFile())
.setSslEnabled(secure);

try (DriftNettyMethodInvokerFactory<String> methodInvokerFactory = new DriftNettyMethodInvokerFactory<>(
new DriftNettyConnectionFactoryConfig().setConnectionPoolEnabled(true),
clientIdentity -> config)) {
try (TestingPooledByteBufAllocator testingAllocator = new TestingPooledByteBufAllocator();
DriftNettyMethodInvokerFactory<String> methodInvokerFactory = new DriftNettyMethodInvokerFactory<>(
new DriftNettyConnectionFactoryConfig().setConnectionPoolEnabled(true),
clientIdentity -> config,
testingAllocator)) {
DriftClientFactoryManager<String> proxyFactoryManager = new DriftClientFactoryManager<>(CODEC_MANAGER, methodInvokerFactory);
DriftClientFactory proxyFactory = proxyFactoryManager.createDriftClientFactory("myFactory", addressSelector, NORMAL_RESULT);

Expand All @@ -178,7 +184,9 @@ private static int logNettyClientBinder(
return 0;
}

return logDriftClientBinder(address, headerValue, entries, new DriftNettyClientModule(), filters, transport, protocol, secure);
try (TestingPooledByteBufAllocator testingAllocator = new TestingPooledByteBufAllocator()) {
return logDriftClientBinder(address, headerValue, entries, new DriftNettyClientModule(testingAllocator), filters, transport, protocol, secure);
}
}

private static boolean isValidConfiguration(Transport transport, Protocol protocol)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import io.airlift.drift.server.DriftServer;
import io.airlift.drift.server.DriftService;
import io.airlift.drift.server.stats.NullMethodInvocationStatsFactory;
import io.airlift.drift.transport.netty.buffer.TestingPooledByteBufAllocator;
import io.airlift.drift.transport.netty.codec.Protocol;
import io.airlift.drift.transport.netty.codec.Transport;
import io.airlift.drift.transport.netty.server.DriftNettyServerConfig;
Expand Down Expand Up @@ -110,8 +111,9 @@ private static void testDriftServer(DriftService service, Consumer<HostAndPort>
.setSslEnabled(true)
.setTrustCertificate(ClientTestUtils.getCertificateChainFile())
.setKey(ClientTestUtils.getPrivateKeyFile());
TestingPooledByteBufAllocator testingAllocator = new TestingPooledByteBufAllocator();
DriftServer driftServer = new DriftServer(
new DriftNettyServerTransportFactory(config),
new DriftNettyServerTransportFactory(config, testingAllocator),
CODEC_MANAGER,
new NullMethodInvocationStatsFactory(),
ImmutableSet.of(service),
Expand All @@ -125,6 +127,7 @@ private static void testDriftServer(DriftService service, Consumer<HostAndPort>
}
finally {
driftServer.shutdown();
testingAllocator.close();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import io.airlift.drift.integration.scribe.drift.DriftResultCode;
import io.airlift.drift.transport.MethodMetadata;
import io.airlift.drift.transport.ParameterMetadata;
import io.airlift.drift.transport.netty.buffer.TestingPooledByteBufAllocator;
import io.airlift.drift.transport.netty.codec.Protocol;
import io.airlift.drift.transport.netty.codec.Transport;
import io.airlift.drift.transport.netty.server.DriftNettyServerConfig;
Expand Down Expand Up @@ -60,14 +61,12 @@ public class TestClientsWithDriftNettyServerTransport
{
@Test
public void testDriftServer()
throws Exception
{
testDriftServer(ImmutableList.of());
}

@Test
public void testHandlersWithDriftServer()
throws Exception
{
TestFilter firstFilter = new TestFilter();
TestFilter secondFilter = new TestFilter();
Expand All @@ -80,7 +79,6 @@ public void testHandlersWithDriftServer()
}

private static int testDriftServer(List<MethodInvocationFilter> filters)
throws Exception
{
TestServerMethodInvoker methodInvoker = new TestServerMethodInvoker();

Expand All @@ -102,13 +100,13 @@ private static int testDriftServer(List<MethodInvocationFilter> filters)
}

private static int testDriftServer(ServerMethodInvoker methodInvoker, List<ToIntFunction<HostAndPort>> clients)
throws Exception
{
DriftNettyServerConfig config = new DriftNettyServerConfig()
.setSslEnabled(true)
.setTrustCertificate(ClientTestUtils.getCertificateChainFile())
.setKey(ClientTestUtils.getPrivateKeyFile());
ServerTransport serverTransport = new DriftNettyServerTransportFactory(config).createServerTransport(methodInvoker);
TestingPooledByteBufAllocator testingAllocator = new TestingPooledByteBufAllocator();
ServerTransport serverTransport = new DriftNettyServerTransportFactory(config, testingAllocator).createServerTransport(methodInvoker);
try {
serverTransport.start();

Expand All @@ -122,6 +120,7 @@ private static int testDriftServer(ServerMethodInvoker methodInvoker, List<ToInt
}
finally {
serverTransport.shutdown();
testingAllocator.close();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io.airlift.drift.integration.guice.EchoService.EmptyOptionalException;
import io.airlift.drift.integration.guice.EchoService.NullValueException;
import io.airlift.drift.integration.scribe.drift.DriftLogEntry;
import io.airlift.drift.transport.netty.buffer.TestingPooledByteBufAllocator;
import io.airlift.drift.transport.netty.client.DriftNettyClientModule;
import io.airlift.drift.transport.netty.server.DriftNettyServerModule;
import org.testng.annotations.Test;
Expand Down Expand Up @@ -50,9 +51,10 @@ public void test()
{
int port = findUnusedPort();

TestingPooledByteBufAllocator testingAllocator = new TestingPooledByteBufAllocator();
Bootstrap bootstrap = new Bootstrap(
new DriftNettyServerModule(),
new DriftNettyClientModule(),
new DriftNettyServerModule(testingAllocator),
new DriftNettyClientModule(testingAllocator),
binder -> {
driftServerBinder(binder).bindService(EchoServiceHandler.class);
driftServerBinder(binder).bindService(MismatchServiceHandler.class);
Expand Down Expand Up @@ -80,6 +82,7 @@ public void test()
}
finally {
lifeCycleManager.stop();
testingAllocator.close();
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Copyright (C) 2018 Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.airlift.drift.transport.netty.buffer;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.PooledByteBufAllocator;

import javax.annotation.concurrent.GuardedBy;

import java.io.Closeable;
import java.lang.ref.WeakReference;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;

import static com.google.common.collect.ImmutableList.toImmutableList;

public class TestingPooledByteBufAllocator
extends PooledByteBufAllocator
implements Closeable
{
public TestingPooledByteBufAllocator()
{
super(false);
}

@GuardedBy("this")
private final List<WeakReference<ByteBuf>> trackedBuffers = new ArrayList<>();

@Override
protected ByteBuf newHeapBuffer(int initialCapacity, int maxCapacity)
{
return track(super.newHeapBuffer(initialCapacity, maxCapacity));
}

@Override
protected ByteBuf newDirectBuffer(int initialCapacity, int maxCapacity)
{
return track(super.newDirectBuffer(initialCapacity, maxCapacity));
}

@Override
public CompositeByteBuf compositeHeapBuffer(int maxNumComponents)
{
return track(super.compositeHeapBuffer(maxNumComponents));
}

@Override
public CompositeByteBuf compositeDirectBuffer(int maxNumComponents)
{
return track(super.compositeDirectBuffer(maxNumComponents));
}

public synchronized List<ByteBuf> getReferencedBuffers()
{
return trackedBuffers.stream()
.map(WeakReference::get)
.filter(Objects::nonNull)
.filter(byteBuf -> byteBuf.refCnt() > 0)
.collect(toImmutableList());
}

private synchronized CompositeByteBuf track(CompositeByteBuf byteBuf)
{
trackedBuffers.add(new WeakReference<>(byteBuf));
trackedBuffers.removeIf(byteBufWeakReference -> byteBufWeakReference.get() == null);
return byteBuf;
}

private synchronized ByteBuf track(ByteBuf byteBuf)
{
trackedBuffers.add(new WeakReference<>(byteBuf));
trackedBuffers.removeIf(byteBufWeakReference -> byteBufWeakReference.get() == null);
return byteBuf;
}

@Override
public void close()
{
List<ByteBuf> referencedBuffers = getReferencedBuffers();
if (!referencedBuffers.isEmpty()) {
throw new AssertionError("LEAK");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import io.airlift.drift.protocol.TTransportException;
import io.airlift.drift.transport.netty.ssl.SslContextFactory;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBufAllocator;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
Expand All @@ -30,6 +31,7 @@
import java.net.InetSocketAddress;

import static com.google.common.primitives.Ints.saturatedCast;
import static io.netty.channel.ChannelOption.ALLOCATOR;
import static io.netty.channel.ChannelOption.CONNECT_TIMEOUT_MILLIS;
import static java.util.Objects.requireNonNull;

Expand All @@ -38,11 +40,13 @@ class ConnectionFactory
{
private final EventLoopGroup group;
private final SslContextFactory sslContextFactory;
private final ByteBufAllocator allocator;

ConnectionFactory(EventLoopGroup group, SslContextFactory sslContextFactory)
ConnectionFactory(EventLoopGroup group, SslContextFactory sslContextFactory, ByteBufAllocator allocator)
{
this.group = requireNonNull(group, "group is null");
this.sslContextFactory = requireNonNull(sslContextFactory, "sslContextFactory is null");
this.allocator = requireNonNull(allocator, "allocator is null");
}

@Override
Expand All @@ -52,6 +56,7 @@ public Future<Channel> getConnection(ConnectionParameters connectionParameters,
Bootstrap bootstrap = new Bootstrap()
.group(group)
.channel(NioSocketChannel.class)
.option(ALLOCATOR, allocator)
.option(CONNECT_TIMEOUT_MILLIS, saturatedCast(connectionParameters.getConnectTimeout().toMillis()))
.handler(new ThriftClientInitializer(
connectionParameters.getTransport(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package io.airlift.drift.transport.netty.client;

import com.google.common.annotations.VisibleForTesting;
import com.google.inject.Binder;
import com.google.inject.Injector;
import com.google.inject.Key;
Expand All @@ -23,6 +24,7 @@
import com.google.inject.TypeLiteral;
import io.airlift.drift.transport.client.DriftClientConfig;
import io.airlift.drift.transport.client.MethodInvokerFactory;
import io.netty.buffer.ByteBufAllocator;

import javax.annotation.PreDestroy;
import javax.inject.Inject;
Expand All @@ -32,10 +34,24 @@

import static com.google.common.base.Preconditions.checkState;
import static io.airlift.configuration.ConfigBinder.configBinder;
import static java.util.Objects.requireNonNull;

public class DriftNettyClientModule
implements Module
{
private final ByteBufAllocator allocator;

public DriftNettyClientModule()
{
this(ByteBufAllocator.DEFAULT);
}

@VisibleForTesting
public DriftNettyClientModule(ByteBufAllocator allocator)
{
this.allocator = requireNonNull(allocator, "allocator is null");
}

@Override
public void configure(Binder binder)
{
Expand All @@ -46,6 +62,7 @@ public void configure(Binder binder)
}
});

binder.bind(ByteBufAllocator.class).toInstance(allocator);
binder.bind(new TypeLiteral<MethodInvokerFactory<Annotation>>() {})
.toProvider(MethodInvokerFactoryProvider.class)
.in(Scopes.SINGLETON);
Expand Down Expand Up @@ -85,7 +102,8 @@ public MethodInvokerFactory<Annotation> get()

factory = new DriftNettyMethodInvokerFactory<>(
injector.getInstance(DriftNettyConnectionFactoryConfig.class),
annotation -> injector.getInstance(Key.get(DriftNettyClientConfig.class, annotation)));
annotation -> injector.getInstance(Key.get(DriftNettyClientConfig.class, annotation)),
injector.getInstance(ByteBufAllocator.class));

return factory;
}
Expand Down
Loading

0 comments on commit 97cb22e

Please sign in to comment.