2022-12-03 21:51:26 +01:00

153 lines
5.2 KiB
Python
Executable File

from __future__ import annotations
import asyncio
import json
import logging
from asyncio import Queue, Task, Event, Lock
from typing import Callable, Dict, Optional
import websockets
from websockets.exceptions import InvalidStatusCode
class HomeAssistantAPI:
def __init__(self, token: str, url: str) -> None:
self.token = token
self.msg_id = 1
self.msg_id_lock = Lock()
self.ws: websockets.WebSocketClientProtocol = None
self.url = url
self.receiver: Optional[Task] = None
self.sender: Optional[Task] = None
self.sending_queue: Queue = Queue()
self.authenticated: Event = Event()
self.events: Dict[int, Queue] = {}
self.responses: Dict[int, Dict] = {}
self.response_events: Dict[int, Event] = {}
self.response_lock: Lock = Lock()
async def connect(self):
retries = 5
logging.info("Connect to home assistant...")
while True:
try:
self.ws = await websockets.connect(self.url)
self.sender = asyncio.create_task(self.sending())
await self.auth()
self.receiver = asyncio.create_task(self.receiving())
return True
except (InvalidStatusCode, TimeoutError):
if retries > 0:
retries -= 1
await asyncio.sleep(30)
logging.info(f"Retry home assistant connection... ({retries})")
continue
else:
logging.error("Invalid status code while connecting to Home Assistant")
await self.exit_loop()
return False
async def wait_for_close(self):
await self.ws.wait_closed()
async def receiving(self):
logging.debug("Start receiving")
async for message in self.ws:
msg: Dict = json.loads(message)
if msg["type"] == "event":
if msg["id"] not in self.events.keys():
logging.error(f"Received event for not subscribted id: {msg['id']} {msg['event_type']}")
continue
await self.events[msg["id"]].put(msg["event"])
else:
async with self.response_lock:
self.responses[msg["id"]] = msg
if msg["id"] in self.response_events.keys():
self.response_events[msg["id"]].set()
async def wait_for(self, idx):
async with self.response_lock:
if idx in self.responses.keys():
msg = self.responses[idx]
del self.responses[idx]
return msg
self.response_events[idx] = Event()
await self.response_events[idx].wait()
async with self.response_lock:
del self.response_events[idx]
if idx not in self.responses.keys():
logging.error("Response ID not found")
return None
msg = self.responses[idx]
del self.responses[idx]
return msg
async def exit_loop(self):
if self.sender is not None:
self.sender.cancel()
if self.receiver is not None:
self.receiver.cancel()
async def auth(self):
msg = json.loads(await self.ws.recv())
if msg["type"] != "auth_required":
logging.error("Authentication error: Not required")
await self.exit_loop()
response = {
"type": "auth",
"access_token": self.token
}
await self.sending_queue.put(response)
msg = json.loads(await self.ws.recv())
if msg["type"] == "auth_invalid":
logging.info("Auth failed")
await self.exit_loop()
elif msg["type"] == "auth_ok":
logging.debug("Authenticated")
self.authenticated.set()
else:
logging.error(f"Unknown answer for auth: {msg}")
await self.exit_loop()
async def sending(self):
while msg := await self.sending_queue.get():
await self.ws.send(json.dumps(msg))
async def subscribe_event(self, event_type: str):
await self.authenticated.wait()
logging.info(f"Subscribe to {event_type}")
async with self.msg_id_lock:
msg_id = self.msg_id
response = {
"id": msg_id,
"type": "subscribe_events",
"event_type": event_type
}
self.events[msg_id] = Queue()
self.msg_id += 1
await self.sending_queue.put(response)
return msg_id
async def get_states(self):
await self.authenticated.wait()
async with self.msg_id_lock:
message = {
"id": self.msg_id,
"type": "get_states"
}
self.msg_id += 1
await self.sending_queue.put(message)
response = await self.wait_for(message["id"])
# ToDo: Error handling
return response["result"]
async def get_device_state(self, entity_id: str):
device_states = await self.get_states()
for device_state in device_states:
if device_state["entity_id"] == entity_id:
return device_state
return None