from __future__ import annotations import asyncio import json import logging from asyncio import Event, Lock, Queue, Task from typing import Any, Dict, Optional, cast from errors import HomeAssistantError from ha_types import HAEvent from websockets.client import connect as ws_connect from websockets.exceptions import InvalidStatusCode logger = logging.getLogger(__name__) class HomeAssistantAPI: def __init__(self, token: str, url: str) -> None: self.token = token self.msg_id = 1 self.msg_id_lock = Lock() self.ws: Any = None self.url = url self.receiver: Optional[Task[Any]] = None self.sender: Optional[Task[Any]] = None self.sending_queue: Queue[Dict[str, Any]] = Queue() self.authenticated: Event = Event() self.events: Dict[int, Queue[HAEvent]] = {} self.responses: Dict[int, Dict[str, Any]] = {} self.response_events: Dict[int, Event] = {} self.response_lock: Lock = Lock() async def connect(self): retries = 5 logger.info("Connect to Home Assistant") while True: try: self.ws = await ws_connect(self.url) self.sender = asyncio.create_task(self.sending()) await self.auth() self.receiver = asyncio.create_task(self.receiving()) return True except (InvalidStatusCode, TimeoutError): if retries > 0: retries -= 1 await asyncio.sleep(30) logger.warning( "Retry Home Assistant connection (%s retries left)", retries ) continue logger.error("Invalid status code while connecting to Home Assistant") await self.exit_loop() raise HomeAssistantError( "Invalid status code while connecting to Home Assistant" ) async def wait_for_close(self): await self.ws.wait_closed() async def receiving(self): logger.debug("Start receiving") async for message in self.ws: msg: Dict[str, Any] = json.loads(cast(str, message)) if msg["type"] == "event": if msg["id"] not in self.events.keys(): logger.warning( "Received event for not subscribed id: %s %s", msg["id"], msg.get("event_type"), ) continue await self.events[msg["id"]].put(msg["event"]) else: async with self.response_lock: self.responses[msg["id"]] = msg if msg["id"] in self.response_events.keys(): self.response_events[msg["id"]].set() async def wait_for(self, idx: int): async with self.response_lock: if idx in self.responses.keys(): msg = self.responses[idx] del self.responses[idx] return msg self.response_events[idx] = Event() await self.response_events[idx].wait() async with self.response_lock: del self.response_events[idx] if idx not in self.responses.keys(): logger.warning("Response ID not found") return None msg = self.responses[idx] del self.responses[idx] return msg async def exit_loop(self): if self.sender is not None: self.sender.cancel() if self.receiver is not None: self.receiver.cancel() async def close(self) -> None: await self.exit_loop() if self.ws is not None: try: await self.ws.close() except Exception: pass async def auth(self): msg = json.loads(await self.ws.recv()) if msg["type"] != "auth_required": await self.exit_loop() raise HomeAssistantError("Authentication error: Not required") response: Dict[str, Any] = {"type": "auth", "access_token": self.token} await self.sending_queue.put(response) msg = json.loads(await self.ws.recv()) if msg["type"] == "auth_invalid": await self.exit_loop() raise HomeAssistantError("Auth failed") elif msg["type"] == "auth_ok": logger.info("Authenticated") self.authenticated.set() else: await self.exit_loop() raise HomeAssistantError(f"Unknown answer for auth: {msg}") async def sending(self): while msg := await self.sending_queue.get(): await self.ws.send(json.dumps(msg)) async def subscribe_event(self, event_type: str): await self.authenticated.wait() logger.info("Subscribe to %s", event_type) async with self.msg_id_lock: msg_id = self.msg_id response: Dict[str, Any] = { "id": msg_id, "type": "subscribe_events", "event_type": event_type, } self.events[msg_id] = Queue() self.msg_id += 1 await self.sending_queue.put(response) return msg_id async def get_states(self) -> list[Dict[str, Any]]: await self.authenticated.wait() async with self.msg_id_lock: message: Dict[str, Any] = {"id": self.msg_id, "type": "get_states"} self.msg_id += 1 await self.sending_queue.put(message) response = await self.wait_for(cast(int, message["id"])) # ToDo: Error handling if response is None: return [] return cast(list[Dict[str, Any]], response.get("result", [])) async def get_device_state(self, entity_id: str) -> Optional[Dict[str, Any]]: device_states = await self.get_states() for device_state in device_states: if device_state["entity_id"] == entity_id: return device_state return None