Skip to content

Commit

Permalink
Change assist satellite announce method signature (home-assistant#126299
Browse files Browse the repository at this point in the history
)
  • Loading branch information
balloob authored Sep 20, 2024
1 parent 41ffa8d commit 604c848
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 17 deletions.
2 changes: 2 additions & 0 deletions homeassistant/components/assist_satellite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from .const import DOMAIN, DOMAIN_DATA, AssistSatelliteEntityFeature
from .entity import (
AssistSatelliteAnnouncement,
AssistSatelliteConfiguration,
AssistSatelliteEntity,
AssistSatelliteEntityDescription,
Expand All @@ -22,6 +23,7 @@

__all__ = [
"DOMAIN",
"AssistSatelliteAnnouncement",
"AssistSatelliteEntity",
"AssistSatelliteConfiguration",
"AssistSatelliteEntityDescription",
Expand Down
29 changes: 26 additions & 3 deletions homeassistant/components/assist_satellite/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from enum import StrEnum
import logging
import time
from typing import Any, Final, final
from typing import Any, Final, Literal, final

from homeassistant.components import media_source, stt, tts
from homeassistant.components.assist_pipeline import (
Expand Down Expand Up @@ -86,6 +86,19 @@ class AssistSatelliteConfiguration:
"""Maximum number of simultaneous wake words allowed (0 for no limit)."""


@dataclass
class AssistSatelliteAnnouncement:
"""Announcement to be made."""

message: str
"""Message to be spoken."""

media_id: str
"""Media ID to be played."""

media_id_source: Literal["url", "media_id", "tts"]


class AssistSatelliteEntity(entity.Entity):
"""Entity encapsulating the state and functionality of an Assist satellite."""

Expand Down Expand Up @@ -174,10 +187,13 @@ async def async_internal_announce(
"""
await self._cancel_running_pipeline()

media_id_source: Literal["url", "media_id", "tts"] | None = None

if message is None:
message = ""

if not media_id:
media_id_source = "tts"
# Synthesize audio and get URL
pipeline_id = self._resolve_pipeline()
pipeline = async_get_pipeline(self.hass, pipeline_id)
Expand All @@ -198,13 +214,18 @@ async def async_internal_announce(
)

if media_source.is_media_source_id(media_id):
if not media_id_source:
media_id_source = "media_id"
media = await media_source.async_resolve_media(
self.hass,
media_id,
None,
)
media_id = media.url

if not media_id_source:
media_id_source = "url"

# Resolve to full URL
media_id = async_process_play_media_url(self.hass, media_id)

Expand All @@ -216,12 +237,14 @@ async def async_internal_announce(

try:
# Block until announcement is finished
await self.async_announce(message, media_id)
await self.async_announce(
AssistSatelliteAnnouncement(message, media_id, media_id_source)
)
finally:
self._is_announcing = False
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)

async def async_announce(self, message: str, media_id: str) -> None:
async def async_announce(self, announcement: AssistSatelliteAnnouncement) -> None:
"""Announce media on the satellite.
Should block until the announcement is done playing.
Expand Down
10 changes: 6 additions & 4 deletions homeassistant/components/esphome/assist_satellite.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,18 +313,20 @@ def on_pipeline_event(self, event: PipelineEvent) -> None:

self.cli.send_voice_assistant_event(event_type, data_to_send)

async def async_announce(self, message: str, media_id: str) -> None:
async def async_announce(
self, announcement: assist_satellite.AssistSatelliteAnnouncement
) -> None:
"""Announce media on the satellite.
Should block until the announcement is done playing.
"""
_LOGGER.debug(
"Waiting for announcement to finished (message=%s, media_id=%s)",
message,
media_id,
announcement.message,
announcement.media_id,
)
await self.cli.send_voice_assistant_announcement_await_response(
media_id, _ANNOUNCEMENT_TIMEOUT_SEC, message
announcement.media_id, _ANNOUNCEMENT_TIMEOUT_SEC, announcement.message
)

async def handle_pipeline_start(
Expand Down
5 changes: 3 additions & 2 deletions tests/components/assist_satellite/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from homeassistant.components.assist_pipeline import PipelineEvent
from homeassistant.components.assist_satellite import (
DOMAIN as AS_DOMAIN,
AssistSatelliteAnnouncement,
AssistSatelliteConfiguration,
AssistSatelliteEntity,
AssistSatelliteEntityFeature,
Expand Down Expand Up @@ -63,9 +64,9 @@ def on_pipeline_event(self, event: PipelineEvent) -> None:
"""Handle pipeline events."""
self.events.append(event)

async def async_announce(self, message: str, media_id: str) -> None:
async def async_announce(self, announcement: AssistSatelliteAnnouncement) -> None:
"""Announce media on a device."""
self.announcements.append((message, media_id))
self.announcements.append(announcement)

@callback
def async_get_configuration(self) -> AssistSatelliteConfiguration:
Expand Down
23 changes: 15 additions & 8 deletions tests/components/assist_satellite/test_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
async_update_pipeline,
vad,
)
from homeassistant.components.assist_satellite import SatelliteBusyError
from homeassistant.components.assist_satellite import (
AssistSatelliteAnnouncement,
SatelliteBusyError,
)
from homeassistant.components.assist_satellite.entity import AssistSatelliteState
from homeassistant.components.media_source import PlayMedia
from homeassistant.config_entries import ConfigEntry
Expand Down Expand Up @@ -159,18 +162,22 @@ async def async_pipeline_from_audio_stream(*args, **kwargs):
[
(
{"message": "Hello"},
("Hello", "https://www.home-assistant.io/resolved.mp3"),
AssistSatelliteAnnouncement(
"Hello", "https://www.home-assistant.io/resolved.mp3", "tts"
),
),
(
{
"message": "Hello",
"media_id": "http://example.com/bla.mp3",
"media_id": "media-source://bla",
},
("Hello", "http://example.com/bla.mp3"),
AssistSatelliteAnnouncement(
"Hello", "https://www.home-assistant.io/resolved.mp3", "media_id"
),
),
(
{"media_id": "http://example.com/bla.mp3"},
("", "http://example.com/bla.mp3"),
AssistSatelliteAnnouncement("", "http://example.com/bla.mp3", "url"),
),
],
)
Expand All @@ -195,10 +202,10 @@ async def test_announce(
original_announce = entity.async_announce
announce_started = asyncio.Event()

async def async_announce(message, media_id):
async def async_announce(announcement):
# Verify state change
assert entity.state == AssistSatelliteState.RESPONDING
await original_announce(message, media_id)
await original_announce(announcement)
announce_started.set()

def tts_generate_media_source_id(
Expand Down Expand Up @@ -249,7 +256,7 @@ async def test_announce_busy(
announce_started = asyncio.Event()
got_error = asyncio.Event()

async def async_announce(message, media_id):
async def async_announce(announcement):
announce_started.set()

# Block so we can do another announcement
Expand Down

0 comments on commit 604c848

Please sign in to comment.