169 lines
6.0 KiB
Python
Executable File
169 lines
6.0 KiB
Python
Executable File
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from asyncio import Event, Lock, Queue, Task
|
|
from typing import Any, Dict, Optional, cast
|
|
|
|
from errors import HomeAssistantError
|
|
from ha_types import HAEvent
|
|
from websockets.client import connect as ws_connect
|
|
from websockets.exceptions import InvalidStatusCode
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class HomeAssistantAPI:
|
|
def __init__(self, token: str, url: str) -> None:
|
|
self.token = token
|
|
self.msg_id = 1
|
|
self.msg_id_lock = Lock()
|
|
self.ws: Any = None
|
|
self.url = url
|
|
self.receiver: Optional[Task[Any]] = None
|
|
self.sender: Optional[Task[Any]] = None
|
|
self.sending_queue: Queue[Dict[str, Any]] = Queue()
|
|
self.authenticated: Event = Event()
|
|
self.events: Dict[int, Queue[HAEvent]] = {}
|
|
self.responses: Dict[int, Dict[str, Any]] = {}
|
|
self.response_events: Dict[int, Event] = {}
|
|
self.response_lock: Lock = Lock()
|
|
|
|
async def connect(self):
|
|
retries = 5
|
|
logger.info("Connect to Home Assistant")
|
|
while True:
|
|
try:
|
|
self.ws = await ws_connect(self.url)
|
|
self.sender = asyncio.create_task(self.sending())
|
|
await self.auth()
|
|
self.receiver = asyncio.create_task(self.receiving())
|
|
return True
|
|
except (InvalidStatusCode, TimeoutError):
|
|
if retries > 0:
|
|
retries -= 1
|
|
await asyncio.sleep(30)
|
|
logger.warning(
|
|
"Retry Home Assistant connection (%s retries left)", retries
|
|
)
|
|
continue
|
|
logger.error("Invalid status code while connecting to Home Assistant")
|
|
await self.exit_loop()
|
|
raise HomeAssistantError(
|
|
"Invalid status code while connecting to Home Assistant"
|
|
)
|
|
|
|
async def wait_for_close(self):
|
|
await self.ws.wait_closed()
|
|
|
|
async def receiving(self):
|
|
logger.debug("Start receiving")
|
|
async for message in self.ws:
|
|
msg: Dict[str, Any] = json.loads(cast(str, message))
|
|
if msg["type"] == "event":
|
|
if msg["id"] not in self.events.keys():
|
|
logger.warning(
|
|
"Received event for not subscribed id: %s %s",
|
|
msg["id"],
|
|
msg.get("event_type"),
|
|
)
|
|
continue
|
|
await self.events[msg["id"]].put(msg["event"])
|
|
else:
|
|
async with self.response_lock:
|
|
self.responses[msg["id"]] = msg
|
|
if msg["id"] in self.response_events.keys():
|
|
self.response_events[msg["id"]].set()
|
|
|
|
async def wait_for(self, idx: int):
|
|
async with self.response_lock:
|
|
if idx in self.responses.keys():
|
|
msg = self.responses[idx]
|
|
del self.responses[idx]
|
|
return msg
|
|
self.response_events[idx] = Event()
|
|
|
|
await self.response_events[idx].wait()
|
|
async with self.response_lock:
|
|
del self.response_events[idx]
|
|
if idx not in self.responses.keys():
|
|
logger.warning("Response ID not found")
|
|
return None
|
|
msg = self.responses[idx]
|
|
del self.responses[idx]
|
|
return msg
|
|
|
|
async def exit_loop(self):
|
|
if self.sender is not None:
|
|
self.sender.cancel()
|
|
if self.receiver is not None:
|
|
self.receiver.cancel()
|
|
|
|
async def close(self) -> None:
|
|
await self.exit_loop()
|
|
if self.ws is not None:
|
|
try:
|
|
await self.ws.close()
|
|
except Exception:
|
|
pass
|
|
|
|
async def auth(self):
|
|
msg = json.loads(await self.ws.recv())
|
|
if msg["type"] != "auth_required":
|
|
await self.exit_loop()
|
|
raise HomeAssistantError("Authentication error: Not required")
|
|
response: Dict[str, Any] = {"type": "auth", "access_token": self.token}
|
|
await self.sending_queue.put(response)
|
|
msg = json.loads(await self.ws.recv())
|
|
if msg["type"] == "auth_invalid":
|
|
await self.exit_loop()
|
|
raise HomeAssistantError("Auth failed")
|
|
elif msg["type"] == "auth_ok":
|
|
logger.info("Authenticated")
|
|
self.authenticated.set()
|
|
else:
|
|
await self.exit_loop()
|
|
raise HomeAssistantError(f"Unknown answer for auth: {msg}")
|
|
|
|
async def sending(self):
|
|
while msg := await self.sending_queue.get():
|
|
await self.ws.send(json.dumps(msg))
|
|
|
|
async def subscribe_event(self, event_type: str):
|
|
await self.authenticated.wait()
|
|
|
|
logger.info("Subscribe to %s", event_type)
|
|
async with self.msg_id_lock:
|
|
msg_id = self.msg_id
|
|
response: Dict[str, Any] = {
|
|
"id": msg_id,
|
|
"type": "subscribe_events",
|
|
"event_type": event_type,
|
|
}
|
|
self.events[msg_id] = Queue()
|
|
self.msg_id += 1
|
|
await self.sending_queue.put(response)
|
|
return msg_id
|
|
|
|
async def get_states(self) -> list[Dict[str, Any]]:
|
|
await self.authenticated.wait()
|
|
async with self.msg_id_lock:
|
|
message: Dict[str, Any] = {"id": self.msg_id, "type": "get_states"}
|
|
self.msg_id += 1
|
|
await self.sending_queue.put(message)
|
|
|
|
response = await self.wait_for(cast(int, message["id"]))
|
|
# ToDo: Error handling
|
|
if response is None:
|
|
return []
|
|
return cast(list[Dict[str, Any]], response.get("result", []))
|
|
|
|
async def get_device_state(self, entity_id: str) -> Optional[Dict[str, Any]]:
|
|
device_states = await self.get_states()
|
|
for device_state in device_states:
|
|
if device_state["entity_id"] == entity_id:
|
|
return device_state
|
|
|
|
return None
|