Skip to content

Commit

Permalink
Stop Speedtest sensors update on startup if manual option is enabled (h…
Browse files Browse the repository at this point in the history
…ome-assistant#37403)

Co-authored-by: Paulus Schoutsen <[email protected]>
  • Loading branch information
engrbm87 and balloob committed Jul 6, 2020
1 parent 8916409 commit ddb049e
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 24 deletions.
28 changes: 20 additions & 8 deletions homeassistant/components/speedtestdotnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,10 @@ async def async_setup_entry(hass, config_entry):
coordinator = SpeedTestDataCoordinator(hass, config_entry)
await coordinator.async_setup()

await coordinator.async_refresh()
if not coordinator.last_update_success:
raise ConfigEntryNotReady
if not config_entry.options[CONF_MANUAL]:
await coordinator.async_refresh()
if not coordinator.last_update_success:
raise ConfigEntryNotReady

hass.data[DOMAIN] = coordinator

Expand Down Expand Up @@ -115,24 +116,33 @@ def __init__(self, hass, config_entry):
),
)

def update_data(self):
"""Get the latest data from speedtest.net."""
server_list = self.api.get_servers()
def update_servers(self):
"""Update list of test servers."""
try:
server_list = self.api.get_servers()
except speedtest.ConfigRetrievalError:
return

self.servers[DEFAULT_SERVER] = {}
for server in sorted(
server_list.values(), key=lambda server: server[0]["country"]
):
self.servers[f"{server[0]['country']} - {server[0]['sponsor']}"] = server[0]

def update_data(self):
"""Get the latest data from speedtest.net."""
self.update_servers()

self.api.closest.clear()
if self.config_entry.options.get(CONF_SERVER_ID):
server_id = self.config_entry.options.get(CONF_SERVER_ID)
self.api.closest.clear()
self.api.get_servers(servers=[server_id])

self.api.get_best_server()
_LOGGER.debug(
"Executing speedtest.net speed test with server_id: %s", self.api.best["id"]
)
self.api.get_best_server()

self.api.download()
self.api.upload()
return self.api.results.dict()
Expand Down Expand Up @@ -170,6 +180,8 @@ async def request_update(call):

await self.async_set_options()

await self.hass.async_add_executor_job(self.update_servers)

self.hass.services.async_register(DOMAIN, SPEED_TEST_SERVICE, request_update)

self.config_entry.add_update_listener(options_updated_listener)
Expand Down
4 changes: 2 additions & 2 deletions homeassistant/components/speedtestdotnet/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ async def async_step_init(self, user_input=None):

self._servers = self.hass.data[DOMAIN].servers

server_name = DEFAULT_SERVER
server = []
if self.config_entry.options.get(
CONF_SERVER_ID
) and not self.config_entry.options.get(CONF_SERVER_NAME):
Expand All @@ -94,7 +94,7 @@ async def async_step_init(self, user_input=None):
for (key, value) in self._servers.items()
if value.get("id") == self.config_entry.options[CONF_SERVER_ID]
]
server_name = server[0] if server else ""
server_name = server[0] if server else DEFAULT_SERVER

options = {
vol.Optional(
Expand Down
46 changes: 32 additions & 14 deletions homeassistant/components/speedtestdotnet/sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import logging

from homeassistant.const import ATTR_ATTRIBUTION
from homeassistant.helpers.entity import Entity
from homeassistant.core import callback
from homeassistant.helpers.restore_state import RestoreEntity

from .const import (
ATTR_BYTES_RECEIVED,
Expand All @@ -11,6 +12,7 @@
ATTR_SERVER_ID,
ATTR_SERVER_NAME,
ATTRIBUTION,
CONF_MANUAL,
DEFAULT_NAME,
DOMAIN,
ICON,
Expand All @@ -32,7 +34,7 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
async_add_entities(entities)


class SpeedtestSensor(Entity):
class SpeedtestSensor(RestoreEntity):
"""Implementation of a speedtest.net sensor."""

def __init__(self, coordinator, sensor_type):
Expand All @@ -41,6 +43,7 @@ def __init__(self, coordinator, sensor_type):
self.coordinator = coordinator
self.type = sensor_type
self._unit_of_measurement = SENSOR_TYPES[self.type][1]
self._state = None

@property
def name(self):
Expand All @@ -55,14 +58,7 @@ def unique_id(self):
@property
def state(self):
"""Return the state of the device."""
state = None
if self.type == "ping":
state = self.coordinator.data["ping"]
elif self.type == "download":
state = round(self.coordinator.data["download"] / 10 ** 6, 2)
elif self.type == "upload":
state = round(self.coordinator.data["upload"] / 10 ** 6, 2)
return state
return self._state

@property
def unit_of_measurement(self):
Expand All @@ -82,6 +78,8 @@ def should_poll(self):
@property
def device_state_attributes(self):
"""Return the state attributes."""
if not self.coordinator.data:
return None
attributes = {
ATTR_ATTRIBUTION: ATTRIBUTION,
ATTR_SERVER_NAME: self.coordinator.data["server"]["name"],
Expand All @@ -98,10 +96,30 @@ def device_state_attributes(self):

async def async_added_to_hass(self):
"""Handle entity which will be added."""

self.async_on_remove(
self.coordinator.async_add_listener(self.async_write_ha_state)
)
await super().async_added_to_hass()
if self.coordinator.config_entry.options[CONF_MANUAL]:
state = await self.async_get_last_state()
if state:
self._state = state.state

@callback
def update():
"""Update state."""
self._update_state()
self.async_write_ha_state()

self.async_on_remove(self.coordinator.async_add_listener(update))
self._update_state()

def _update_state(self):
"""Update sensors state."""
if self.coordinator.data:
if self.type == "ping":
self._state = self.coordinator.data["ping"]
elif self.type == "download":
self._state = round(self.coordinator.data["download"] / 10 ** 6, 2)
elif self.type == "upload":
self._state = round(self.coordinator.data["upload"] / 10 ** 6, 2)

async def async_update(self):
"""Request coordinator to update data."""
Expand Down

0 comments on commit ddb049e

Please sign in to comment.