Skip to content

Commit

Permalink
KAFKA-1512 Add per-ip connection limits.
Browse files Browse the repository at this point in the history
  • Loading branch information
jkreps committed Jul 16, 2014
1 parent b428d8c commit 8e444a3
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 46 deletions.
146 changes: 103 additions & 43 deletions core/src/main/scala/kafka/network/SocketServer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import java.net._
import java.io._
import java.nio.channels._

import scala.collection._

import kafka.common.KafkaException
import kafka.metrics.KafkaMetricsGroup
import kafka.utils._
Expand All @@ -41,7 +43,9 @@ class SocketServer(val brokerId: Int,
val maxQueuedRequests: Int,
val sendBufferSize: Int,
val recvBufferSize: Int,
val maxRequestSize: Int = Int.MaxValue) extends Logging with KafkaMetricsGroup {
val maxRequestSize: Int = Int.MaxValue,
val maxConnectionsPerIp: Int = Int.MaxValue,
val maxConnectionsPerIpOverrides: Map[String, Int] = Map[String, Int]()) extends Logging with KafkaMetricsGroup {
this.logIdent = "[Socket Server on Broker " + brokerId + "], "
private val time = SystemTime
private val processors = new Array[Processor](numProcessorThreads)
Expand All @@ -55,17 +59,23 @@ class SocketServer(val brokerId: Int,
* Start the socket server
*/
def startup() {
val quotas = new ConnectionQuotas(maxConnectionsPerIp, maxConnectionsPerIpOverrides)
for(i <- 0 until numProcessorThreads) {
processors(i) = new Processor(i, time, maxRequestSize, aggregateIdleMeter,
newMeter("NetworkProcessor-" + i + "-IdlePercent", "percent", TimeUnit.NANOSECONDS),
numProcessorThreads, requestChannel)
processors(i) = new Processor(i,
time,
maxRequestSize,
aggregateIdleMeter,
newMeter("NetworkProcessor-" + i + "-IdlePercent", "percent", TimeUnit.NANOSECONDS),
numProcessorThreads,
requestChannel,
quotas)
Utils.newThread("kafka-network-thread-%d-%d".format(port, i), processors(i), false).start()
}
// register the processor threads for notification of responses
requestChannel.addResponseListener((id:Int) => processors(id).wakeup())

// start accepting connections
this.acceptor = new Acceptor(host, port, processors, sendBufferSize, recvBufferSize)
this.acceptor = new Acceptor(host, port, processors, sendBufferSize, recvBufferSize, quotas)
Utils.newThread("kafka-socket-acceptor", acceptor, false).start()
acceptor.awaitStartup
info("Started")
Expand All @@ -87,7 +97,7 @@ class SocketServer(val brokerId: Int,
/**
* A base class with some helper variables and methods
*/
private[kafka] abstract class AbstractServerThread extends Runnable with Logging {
private[kafka] abstract class AbstractServerThread(connectionQuotas: ConnectionQuotas) extends Runnable with Logging {

protected val selector = Selector.open();
private val startupLatch = new CountDownLatch(1)
Expand Down Expand Up @@ -131,13 +141,48 @@ private[kafka] abstract class AbstractServerThread extends Runnable with Logging
*/
def wakeup() = selector.wakeup()

/**
* Close the given key and associated socket
*/
def close(key: SelectionKey) {
if(key != null) {
key.attach(null)
close(key.channel.asInstanceOf[SocketChannel])
swallowError(key.cancel())
}
}

def close(channel: SocketChannel) {
if(channel != null) {
debug("Closing connection from " + channel.socket.getRemoteSocketAddress())
connectionQuotas.dec(channel.socket.getInetAddress)
swallowError(channel.socket().close())
swallowError(channel.close())
}
}

/**
* Close all open connections
*/
def closeAll() {
val iter = this.selector.keys().iterator()
while (iter.hasNext) {
val key = iter.next()
close(key)
}
}

}

/**
* Thread that accepts and configures new connections. There is only need for one of these
*/
private[kafka] class Acceptor(val host: String, val port: Int, private val processors: Array[Processor],
val sendBufferSize: Int, val recvBufferSize: Int) extends AbstractServerThread {
private[kafka] class Acceptor(val host: String,
val port: Int,
private val processors: Array[Processor],
val sendBufferSize: Int,
val recvBufferSize: Int,
connectionQuotas: ConnectionQuotas) extends AbstractServerThread(connectionQuotas) {
val serverChannel = openServerSocket(host, port)

/**
Expand All @@ -158,14 +203,14 @@ private[kafka] class Acceptor(val host: String, val port: Int, private val proce
key = iter.next
iter.remove()
if(key.isAcceptable)
accept(key, processors(currentProcessor))
else
throw new IllegalStateException("Unrecognized key state for acceptor thread.")
accept(key, processors(currentProcessor))
else
throw new IllegalStateException("Unrecognized key state for acceptor thread.")

// round robin to the next processor thread
currentProcessor = (currentProcessor + 1) % processors.length
// round robin to the next processor thread
currentProcessor = (currentProcessor + 1) % processors.length
} catch {
case e: Throwable => error("Error in acceptor", e)
case e: Throwable => error("Error while accepting connection", e)
}
}
}
Expand All @@ -187,6 +232,7 @@ private[kafka] class Acceptor(val host: String, val port: Int, private val proce
new InetSocketAddress(host, port)
val serverChannel = ServerSocketChannel.open()
serverChannel.configureBlocking(false)
serverChannel.socket().setReceiveBufferSize(recvBufferSize)
try {
serverChannel.socket.bind(socketAddress)
info("Awaiting socket connections on %s:%d.".format(socketAddress.getHostName, port))
Expand All @@ -202,19 +248,24 @@ private[kafka] class Acceptor(val host: String, val port: Int, private val proce
*/
def accept(key: SelectionKey, processor: Processor) {
val serverSocketChannel = key.channel().asInstanceOf[ServerSocketChannel]
serverSocketChannel.socket().setReceiveBufferSize(recvBufferSize)

val socketChannel = serverSocketChannel.accept()
socketChannel.configureBlocking(false)
socketChannel.socket().setTcpNoDelay(true)
socketChannel.socket().setSendBufferSize(sendBufferSize)
try {
connectionQuotas.inc(socketChannel.socket().getInetAddress)
socketChannel.configureBlocking(false)
socketChannel.socket().setTcpNoDelay(true)
socketChannel.socket().setSendBufferSize(sendBufferSize)

debug("Accepted connection from %s on %s. sendBufferSize [actual|requested]: [%d|%d] recvBufferSize [actual|requested]: [%d|%d]"
.format(socketChannel.socket.getInetAddress, socketChannel.socket.getLocalSocketAddress,
debug("Accepted connection from %s on %s. sendBufferSize [actual|requested]: [%d|%d] recvBufferSize [actual|requested]: [%d|%d]"
.format(socketChannel.socket.getInetAddress, socketChannel.socket.getLocalSocketAddress,
socketChannel.socket.getSendBufferSize, sendBufferSize,
socketChannel.socket.getReceiveBufferSize, recvBufferSize))

processor.accept(socketChannel)
processor.accept(socketChannel)
} catch {
case e: TooManyConnectionsException =>
info("Rejected connection from %s, address already has the configured maximum of %d connections.".format(e.ip, e.count))
close(socketChannel)
}
}

}
Expand All @@ -229,7 +280,8 @@ private[kafka] class Processor(val id: Int,
val aggregateIdleMeter: Meter,
val idleMeter: Meter,
val totalProcessorThreads: Int,
val requestChannel: RequestChannel) extends AbstractServerThread {
val requestChannel: RequestChannel,
connectionQuotas: ConnectionQuotas) extends AbstractServerThread(connectionQuotas) {

private val newConnections = new ConcurrentLinkedQueue[SocketChannel]()

Expand Down Expand Up @@ -324,26 +376,6 @@ private[kafka] class Processor(val id: Int,
}
}
}

private def close(key: SelectionKey) {
val channel = key.channel.asInstanceOf[SocketChannel]
debug("Closing connection from " + channel.socket.getRemoteSocketAddress())
swallowError(channel.socket().close())
swallowError(channel.close())
key.attach(null)
swallowError(key.cancel())
}

/*
* Close all open connections
*/
private def closeAll() {
val iter = this.selector.keys().iterator()
while (iter.hasNext) {
val key = iter.next()
close(key)
}
}

/**
* Queue up a new connection for reading
Expand Down Expand Up @@ -419,3 +451,31 @@ private[kafka] class Processor(val id: Int,
private def channelFor(key: SelectionKey) = key.channel().asInstanceOf[SocketChannel]

}

class ConnectionQuotas(val defaultMax: Int, overrideQuotas: Map[String, Int]) {
private val overrides = overrideQuotas.map(entry => (InetAddress.getByName(entry._1), entry._2))
private val counts = mutable.Map[InetAddress, Int]()

def inc(addr: InetAddress) {
counts synchronized {
val count = counts.getOrElse(addr, 0)
counts.put(addr, count + 1)
val max = overrides.getOrElse(addr, defaultMax)
if(count >= max)
throw new TooManyConnectionsException(addr, max)
}
}

def dec(addr: InetAddress) {
counts synchronized {
val count = counts.get(addr).get
if(count == 1)
counts.remove(addr)
else
counts.put(addr, count - 1)
}
}

}

class TooManyConnectionsException(val ip: InetAddress, val count: Int) extends KafkaException("Too many connections from %s (maximum = %d)".format(ip, count))
6 changes: 6 additions & 0 deletions core/src/main/scala/kafka/server/KafkaConfig.scala
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@ class KafkaConfig private (val props: VerifiableProperties) extends ZKConfig(pro

/* the maximum number of bytes in a socket request */
val socketRequestMaxBytes: Int = props.getIntInRange("socket.request.max.bytes", 100*1024*1024, (1, Int.MaxValue))

/* the maximum number of connections we allow from each ip address */
val maxConnectionsPerIp: Int = props.getIntInRange("max.connections.per.ip", Int.MaxValue, (1, Int.MaxValue))

/* per-ip or hostname overrides to the default maximum number of connections */
val maxConnectionsPerIpOverrides = props.getMap("max.connections.per.ip.overrides").map(entry => (entry._1, entry._2.toInt))

/*********** Log Configuration ***********/

Expand Down
3 changes: 2 additions & 1 deletion core/src/main/scala/kafka/server/KafkaServer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ class KafkaServer(val config: KafkaConfig, time: Time = SystemTime) extends Logg
config.queuedMaxRequests,
config.socketSendBufferBytes,
config.socketReceiveBufferBytes,
config.socketRequestMaxBytes)
config.socketRequestMaxBytes,
config.maxConnectionsPerIp)
socketServer.startup()

replicaManager = new ReplicaManager(config, time, zkClient, kafkaScheduler, logManager, isShuttingDown)
Expand Down
20 changes: 18 additions & 2 deletions core/src/test/scala/unit/kafka/network/SocketServerTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ class SocketServerTest extends JUnitSuite {
maxQueuedRequests = 50,
sendBufferSize = 300000,
recvBufferSize = 300000,
maxRequestSize = 50)
maxRequestSize = 50,
maxConnectionsPerIp = 5)
server.startup()

def sendRequest(socket: Socket, id: Short, request: Array[Byte]) {
Expand Down Expand Up @@ -75,7 +76,7 @@ class SocketServerTest extends JUnitSuite {
def cleanup() {
server.shutdown()
}

@Test
def simpleRequest() {
val socket = connect()
Expand Down Expand Up @@ -139,4 +140,19 @@ class SocketServerTest extends JUnitSuite {
// doing a subsequent send should throw an exception as the connection should be closed.
sendRequest(socket, 0, bytes)
}

@Test
def testMaxConnectionsPerIp() {
// make the maximum allowable number of connections and then leak them
val conns = (0 until server.maxConnectionsPerIp).map(i => connect())

// now try one more (should fail)
try {
val conn = connect()
sendRequest(conn, 100, "hello".getBytes)
assertEquals(-1, conn.getInputStream().read())
} catch {
case e: IOException => // this is good
}
}
}

0 comments on commit 8e444a3

Please sign in to comment.