153 lines
5.2 KiB
Python
Executable File
153 lines
5.2 KiB
Python
Executable File
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from asyncio import Queue, Task, Event, Lock
|
|
from typing import Callable, Dict, Optional
|
|
import websockets
|
|
from websockets.exceptions import InvalidStatusCode
|
|
|
|
|
|
class HomeAssistantAPI:
|
|
def __init__(self, token: str, url: str) -> None:
|
|
self.token = token
|
|
self.msg_id = 1
|
|
self.msg_id_lock = Lock()
|
|
self.ws: websockets.WebSocketClientProtocol = None
|
|
self.url = url
|
|
self.receiver: Optional[Task] = None
|
|
self.sender: Optional[Task] = None
|
|
self.sending_queue: Queue = Queue()
|
|
self.authenticated: Event = Event()
|
|
self.events: Dict[int, Queue] = {}
|
|
self.responses: Dict[int, Dict] = {}
|
|
self.response_events: Dict[int, Event] = {}
|
|
self.response_lock: Lock = Lock()
|
|
|
|
async def connect(self):
|
|
retries = 5
|
|
logging.info("Connect to home assistant...")
|
|
while True:
|
|
try:
|
|
self.ws = await websockets.connect(self.url)
|
|
self.sender = asyncio.create_task(self.sending())
|
|
await self.auth()
|
|
self.receiver = asyncio.create_task(self.receiving())
|
|
return True
|
|
except (InvalidStatusCode, TimeoutError):
|
|
if retries > 0:
|
|
retries -= 1
|
|
await asyncio.sleep(30)
|
|
logging.info(f"Retry home assistant connection... ({retries})")
|
|
continue
|
|
else:
|
|
logging.error("Invalid status code while connecting to Home Assistant")
|
|
await self.exit_loop()
|
|
return False
|
|
|
|
async def wait_for_close(self):
|
|
await self.ws.wait_closed()
|
|
|
|
async def receiving(self):
|
|
logging.debug("Start receiving")
|
|
async for message in self.ws:
|
|
msg: Dict = json.loads(message)
|
|
if msg["type"] == "event":
|
|
if msg["id"] not in self.events.keys():
|
|
logging.error(f"Received event for not subscribted id: {msg['id']} {msg['event_type']}")
|
|
continue
|
|
await self.events[msg["id"]].put(msg["event"])
|
|
else:
|
|
async with self.response_lock:
|
|
self.responses[msg["id"]] = msg
|
|
if msg["id"] in self.response_events.keys():
|
|
self.response_events[msg["id"]].set()
|
|
|
|
async def wait_for(self, idx):
|
|
async with self.response_lock:
|
|
if idx in self.responses.keys():
|
|
msg = self.responses[idx]
|
|
del self.responses[idx]
|
|
return msg
|
|
self.response_events[idx] = Event()
|
|
|
|
await self.response_events[idx].wait()
|
|
async with self.response_lock:
|
|
del self.response_events[idx]
|
|
if idx not in self.responses.keys():
|
|
logging.error("Response ID not found")
|
|
return None
|
|
msg = self.responses[idx]
|
|
del self.responses[idx]
|
|
return msg
|
|
|
|
async def exit_loop(self):
|
|
if self.sender is not None:
|
|
self.sender.cancel()
|
|
if self.receiver is not None:
|
|
self.receiver.cancel()
|
|
|
|
async def auth(self):
|
|
msg = json.loads(await self.ws.recv())
|
|
if msg["type"] != "auth_required":
|
|
logging.error("Authentication error: Not required")
|
|
await self.exit_loop()
|
|
response = {
|
|
"type": "auth",
|
|
"access_token": self.token
|
|
}
|
|
await self.sending_queue.put(response)
|
|
msg = json.loads(await self.ws.recv())
|
|
if msg["type"] == "auth_invalid":
|
|
logging.info("Auth failed")
|
|
await self.exit_loop()
|
|
elif msg["type"] == "auth_ok":
|
|
logging.debug("Authenticated")
|
|
self.authenticated.set()
|
|
else:
|
|
logging.error(f"Unknown answer for auth: {msg}")
|
|
await self.exit_loop()
|
|
|
|
async def sending(self):
|
|
while msg := await self.sending_queue.get():
|
|
await self.ws.send(json.dumps(msg))
|
|
|
|
async def subscribe_event(self, event_type: str):
|
|
await self.authenticated.wait()
|
|
|
|
logging.info(f"Subscribe to {event_type}")
|
|
async with self.msg_id_lock:
|
|
msg_id = self.msg_id
|
|
response = {
|
|
"id": msg_id,
|
|
"type": "subscribe_events",
|
|
"event_type": event_type
|
|
}
|
|
self.events[msg_id] = Queue()
|
|
self.msg_id += 1
|
|
await self.sending_queue.put(response)
|
|
return msg_id
|
|
|
|
async def get_states(self):
|
|
await self.authenticated.wait()
|
|
async with self.msg_id_lock:
|
|
message = {
|
|
"id": self.msg_id,
|
|
"type": "get_states"
|
|
}
|
|
self.msg_id += 1
|
|
await self.sending_queue.put(message)
|
|
|
|
response = await self.wait_for(message["id"])
|
|
# ToDo: Error handling
|
|
return response["result"]
|
|
|
|
async def get_device_state(self, entity_id: str):
|
|
device_states = await self.get_states()
|
|
for device_state in device_states:
|
|
if device_state["entity_id"] == entity_id:
|
|
return device_state
|
|
|
|
return None
|