from __future__ import annotations import asyncio import json import logging from asyncio import Queue, Task, Event, Lock from typing import Callable, Dict, Optional import websockets from websockets.exceptions import InvalidStatusCode class HomeAssistantAPI: def __init__(self, token: str, url: str) -> None: self.token = token self.msg_id = 1 self.msg_id_lock = Lock() self.ws: websockets.WebSocketClientProtocol = None self.url = url self.receiver: Optional[Task] = None self.sender: Optional[Task] = None self.sending_queue: Queue = Queue() self.authenticated: Event = Event() self.events: Dict[int, Queue] = {} self.responses: Dict[int, Dict] = {} self.response_events: Dict[int, Event] = {} self.response_lock: Lock = Lock() async def connect(self): retries = 5 logging.info("Connect to home assistant...") while True: try: self.ws = await websockets.connect(self.url) self.sender = asyncio.create_task(self.sending()) await self.auth() self.receiver = asyncio.create_task(self.receiving()) return True except (InvalidStatusCode, TimeoutError): if retries > 0: retries -= 1 await asyncio.sleep(30) logging.info(f"Retry home assistant connection... ({retries})") continue else: logging.error("Invalid status code while connecting to Home Assistant") await self.exit_loop() return False async def wait_for_close(self): await self.ws.wait_closed() async def receiving(self): logging.debug("Start receiving") async for message in self.ws: msg: Dict = json.loads(message) if msg["type"] == "event": if msg["id"] not in self.events.keys(): logging.error(f"Received event for not subscribted id: {msg['id']} {msg['event_type']}") continue await self.events[msg["id"]].put(msg["event"]) else: async with self.response_lock: self.responses[msg["id"]] = msg if msg["id"] in self.response_events.keys(): self.response_events[msg["id"]].set() async def wait_for(self, idx): async with self.response_lock: if idx in self.responses.keys(): msg = self.responses[idx] del self.responses[idx] return msg self.response_events[idx] = Event() await self.response_events[idx].wait() async with self.response_lock: del self.response_events[idx] if idx not in self.responses.keys(): logging.error("Response ID not found") return None msg = self.responses[idx] del self.responses[idx] return msg async def exit_loop(self): if self.sender is not None: self.sender.cancel() if self.receiver is not None: self.receiver.cancel() async def auth(self): msg = json.loads(await self.ws.recv()) if msg["type"] != "auth_required": logging.error("Authentication error: Not required") await self.exit_loop() response = { "type": "auth", "access_token": self.token } await self.sending_queue.put(response) msg = json.loads(await self.ws.recv()) if msg["type"] == "auth_invalid": logging.info("Auth failed") await self.exit_loop() elif msg["type"] == "auth_ok": logging.debug("Authenticated") self.authenticated.set() else: logging.error(f"Unknown answer for auth: {msg}") await self.exit_loop() async def sending(self): while msg := await self.sending_queue.get(): await self.ws.send(json.dumps(msg)) async def subscribe_event(self, event_type: str): await self.authenticated.wait() logging.info(f"Subscribe to {event_type}") async with self.msg_id_lock: msg_id = self.msg_id response = { "id": msg_id, "type": "subscribe_events", "event_type": event_type } self.events[msg_id] = Queue() self.msg_id += 1 await self.sending_queue.put(response) return msg_id async def get_states(self): await self.authenticated.wait() async with self.msg_id_lock: message = { "id": self.msg_id, "type": "get_states" } self.msg_id += 1 await self.sending_queue.put(message) response = await self.wait_for(message["id"]) # ToDo: Error handling return response["result"] async def get_device_state(self, entity_id: str): device_states = await self.get_states() for device_state in device_states: if device_state["entity_id"] == entity_id: return device_state return None