Skip to content

Commit

Permalink
Close AMQP connection explicitly when no more links (Azure#4914)
Browse files Browse the repository at this point in the history
This fixes the following problem:

In certain situations EdgeHub can decide to close the connection of a device (or module). Such situations are when a SAS token expires or the device gets disabled on the azure portal, for example.
The current solution to close the device when it uses Amqp is to close the links in the Amqp session. The problem is what the underlying transport channel remains open in this case and the device is able to reopen the channel.
When EdgeHub closes a client that way it also sets its representing device objects (the related instance of classes like DeviceProxy) as closed. After that moment for certain operations, like sending M2M messages, EdgeHub checks if a module has an active DeviceProxy, and if does not, it does not send out the message. Interestingly, other operations (like transmitting twin results) does not do this check and the operation goes through, as the underlying link can be reopened.
Also, when the module (or device) does an operation (e.g. sending a telemetry message) that can go through, even if the device proxy is closed.

Because of this bug, the following can happen:
1) The module works properly for a while
2) The token is near to expire, so the module sends a new one. Let's say that the processing of the token has problems (in real life scenario, talking to edged failed for some reason), so EdgeHub chooses to close the connection.
3) EdgeHub closes all the Amqp links, and sets the device proxy as closed/
4) The module wants to send a telemetry message, it still have the underlying transport, so it opens a new link for telemetry, and sends the message.
5) EdgeHub processes the message and everything seems working
6) An incoming message comes from somewhere which is routed to the module we are talking about. 
7) EdgeHub finds the module and sees that the device proxy is closed, so it does not send the message.

The message does not get lost, but will not be sent out until the module does not get restarted. Also if there is a time-to-leave value (or some other limits), the message will be dropped after a while.

This fix changes the logic from closing the links to close the entire connection. As a result the client (device or module) cannot reopen links but needs to create a new connection, which results creating a new DeviceProxy with an appropriate state (=opened), so M2M messages (and other device bound operations checking the state) can go through.
  • Loading branch information
vipeller authored May 10, 2021
1 parent 6c4269a commit 6c8134e
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ void OnConnectionOpening(object sender, OpenEventArgs e)
amqpConnection.Extensions.Add(cbsNode);
}

IClientConnectionsHandler connectionHandler = new ClientConnectionsHandler(this.connectionProvider);
IClientConnectionsHandler connectionHandler = new ClientConnectionsHandler(this.connectionProvider, amqpConnection);
amqpConnection.Extensions.Add(connectionHandler);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Amqp
using System.Linq;
using System.Threading.Tasks;
using System.Web;
using Microsoft.Azure.Amqp;
using Microsoft.Azure.Devices.Edge.Hub.Amqp.LinkHandlers;
using Microsoft.Azure.Devices.Edge.Hub.Core;
using Microsoft.Azure.Devices.Edge.Hub.Core.Device;
Expand All @@ -21,18 +22,21 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Amqp
/// </summary>
class ClientConnectionHandler : IConnectionHandler
{
readonly TimeSpan closeTimeout = TimeSpan.FromSeconds(60);
readonly IDictionary<LinkType, ILinkHandler> registry = new Dictionary<LinkType, ILinkHandler>();
readonly IIdentity identity;
readonly AmqpConnectionBase amqpConnection;

readonly AsyncLock initializationLock = new AsyncLock();
readonly AsyncLock registryUpdateLock = new AsyncLock();
readonly IConnectionProvider connectionProvider;
Option<IDeviceListener> deviceListener = Option.None<IDeviceListener>();

public ClientConnectionHandler(IIdentity identity, IConnectionProvider connectionProvider)
public ClientConnectionHandler(IIdentity identity, IConnectionProvider connectionProvider, AmqpConnectionBase amqpConnection)
{
this.identity = Preconditions.CheckNotNull(identity, nameof(identity));
this.connectionProvider = Preconditions.CheckNotNull(connectionProvider, nameof(connectionProvider));
this.amqpConnection = Preconditions.CheckNotNull(amqpConnection, nameof(amqpConnection));
}

public Task<IDeviceListener> GetDeviceListener()
Expand Down Expand Up @@ -129,19 +133,13 @@ public async Task RemoveLinkHandler(ILinkHandler linkHandler)
}
}

Task CloseAllLinks()
{
IList<ILinkHandler> links = this.registry.Values.ToList();
IEnumerable<Task> closeTasks = links.Select(l => l.CloseAsync(Constants.DefaultTimeout));
return Task.WhenAll(closeTasks);
}

async Task CloseConnection()
{
using (await this.initializationLock.LockAsync())
{
await this.deviceListener.ForEachAsync(d => d.CloseAsync());
this.deviceListener = Option.None<IDeviceListener>();
await this.amqpConnection.CloseAsync(this.closeTimeout);
}
}

Expand All @@ -167,7 +165,7 @@ public Task CloseAsync(Exception ex)
if (this.isActive.GetAndSet(false))
{
Events.ClosingProxy(this.Identity, ex);
return this.clientConnectionHandler.CloseAllLinks();
return this.clientConnectionHandler.CloseConnection();
}

return Task.CompletedTask;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
namespace Microsoft.Azure.Devices.Edge.Hub.Amqp
{
using System.Collections.Concurrent;
using Microsoft.Azure.Amqp;
using Microsoft.Azure.Devices.Edge.Hub.Core;
using Microsoft.Azure.Devices.Edge.Hub.Core.Identity;
using Microsoft.Azure.Devices.Edge.Util;
Expand All @@ -10,13 +11,15 @@ class ClientConnectionsHandler : IClientConnectionsHandler
{
readonly ConcurrentDictionary<string, ClientConnectionHandler> connectionHandlers = new ConcurrentDictionary<string, ClientConnectionHandler>();
readonly IConnectionProvider connectionProvider;
readonly AmqpConnection amqpConnection;

public ClientConnectionsHandler(IConnectionProvider connectionProvider)
public ClientConnectionsHandler(IConnectionProvider connectionProvider, AmqpConnection amqpConnection)
{
this.connectionProvider = Preconditions.CheckNotNull(connectionProvider, nameof(connectionProvider));
this.amqpConnection = Preconditions.CheckNotNull(amqpConnection, nameof(amqpConnection));
}

public IConnectionHandler GetConnectionHandler(IIdentity identity) =>
this.connectionHandlers.GetOrAdd(identity.Id, i => new ClientConnectionHandler(identity, this.connectionProvider));
this.connectionHandlers.GetOrAdd(identity.Id, i => new ClientConnectionHandler(identity, this.connectionProvider, this.amqpConnection));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Amqp.Test
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.Azure.Amqp;
using Microsoft.Azure.Amqp.Framing;
using Microsoft.Azure.Amqp.Transport;
using Microsoft.Azure.Devices.Edge.Hub.Amqp.LinkHandlers;
using Microsoft.Azure.Devices.Edge.Hub.Core;
using Microsoft.Azure.Devices.Edge.Hub.Core.Device;
Expand All @@ -23,12 +26,14 @@ public void ConnectionHandlerCtorTest()
{
// Arrange
var identity = Mock.Of<IIdentity>();
var connectionPovider = Mock.Of<IConnectionProvider>();
var connectionProvider = Mock.Of<IConnectionProvider>();
var amqpConnection = new AmqpTestConnection();

// Act / Assert
Assert.NotNull(new ClientConnectionHandler(identity, connectionPovider));
Assert.Throws<ArgumentNullException>(() => new ClientConnectionHandler(null, connectionPovider));
Assert.Throws<ArgumentNullException>(() => new ClientConnectionHandler(identity, null));
Assert.NotNull(new ClientConnectionHandler(identity, connectionProvider, amqpConnection));
Assert.Throws<ArgumentNullException>(() => new ClientConnectionHandler(null, connectionProvider, amqpConnection));
Assert.Throws<ArgumentNullException>(() => new ClientConnectionHandler(identity, null, amqpConnection));
Assert.Throws<ArgumentNullException>(() => new ClientConnectionHandler(identity, connectionProvider, null));
}

[Fact]
Expand All @@ -42,7 +47,9 @@ public async Task GetDeviceListenerTest()
.Callback<IDeviceProxy>(d => deviceProxy = d);

var connectionProvider = Mock.Of<IConnectionProvider>(c => c.GetDeviceListenerAsync(identity, Option.None<string>()) == Task.FromResult(deviceListener));
var connectionHandler = new ClientConnectionHandler(identity, connectionProvider);
var amqpConnection = new AmqpTestConnection();

var connectionHandler = new ClientConnectionHandler(identity, connectionProvider, amqpConnection);

// Act
var tasks = new List<Task<IDeviceListener>>();
Expand Down Expand Up @@ -77,8 +84,9 @@ public async Task RegisterC2DMessageSenderTest()
.Callback<IDeviceProxy>(d => deviceProxy = d);

var connectionProvider = Mock.Of<IConnectionProvider>(c => c.GetDeviceListenerAsync(identity, Option.None<string>()) == Task.FromResult(deviceListener));
var amqpConnection = new AmqpTestConnection();

var connectionHandler = new ClientConnectionHandler(identity, connectionProvider);
var connectionHandler = new ClientConnectionHandler(identity, connectionProvider, amqpConnection);

IMessage receivedMessage = null;
var c2DLinkHandler = new Mock<ISendingLinkHandler>();
Expand Down Expand Up @@ -113,8 +121,9 @@ public async Task RegisterModuleMessageSenderTest()
.Callback<IDeviceProxy>(d => deviceProxy = d);

var connectionProvider = Mock.Of<IConnectionProvider>(c => c.GetDeviceListenerAsync(identity, Option.None<string>()) == Task.FromResult(deviceListener));
var amqpConnection = new AmqpTestConnection();

var connectionHandler = new ClientConnectionHandler(identity, connectionProvider);
var connectionHandler = new ClientConnectionHandler(identity, connectionProvider, amqpConnection);

IMessage receivedMessage = null;
var moduleMessageLinkHandler = new Mock<ISendingLinkHandler>();
Expand Down Expand Up @@ -149,8 +158,9 @@ public async Task RegisterMethodInvokerTest()
.Callback<IDeviceProxy>(d => deviceProxy = d);

var connectionProvider = Mock.Of<IConnectionProvider>(c => c.GetDeviceListenerAsync(identity, Option.None<string>()) == Task.FromResult(deviceListener));
var amqpConnection = new AmqpTestConnection();

var connectionHandler = new ClientConnectionHandler(identity, connectionProvider);
var connectionHandler = new ClientConnectionHandler(identity, connectionProvider, amqpConnection);

IMessage receivedMessage = null;
var methodSendingLinkHandler = new Mock<ISendingLinkHandler>();
Expand Down Expand Up @@ -185,8 +195,9 @@ public async Task RegisterDesiredPropertiesUpdateSenderTest()
.Callback<IDeviceProxy>(d => deviceProxy = d);

var connectionProvider = Mock.Of<IConnectionProvider>(c => c.GetDeviceListenerAsync(identity, Option.None<string>()) == Task.FromResult(deviceListener));
var amqpConnection = new AmqpTestConnection();

var connectionHandler = new ClientConnectionHandler(identity, connectionProvider);
var connectionHandler = new ClientConnectionHandler(identity, connectionProvider, amqpConnection);

IMessage receivedMessage = null;
var twinSendingLinkHandler = new Mock<ISendingLinkHandler>();
Expand Down Expand Up @@ -214,11 +225,13 @@ public async Task CloseOnRemovingAllLinksTest()
// Arrange
var deviceListener = new Mock<IDeviceListener>();
deviceListener.Setup(d => d.CloseAsync()).Returns(Task.CompletedTask);
deviceListener.Setup(d => d.BindDeviceProxy(It.IsAny<IDeviceProxy>()));

var identity = Mock.Of<IIdentity>(i => i.Id == "d1/m1");
var connectionProvider = Mock.Of<IConnectionProvider>(c => c.GetDeviceListenerAsync(identity, Option.None<string>()) == Task.FromResult(deviceListener.Object));
deviceListener.Setup(d => d.BindDeviceProxy(It.IsAny<IDeviceProxy>()));
var amqpConnection = new AmqpTestConnection();

var connectionHandler = new ClientConnectionHandler(identity, connectionProvider);
var connectionHandler = new ClientConnectionHandler(identity, connectionProvider, amqpConnection);

var eventsLinkHandler = Mock.Of<ILinkHandler>(l => l.Type == LinkType.Events);
string twinCorrelationId = Guid.NewGuid().ToString();
Expand Down Expand Up @@ -249,6 +262,7 @@ public async Task CloseOnRemovingAllLinksTest()

// Assert
deviceListener.Verify(d => d.CloseAsync(), Times.Once);
Assert.True(amqpConnection.CloseCalled);

// Act
await connectionHandler.GetDeviceListener();
Expand All @@ -257,4 +271,61 @@ public async Task CloseOnRemovingAllLinksTest()
deviceListener.Verify(d => d.BindDeviceProxy(It.IsAny<IDeviceProxy>()), Times.Exactly(2));
}
}

class AmqpTestConnection : AmqpConnectionBase
{
public AmqpTestConnection()
: base("test", new TestTransport(), new AmqpConnectionSettings(), false)
{
}

public bool CloseCalled { get; private set; }

protected override void AbortInternal()
{
}

protected override bool CloseInternal()
{
this.CloseCalled = true;
return true;
}

protected override void OnFrameBuffer(ByteBuffer buffer)
{
}

protected override void OnProtocolHeader(ProtocolHeader header)
{
}

protected override bool OpenInternal() => true;
}

class TestTransport : TransportBase
{
public override string LocalEndPoint => "localhost";
public override string RemoteEndPoint => "remotehost";

public TestTransport()
: base("test")
{
}

public bool CloseCalled { get; private set; }

public override bool ReadAsync(TransportAsyncCallbackArgs args) => true;

public override void SetMonitor(ITransportMonitor usageMeter)
{
}

public override bool WriteAsync(TransportAsyncCallbackArgs args) => false;

protected override void AbortInternal()
{
}

protected override bool CloseInternal() => true;
}
}

0 comments on commit 6c8134e

Please sign in to comment.