diff --git a/.gitignore b/.gitignore index de3ccf9c..f724ac37 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,7 @@ docs/_build/ # GitHub App credentials gha-creds-*.json + +# pickle files +*.p +*.pickle diff --git a/roborock/devices/device.py b/roborock/devices/device.py index ca1fbf14..12a3c2f3 100644 --- a/roborock/devices/device.py +++ b/roborock/devices/device.py @@ -15,7 +15,6 @@ from roborock.data import HomeDataDevice, HomeDataProduct from roborock.diagnostics import redact_device_data from roborock.exceptions import RoborockException -from roborock.roborock_message import RoborockMessage from roborock.util import RoborockLoggerAdapter from .traits import Trait @@ -75,6 +74,7 @@ def __init__( self._channel = channel self._connect_task: asyncio.Task[None] | None = None self._unsub: Callable[[], None] | None = None + self._v1_unsub: Callable[[], None] | None = None self._ready_callbacks = CallbackList["RoborockDevice"]() self._has_connected = False @@ -196,15 +196,23 @@ async def connect(self) -> None: """Connect to the device using the appropriate protocol channel.""" if self._unsub: raise ValueError("Already connected to the device") - unsub = await self._channel.subscribe(self._on_message) + if self.v1_properties is not None: try: + # V1 layer subscribes to the channel and handles protocol updates. + # Note: V1Channel only allows one subscription, so the V1 layer + # is the sole subscriber for V1 devices. + self._v1_unsub = await self.v1_properties.subscribe_async(self._channel) await self.v1_properties.discover_features() except RoborockException: - unsub() + if self._v1_unsub: + self._v1_unsub() raise + else: + # Non-V1 devices subscribe directly (no protocol update handling needed) + self._unsub = await self._channel.subscribe(lambda msg: None) + self._logger.info("Connected to device") - self._unsub = unsub async def close(self) -> None: """Close all connections to the device.""" @@ -214,14 +222,13 @@ async def close(self) -> None: await self._connect_task except asyncio.CancelledError: pass + if self._v1_unsub: + self._v1_unsub() + self._v1_unsub = None if self._unsub: self._unsub() self._unsub = None - def _on_message(self, message: RoborockMessage) -> None: - """Handle incoming messages from the device.""" - self._logger.debug("Received message from device: %s", message) - def diagnostic_data(self) -> dict[str, Any]: """Return diagnostics information about the device.""" extra: dict[str, Any] = {} diff --git a/roborock/devices/rpc/v1_channel.py b/roborock/devices/rpc/v1_channel.py index d1b4ee24..608c9ecc 100644 --- a/roborock/devices/rpc/v1_channel.py +++ b/roborock/devices/rpc/v1_channel.py @@ -305,12 +305,14 @@ async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callab loop = asyncio.get_running_loop() self._reconnect_task = loop.create_task(self._background_reconnect()) - if not self.is_local_connected: - # We were not able to connect locally, so fallback to MQTT and at least - # establish that connection explicitly. If this fails then raise an - # error and let the caller know we failed to subscribe. - self._mqtt_unsub = await self._mqtt_channel.subscribe(self._on_mqtt_message) - self._logger.debug("V1Channel connected to device via MQTT") + # Always subscribe to MQTT to receive protocol updates (data points) + # even if we have a local connection. Protocol updates only come via cloud/MQTT. + # Local connection is used for RPC commands, but push notifications come via MQTT. + self._mqtt_unsub = await self._mqtt_channel.subscribe(self._on_mqtt_message) + if self.is_local_connected: + self._logger.debug("V1Channel connected via local and MQTT (for protocol updates)") + else: + self._logger.debug("V1Channel connected via MQTT only") def unsub() -> None: """Unsubscribe from all messages.""" diff --git a/roborock/devices/traits/v1/__init__.py b/roborock/devices/traits/v1/__init__.py index 7438c6b0..db5b047b 100644 --- a/roborock/devices/traits/v1/__init__.py +++ b/roborock/devices/traits/v1/__init__.py @@ -52,7 +52,9 @@ code in HomeDataProduct Schema that is required for the field to be supported. """ +import json import logging +from collections.abc import Callable from dataclasses import dataclass, field, fields from functools import cache from typing import Any, get_args @@ -61,8 +63,15 @@ from roborock.data.v1.v1_code_mappings import RoborockDockTypeCode from roborock.devices.cache import DeviceCache from roborock.devices.traits import Trait +from roborock.devices.transport.channel import Channel from roborock.map.map_parser import MapParserConfig from roborock.protocols.v1_protocol import V1RpcChannel +from roborock.roborock_message import ( + ROBOROCK_DATA_STATUS_PROTOCOL, + RoborockDataProtocol, + RoborockMessage, + RoborockMessageProtocol, +) from roborock.web_api import UserWebApiClient from . import ( @@ -313,6 +322,79 @@ def as_dict(self) -> dict[str, Any]: result[item.name] = data return result + async def subscribe_async(self, channel: Channel) -> Callable[[], None]: + """Subscribe to protocol updates from the channel. + + This handles MQTT protocol updates for V1 devices, routing data point + updates to the appropriate traits. + + Args: + channel: The channel to subscribe to for updates. + + Returns: + A callable that can be used to unsubscribe from updates. + """ + + def on_message(message: RoborockMessage) -> None: + self._handle_message(message) + + return await channel.subscribe(on_message) + + def _handle_message(self, message: RoborockMessage) -> None: + """Handle incoming messages from the device. + + Parses protocol updates and routes them to the appropriate traits. + """ + # Only process messages that can contain protocol updates + # RPC_RESPONSE (102), and GENERAL_RESPONSE (5) + if message.protocol not in { + RoborockMessageProtocol.RPC_RESPONSE, + RoborockMessageProtocol.GENERAL_RESPONSE, + }: + return + + if not message.payload: + return + + try: + payload = json.loads(message.payload.decode("utf-8")) + dps = payload.get("dps", {}) + + if not dps: + return + + # Process each data point in the message + for data_point_number, data_point in dps.items(): + # Skip RPC responses (102) as they're handled by the RPC channel + if data_point_number == "102": + continue + + try: + data_protocol = RoborockDataProtocol(int(data_point_number)) + _LOGGER.debug("Got device update for %s: %s", data_protocol.name, data_point) + self._handle_protocol_update(data_protocol, data_point) + except ValueError: + # Unknown protocol number + _LOGGER.debug( + f"Got unknown data protocol {data_point_number}, data: {data_point}. " + f"This may allow for faster updates in the future." + ) + except (json.JSONDecodeError, UnicodeDecodeError, KeyError) as ex: + _LOGGER.debug("Failed to parse protocol message: %s", ex) + + def _handle_protocol_update(self, protocol: RoborockDataProtocol, data_point: Any) -> None: + """Handle a protocol update for a specific data protocol. + + Args: + protocol: The data protocol number. + data_point: The data value for this protocol. + """ + # Handle status protocol updates + if protocol in ROBOROCK_DATA_STATUS_PROTOCOL and self.status: + if self.status.handle_protocol_update(protocol, data_point): + _LOGGER.debug("Updated status.%s to %s", protocol.name.lower(), data_point) + self.status.notify_update() + def create( device_uid: str, diff --git a/roborock/devices/traits/v1/common.py b/roborock/devices/traits/v1/common.py index 63ae2e20..cdcedf71 100644 --- a/roborock/devices/traits/v1/common.py +++ b/roborock/devices/traits/v1/common.py @@ -3,8 +3,11 @@ This is an internal library and should not be used directly by consumers. """ +from __future__ import annotations + import logging from abc import ABC, abstractmethod +from collections.abc import Callable from dataclasses import dataclass, fields from typing import ClassVar, Self @@ -15,6 +18,7 @@ _LOGGER = logging.getLogger(__name__) V1ResponseData = dict | list | int | str +V1TraitUpdateCallback = Callable[[], None] @dataclass @@ -74,6 +78,7 @@ def __post_init__(self) -> None: device setup code. """ self._rpc_channel = None + self._update_callbacks: list[V1TraitUpdateCallback] = [] @property def rpc_channel(self) -> V1RpcChannel: @@ -97,6 +102,26 @@ def _update_trait_values(self, new_data: RoborockBase) -> None: new_value = getattr(new_data, field.name, None) setattr(self, field.name, new_value) + def add_update_callback(self, callback: V1TraitUpdateCallback) -> Callable[[], None]: + """Add a callback to be notified when the trait is updated. + + The callback will be called whenever a protocol message updates the trait. + Callers should track which trait they subscribed to if needed. + + Returns: + A callable that can be used to remove the callback. + """ + self._update_callbacks.append(callback) + return lambda: self._update_callbacks.remove(callback) + + def notify_update(self) -> None: + """Notify all registered callbacks that the trait has been updated.""" + for callback in self._update_callbacks: + try: + callback() + except Exception: # noqa: BLE001 + _LOGGER.exception("Error in trait update callback") + def _get_value_field(clazz: type[V1TraitMixin]) -> str: """Get the name of the field marked as the main value of the RoborockValueBase.""" diff --git a/roborock/devices/traits/v1/status.py b/roborock/devices/traits/v1/status.py index 08cd0c45..4f60bdba 100644 --- a/roborock/devices/traits/v1/status.py +++ b/roborock/devices/traits/v1/status.py @@ -1,7 +1,8 @@ -from typing import Self +from typing import Any, Self -from roborock.data import HomeDataProduct, ModelStatus, S7MaxVStatus, Status +from roborock.data import HomeDataProduct, ModelStatus, RoborockErrorCode, RoborockStateCode, S7MaxVStatus, Status from roborock.devices.traits.v1 import common +from roborock.roborock_message import RoborockDataProtocol from roborock.roborock_typing import RoborockCommand @@ -22,3 +23,19 @@ def _parse_response(self, response: common.V1ResponseData) -> Self: if isinstance(response, dict): return status_type.from_dict(response) raise ValueError(f"Unexpected status format: {response!r}") + + def handle_protocol_update(self, protocol: RoborockDataProtocol, data_point: Any) -> bool: + """Handle a protocol update for a specific data protocol.""" + match protocol: + case RoborockDataProtocol.ERROR_CODE: + self.error_code = RoborockErrorCode(data_point) + case RoborockDataProtocol.STATE: + self.state = RoborockStateCode(data_point) + case RoborockDataProtocol.BATTERY: + self.battery = data_point + case RoborockDataProtocol.CHARGE_STATUS: + self.charge_status = data_point + case _: + # There is also fan power and water box mode, but for now those are skipped + return False + return True