Skip to content

Commit

Permalink
GH-663: Fix a race in IoSession creation
Browse files Browse the repository at this point in the history
IoSessions are created asynchronously through an IoConnector. When an
IoConnector is closed, all sessions created through it are also to be
closed.

Because session creation is asynchronous, it was possible that a newly
created session registered on an already closed IoConnector, and then
would never be closed.

Prevent this by forcibly closing any newly created IoSession if the
IoConnector is already closed when the new session tries to register
with the connector.

Add a test that verifies that the connect future the client code sees
does not provide a session but is fulfilled either by an exception or
by having been cancelled.
  • Loading branch information
tomaswolf committed Mar 1, 2025
1 parent c497904 commit 0fed512
Show file tree
Hide file tree
Showing 9 changed files with 204 additions and 18 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
## Bug Fixes

* [GH-650](https://github.com/apache/mina-sshd/issues/650) Use the correct key from a user certificate in server-side pubkey auth
* [GH-663](https://github.com/apache/mina-sshd/issues/663) Fix racy `IoSession` creation
* [GH-664](https://github.com/apache/mina-sshd/issues/664) Skip MAC negotiation if an AEAD cipher was negotiated

## New Features
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,10 @@ protected void onCompleted(Void result, Object attachment) {

handler.sessionCreated(session);
sessionId = session.getId();
sessions.put(sessionId, session);
future.setSession(session);
IoSession registered = mapSession(session);
if (registered == session) {
future.setSession(session);
}
if (session != future.getSession()) {
session.close(true);
throw new CancellationException();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ public abstract class Nio2Service extends AbstractInnerCloseable implements IoSe
private final AsynchronousChannelGroup group;
private final ExecutorService executor;
private IoServiceEventListener eventListener;
private boolean noMoreSessions;

protected Nio2Service(PropertyResolver propertyResolver, IoHandler handler, AsynchronousChannelGroup group,
ExecutorService resumeTasks) {
Expand Down Expand Up @@ -127,7 +128,7 @@ public void dispose() {
@Override
protected Closeable getInnerCloseable() {
return builder()
.parallel(toString(), sessions.values())
.parallel(toString(), snapshot())
.build();
}

Expand All @@ -140,6 +141,23 @@ public void sessionClosed(Nio2Session session) {
unmapSession(session.getId());
}

private Collection<IoSession> snapshot() {
synchronized (this) {
noMoreSessions = true;
}
return sessions.values();
}

protected IoSession mapSession(IoSession session) {
synchronized (this) {
if (noMoreSessions) {
return null;
}
sessions.put(session.getId(), session);
return session;
}
}

protected void unmapSession(Long sessionId) {
if (sessionId != null) {
IoSession ioSession = sessions.remove(sessionId);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sshd.common.io;

import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;

import org.apache.mina.core.buffer.IoBuffer;
import org.apache.mina.core.service.IoHandlerAdapter;
import org.apache.mina.transport.socket.nio.NioSocketAcceptor;
import org.apache.sshd.client.SshClient;
import org.apache.sshd.common.future.SshFutureListener;
import org.apache.sshd.common.util.Readable;
import org.apache.sshd.util.test.BaseTestSupport;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* Tests for low-level connections.
*/
class IoConnectionTest extends BaseTestSupport {

private static final Logger LOG = LoggerFactory.getLogger(IoConnectionTest.class);

@Test
void connectorRace() throws Exception {
CountDownLatch connectionMade = new CountDownLatch(1);
CountDownLatch connectorClosing = new CountDownLatch(1);
CountDownLatch futureTriggered = new CountDownLatch(1);
CountDownLatch ioSessionClosed = new CountDownLatch(1);
AtomicReference<IoSession> session = new AtomicReference<>();
AtomicBoolean connectorIsClosing = new AtomicBoolean();
AtomicBoolean sessionWaited = new AtomicBoolean();

SshClient client = setupTestClient();
IoServiceFactory serviceFactory = DefaultIoServiceFactoryFactory.getDefaultIoServiceFactoryFactoryInstance()
.create(client);
IoConnector connector = serviceFactory.createConnector(new IoHandler() {

@Override
public void sessionCreated(org.apache.sshd.common.io.IoSession session) throws Exception {
connectionMade.countDown();
sessionWaited.set(connectorClosing.await(5, TimeUnit.SECONDS));
}

@Override
public void sessionClosed(org.apache.sshd.common.io.IoSession session) throws Exception {
ioSessionClosed.countDown();
}

@Override
public void exceptionCaught(org.apache.sshd.common.io.IoSession session, Throwable cause) throws Exception {
// Nothing
}

@Override
public void messageReceived(org.apache.sshd.common.io.IoSession session, Readable message) throws Exception {
// Nothing; we're not actually sending or receiving data.
}
});
NioSocketAcceptor acceptor = startEchoServer();
try {
InetSocketAddress connectAddress = new InetSocketAddress(InetAddress.getByName(TEST_LOCALHOST),
acceptor.getLocalAddress().getPort());
IoConnectFuture future = connector.connect(connectAddress, null, null);
connectionMade.await(5, TimeUnit.SECONDS);
connector.close();
connectorClosing.countDown();
future.addListener(new SshFutureListener<IoConnectFuture>() {

@Override
public void operationComplete(IoConnectFuture future) {
session.set(future.getSession());
connectorIsClosing.set(!connector.isOpen());
futureTriggered.countDown();
}
});
assertTrue(futureTriggered.await(5, TimeUnit.SECONDS));
Throwable error = future.getException();
if (error != null) {
LOG.info("{}: Connect future was terminated exceptionally: {} ", getCurrentTestName(), error);
error.printStackTrace();
} else if (future.isCanceled()) {
LOG.info("{}: Connect future was canceled", getCurrentTestName());
}
assertEquals(0, connectionMade.getCount());
assertTrue(sessionWaited.get());
assertNull(session.get());
assertTrue(connectorIsClosing.get());
// Since sessionCreated() was called we also expect sessionClosed() to get called eventually.
assertTrue(ioSessionClosed.await(5, TimeUnit.SECONDS));
} finally {
acceptor.dispose(false);
}
}

private NioSocketAcceptor startEchoServer() throws IOException {
NioSocketAcceptor acceptor = new NioSocketAcceptor();
acceptor.setHandler(new IoHandlerAdapter() {

@Override
public void messageReceived(org.apache.mina.core.session.IoSession session, Object message) throws Exception {
IoBuffer recv = (IoBuffer) message;
IoBuffer sent = IoBuffer.allocate(recv.remaining());
sent.put(recv);
sent.flip();
session.write(sent);
}
});
acceptor.setReuseAddress(true);
acceptor.bind(new InetSocketAddress(0));
return acceptor;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ public void setSession(org.apache.sshd.common.io.IoSession session) {
Throwable t = cf.getException();
if (t != null) {
future.setException(t);
} else if (cf.isCanceled()) {
} else if (cf.isCanceled() || !isOpen()) {
IoSession ioSession = createdSession.getAndSet(null);
CancelFuture cancellation = future.cancel();
if (ioSession != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,10 @@
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.group.DefaultChannelGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import io.netty.util.concurrent.GlobalEventExecutor;
import org.apache.sshd.common.future.CloseFuture;
import org.apache.sshd.common.io.IoAcceptor;
import org.apache.sshd.common.io.IoHandler;
Expand All @@ -64,10 +62,9 @@ public class NettyIoAcceptor extends NettyIoService implements IoAcceptor {
protected final Map<SocketAddress, Channel> boundAddresses = new ConcurrentHashMap<>();

public NettyIoAcceptor(NettyIoServiceFactory factory, IoHandler handler) {
super(factory, handler);
super(factory, handler, "sshd-acceptor-channels");

Boolean reuseaddr = CoreModuleProperties.SOCKET_REUSEADDR.getRequired(factory.manager);
channelGroup = new DefaultChannelGroup("sshd-acceptor-channels", GlobalEventExecutor.INSTANCE);
bootstrap.group(factory.eventLoopGroup)
.channel(NioServerSocketChannel.class)
.option(ChannelOption.SO_BACKLOG, CoreModuleProperties.SOCKET_BACKLOG.getRequired(factory.manager))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,10 @@
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.group.DefaultChannelGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import io.netty.util.concurrent.GlobalEventExecutor;
import org.apache.sshd.common.AttributeRepository;
import org.apache.sshd.common.future.CancelFuture;
import org.apache.sshd.common.io.DefaultIoConnectFuture;
Expand All @@ -51,8 +49,7 @@ public class NettyIoConnector extends NettyIoService implements IoConnector {
private static final LoggingHandler LOGGING_TRACE = new LoggingHandler(NettyIoConnector.class, LogLevel.TRACE);

public NettyIoConnector(NettyIoServiceFactory factory, IoHandler handler) {
super(factory, handler);
channelGroup = new DefaultChannelGroup("sshd-connector-channels", GlobalEventExecutor.INSTANCE);
super(factory, handler, "sshd-connector-channels");
}

@Override
Expand Down
41 changes: 38 additions & 3 deletions sshd-netty/src/main/java/org/apache/sshd/netty/NettyIoService.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,18 @@

package org.apache.sshd.netty;

import java.util.Collections;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CancellationException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;

import io.netty.channel.Channel;
import io.netty.channel.group.ChannelGroup;
import io.netty.channel.group.DefaultChannelGroup;
import io.netty.util.AttributeKey;
import io.netty.util.concurrent.GlobalEventExecutor;
import org.apache.sshd.common.io.IoConnectFuture;
import org.apache.sshd.common.io.IoHandler;
import org.apache.sshd.common.io.IoService;
Expand All @@ -44,16 +49,46 @@ public abstract class NettyIoService extends AbstractCloseable implements IoServ

protected final AtomicLong sessionSeq = new AtomicLong();
protected final Map<Long, IoSession> sessions = new ConcurrentHashMap<>();
protected ChannelGroup channelGroup;
protected final ChannelGroup channelGroup;
protected final NettyIoServiceFactory factory;
protected final IoHandler handler;
private boolean noMoreSessions;

private IoServiceEventListener eventListener;

protected NettyIoService(NettyIoServiceFactory factory, IoHandler handler) {
protected NettyIoService(NettyIoServiceFactory factory, IoHandler handler, String channelGroupName) {
this.factory = Objects.requireNonNull(factory, "No factory instance provided");
this.handler = Objects.requireNonNull(handler, "No I/O handler provied");
this.eventListener = factory.getIoServiceEventListener();
this.channelGroup = new DefaultChannelGroup(Objects.requireNonNull(channelGroupName, "No channel group name"),
GlobalEventExecutor.INSTANCE);
}

@Override
protected void doCloseImmediately() {
synchronized (this) {
noMoreSessions = true;
}
channelGroup.close();
super.doCloseImmediately();
}

protected void registerChannel(Channel channel) throws CancellationException {
synchronized (this) {
if (noMoreSessions) {
throw new CancellationException("NettyIoService closed");
}
channelGroup.add(channel);
}
}

protected void mapSession(IoSession session) throws CancellationException {
synchronized (this) {
if (noMoreSessions) {
throw new CancellationException("NettyIoService closed; cannot register new session");
}
sessions.put(session.getId(), session);
}
}

@Override
Expand All @@ -68,6 +103,6 @@ public void setIoServiceEventListener(IoServiceEventListener listener) {

@Override
public Map<Long, IoSession> getManagedSessions() {
return sessions;
return Collections.unmodifiableMap(sessions);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,6 @@ protected void doCloseImmediately() {
protected void channelActive(ChannelHandlerContext ctx) throws Exception {
context = ctx;
Channel channel = ctx.channel();
service.channelGroup.add(channel);
service.sessions.put(id, NettyIoSession.this);
prev = context.newPromise().setSuccess();
remoteAddr = channel.remoteAddress();
// If handler.sessionCreated() propagates an exception, we'll have a NettyIoSession without SSH session. We'll
Expand All @@ -254,15 +252,17 @@ protected void channelActive(ChannelHandlerContext ctx) throws Exception {
Attribute<IoConnectFuture> connectFuture = channel.attr(NettyIoService.CONNECT_FUTURE_KEY);
IoConnectFuture future = connectFuture.get();
try {
service.registerChannel(channel);
handler.sessionCreated(NettyIoSession.this);
service.mapSession(this);
if (future != null) {
future.setSession(NettyIoSession.this);
if (future.getSession() != NettyIoSession.this) {
close(true);
}
}
} catch (Throwable e) {
log.warn("channelActive(session={}): could not create SSH session ({}); closing", this, e.getClass().getName(), e);
warn("channelActive(session={}): could not create SSH session ({}); closing", this, e.getClass().getName(), e);
try {
if (future != null) {
future.setException(e);
Expand Down

0 comments on commit 0fed512

Please sign in to comment.