Skip to content

Commit

Permalink
Improve async_track_template_result callback typing (#97135)
Browse files Browse the repository at this point in the history
  • Loading branch information
cdce8p committed Jul 24, 2023
1 parent c0da6b8 commit 582499a
Show file tree
Hide file tree
Showing 7 changed files with 228 additions and 72 deletions.
13 changes: 6 additions & 7 deletions homeassistant/components/bayesian/binary_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
STATE_UNAVAILABLE,
STATE_UNKNOWN,
)
from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import ConditionError, TemplateError
from homeassistant.helpers import condition
import homeassistant.helpers.config_validation as cv
Expand Down Expand Up @@ -259,14 +259,13 @@ def async_threshold_sensor_state_listener(

@callback
def _async_template_result_changed(
event: Event | None, updates: list[TrackTemplateResult]
event: EventType[EventStateChangedData] | None,
updates: list[TrackTemplateResult],
) -> None:
track_template_result = updates.pop()
template = track_template_result.template
result = track_template_result.result
entity: str | None = (
None if event is None else event.data.get(CONF_ENTITY_ID)
)
entity_id = None if event is None else event.data["entity_id"]
if isinstance(result, TemplateError):
_LOGGER.error(
"TemplateError('%s') while processing template '%s' in entity '%s'",
Expand All @@ -283,8 +282,8 @@ def _async_template_result_changed(
observation.observed = observed

# in some cases a template may update because of the absence of an entity
if entity is not None:
observation.entity_id = entity
if entity_id is not None:
observation.entity_id = entity_id

self.current_observations[observation.id] = observation

Expand Down
21 changes: 14 additions & 7 deletions homeassistant/components/template/trigger.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Offer template automation rules."""
from datetime import timedelta
import logging
from typing import Any

import voluptuous as vol

Expand All @@ -8,13 +10,15 @@
from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback
from homeassistant.helpers import config_validation as cv, template
from homeassistant.helpers.event import (
EventStateChangedData,
TrackTemplate,
TrackTemplateResult,
async_call_later,
async_track_template_result,
)
from homeassistant.helpers.template import Template, result_as_boolean
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
from homeassistant.helpers.typing import ConfigType
from homeassistant.helpers.typing import ConfigType, EventType

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -59,7 +63,10 @@ async def async_attach_trigger(
)

@callback
def template_listener(event, updates):
def template_listener(
event: EventType[EventStateChangedData] | None,
updates: list[TrackTemplateResult],
) -> None:
"""Listen for state changes and calls action."""
nonlocal delay_cancel, armed
result = updates.pop().result
Expand Down Expand Up @@ -88,9 +95,9 @@ def template_listener(event, updates):
# Fire!
armed = False

entity_id = event and event.data.get("entity_id")
from_s = event and event.data.get("old_state")
to_s = event and event.data.get("new_state")
entity_id = event and event.data["entity_id"]
from_s = event and event.data["old_state"]
to_s = event and event.data["new_state"]

if entity_id is not None:
description = f"{entity_id} via template"
Expand All @@ -110,7 +117,7 @@ def template_listener(event, updates):
}

@callback
def call_action(*_):
def call_action(*_: Any) -> None:
"""Call action with right context."""
nonlocal trigger_variables
hass.async_run_hass_job(
Expand All @@ -124,7 +131,7 @@ def call_action(*_):
return

try:
period = cv.positive_time_period(
period: timedelta = cv.positive_time_period(
template.render_complex(time_delta, {"trigger": template_variables})
)
except (exceptions.TemplateError, vol.Invalid) as ex:
Expand Down
6 changes: 5 additions & 1 deletion homeassistant/components/universal/media_player.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
from homeassistant.helpers.event import (
EventStateChangedData,
TrackTemplate,
TrackTemplateResult,
async_track_state_change_event,
async_track_template_result,
)
Expand Down Expand Up @@ -192,7 +193,10 @@ def _async_on_dependency_update(
self.async_schedule_update_ha_state(True)

@callback
def _async_on_template_update(event, updates):
def _async_on_template_update(
event: EventType[EventStateChangedData] | None,
updates: list[TrackTemplateResult],
) -> None:
"""Update state when template state changes."""
for data in updates:
template = data.template
Expand Down
5 changes: 4 additions & 1 deletion homeassistant/components/websocket_api/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from homeassistant.helpers import config_validation as cv, entity, template
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.event import (
EventStateChangedData,
TrackTemplate,
TrackTemplateResult,
async_track_template_result,
Expand All @@ -37,6 +38,7 @@
json_dumps,
)
from homeassistant.helpers.service import async_get_all_descriptions
from homeassistant.helpers.typing import EventType
from homeassistant.loader import (
Integration,
IntegrationNotFound,
Expand Down Expand Up @@ -535,7 +537,8 @@ async def handle_render_template(

@callback
def _template_listener(
event: Event | None, updates: list[TrackTemplateResult]
event: EventType[EventStateChangedData] | None,
updates: list[TrackTemplateResult],
) -> None:
nonlocal info
track_template_result = updates.pop()
Expand Down
6 changes: 2 additions & 4 deletions homeassistant/helpers/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,9 +888,7 @@ def __init__(
self,
hass: HomeAssistant,
track_templates: Sequence[TrackTemplate],
action: Callable[
[EventType[EventStateChangedData] | None, list[TrackTemplateResult]], None
],
action: TrackTemplateResultListener,
has_super_template: bool = False,
) -> None:
"""Handle removal / refresh of tracker init."""
Expand Down Expand Up @@ -1209,7 +1207,7 @@ def _apply_update(
EventType[EventStateChangedData] | None,
list[TrackTemplateResult],
],
None,
Coroutine[Any, Any, None] | None,
]
"""Type for the listener for template results.
Expand Down
14 changes: 9 additions & 5 deletions homeassistant/helpers/template_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
SensorEntity,
)
from homeassistant.const import (
ATTR_ENTITY_ID,
ATTR_ENTITY_PICTURE,
ATTR_FRIENDLY_NAME,
ATTR_ICON,
Expand All @@ -33,7 +32,12 @@

from . import config_validation as cv
from .entity import Entity
from .event import TrackTemplate, TrackTemplateResult, async_track_template_result
from .event import (
EventStateChangedData,
TrackTemplate,
TrackTemplateResult,
async_track_template_result,
)
from .script import Script, _VarsType
from .template import (
Template,
Expand All @@ -42,7 +46,7 @@
render_complex,
result_as_boolean,
)
from .typing import ConfigType
from .typing import ConfigType, EventType

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -327,14 +331,14 @@ def add_template_attribute(
@callback
def _handle_results(
self,
event: Event | None,
event: EventType[EventStateChangedData] | None,
updates: list[TrackTemplateResult],
) -> None:
"""Call back the results to the attributes."""
if event:
self.async_set_context(event.context)

entity_id = event and event.data.get(ATTR_ENTITY_ID)
entity_id = event and event.data["entity_id"]

if entity_id and entity_id == self.entity_id:
self._self_ref_update_count += 1
Expand Down
Loading

0 comments on commit 582499a

Please sign in to comment.