Skip to content

Commit

Permalink
Replace messages generator with iterator class that implements len()
Browse files Browse the repository at this point in the history
  • Loading branch information
empicano committed Aug 1, 2024
1 parent 650699a commit becacae
Showing 1 changed file with 38 additions and 32 deletions.
70 changes: 38 additions & 32 deletions aiomqtt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from types import TracebackType
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Awaitable,
Callable,
Coroutine,
Expand Down Expand Up @@ -125,7 +125,7 @@ class Will:


class Client:
"""The async context manager that manages the connection to the broker.
"""Asynchronous context manager for the connection to the MQTT broker.
Args:
hostname: The hostname or IP address of the remote broker.
Expand Down Expand Up @@ -320,10 +320,6 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915
timeout = 10
self.timeout = timeout

@property
def messages(self) -> AsyncGenerator[Message, None]:
return self._messages()

@property
def identifier(self) -> str:
"""Return the client identifier.
Expand All @@ -333,6 +329,42 @@ def identifier(self) -> str:
"""
return self._client._client_id.decode() # noqa: SLF001

class MessagesIterator:
"""Dynamic view of the message queue."""

def __init__(self, client: Client) -> None:
self._client = client

def __aiter__(self) -> AsyncIterator[Message]:
return self

async def __anext__(self) -> Message:
# Wait until we either (1) receive a message or (2) disconnect
task = self._client._loop.create_task(self._client._queue.get()) # noqa: SLF001
try:
done, _ = await asyncio.wait(
(task, self._client._disconnected), # noqa: SLF001
return_when=asyncio.FIRST_COMPLETED,
)
# If the asyncio.wait is cancelled, we must also cancel the queue task
except asyncio.CancelledError:
task.cancel()
raise
# When we receive a message, return it
if task in done:
return task.result()
# If we disconnect from the broker, stop the generator with an exception
task.cancel()
msg = "Disconnected during message iteration"
raise MqttError(msg)

def __len__(self) -> int:
return self._client._queue.qsize() # noqa: SLF001

@property
def messages(self) -> MessagesIterator:
return self.MessagesIterator(self)

@property
def _pending_calls(self) -> Generator[int, None, None]:
"""Yield all message IDs with pending calls."""
Expand Down Expand Up @@ -456,32 +488,6 @@ async def publish( # noqa: PLR0913
# Wait for confirmation
await self._wait_for(confirmation.wait(), timeout=timeout)

async def _messages(self) -> AsyncGenerator[Message, None]:
"""Async generator that yields messages from the underlying message queue."""
while True:
# Wait until we either:
# 1. Receive a message
# 2. Disconnect from the broker
task = self._loop.create_task(self._queue.get())
try:
done, _ = await asyncio.wait(
(task, self._disconnected), return_when=asyncio.FIRST_COMPLETED
)
except asyncio.CancelledError:
# If the asyncio.wait is cancelled, we must make sure
# to also cancel the underlying tasks.
task.cancel()
raise
if task in done:
# We received a message. Return the result.
yield task.result()
else:
# We were disconnected from the broker
task.cancel()
# Stop the generator with an exception
msg = "Disconnected during message iteration"
raise MqttError(msg)

async def _wait_for(
self, fut: Awaitable[T], timeout: float | None, **kwargs: Any
) -> T:
Expand Down

0 comments on commit becacae

Please sign in to comment.