From 50dcf119f7853aa33e60bf0cfac6e26059014075 Mon Sep 17 00:00:00 2001 From: Luke Date: Sun, 4 Jan 2026 22:06:41 -0500 Subject: [PATCH 1/3] feat: add protocol updates --- .gitignore | 4 ++ roborock/devices/device.py | 81 ++++++++++++++++++++++++++-- roborock/devices/rpc/v1_channel.py | 14 ++--- roborock/devices/traits/v1/common.py | 21 ++++++++ 4 files changed, 111 insertions(+), 9 deletions(-) diff --git a/.gitignore b/.gitignore index de3ccf9c..f724ac37 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,7 @@ docs/_build/ # GitHub App credentials gha-creds-*.json + +# pickle files +*.p +*.pickle diff --git a/roborock/devices/device.py b/roborock/devices/device.py index 9026c4a7..40f581e6 100644 --- a/roborock/devices/device.py +++ b/roborock/devices/device.py @@ -6,16 +6,22 @@ import asyncio import datetime +import json import logging from abc import ABC from collections.abc import Callable from typing import Any from roborock.callbacks import CallbackList -from roborock.data import HomeDataDevice, HomeDataProduct +from roborock.data import HomeDataDevice, HomeDataProduct, RoborockErrorCode, RoborockStateCode from roborock.diagnostics import redact_device_data from roborock.exceptions import RoborockException -from roborock.roborock_message import RoborockMessage +from roborock.roborock_message import ( + ROBOROCK_DATA_STATUS_PROTOCOL, + RoborockDataProtocol, + RoborockMessage, + RoborockMessageProtocol, +) from roborock.util import RoborockLoggerAdapter from .traits import Trait @@ -219,8 +225,77 @@ async def close(self) -> None: self._unsub = None def _on_message(self, message: RoborockMessage) -> None: - """Handle incoming messages from the device.""" + """Handle incoming messages from the device. + + Note: Protocol updates (data points) are only sent via cloud/MQTT, not local connection. + """ self._logger.debug("Received message from device: %s", message) + if self.v1_properties is None: + # Ensure we are only doing below logic for set-up V1 devices. + return + + # Only process messages that can contain protocol updates + # RPC_RESPONSE (102), GENERAL_REQUEST (4), and GENERAL_RESPONSE (5) + if message.protocol not in { + RoborockMessageProtocol.RPC_RESPONSE, + RoborockMessageProtocol.GENERAL_RESPONSE, + }: + return + + if not message.payload: + return + + try: + payload = json.loads(message.payload.decode()) + dps = payload.get("dps", {}) + + if not dps: + return + + # Process each data point in the message + for data_point_number, data_point in dps.items(): + # Skip RPC responses (102) as they're handled by the RPC channel + if data_point_number == "102": + continue + + try: + data_protocol = RoborockDataProtocol(int(data_point_number)) + self._logger.debug(f"Got device update for {data_protocol.name}: {data_point}") + self._handle_protocol_update(data_protocol, data_point) + except ValueError: + # Unknown protocol number + self._logger.debug( + f"Got unknown data protocol {data_point_number}, data: {data_point}. " + f"This may allow for faster updates in the future." + ) + except (json.JSONDecodeError, UnicodeDecodeError, KeyError) as ex: + self._logger.debug(f"Failed to parse protocol message: {ex}") + + def _handle_protocol_update(self, protocol: RoborockDataProtocol, data_point: Any) -> None: + """Handle a protocol update for a specific data protocol. + + Args: + protocol: The data protocol number. + data_point: The data value for this protocol. + """ + # Handle status protocol updates + if protocol in ROBOROCK_DATA_STATUS_PROTOCOL and self.v1_properties and self.v1_properties.status: + # Update the specific field in the status trait + match protocol: + case RoborockDataProtocol.ERROR_CODE: + self.v1_properties.status.error_code = RoborockErrorCode(data_point) + case RoborockDataProtocol.STATE: + self.v1_properties.status.state = RoborockStateCode(data_point) + case RoborockDataProtocol.BATTERY: + self.v1_properties.status.battery = data_point + case RoborockDataProtocol.CHARGE_STATUS: + self.v1_properties.status.charge_status = data_point + case _: + # There is also fan power and water box mode, but for now those are skipped + return + + self._logger.debug("Updated status.%s to %s", protocol.name.lower(), data_point) + self.v1_properties.status.notify_update() def diagnostic_data(self) -> dict[str, Any]: """Return diagnostics information about the device.""" diff --git a/roborock/devices/rpc/v1_channel.py b/roborock/devices/rpc/v1_channel.py index d1b4ee24..608c9ecc 100644 --- a/roborock/devices/rpc/v1_channel.py +++ b/roborock/devices/rpc/v1_channel.py @@ -305,12 +305,14 @@ async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callab loop = asyncio.get_running_loop() self._reconnect_task = loop.create_task(self._background_reconnect()) - if not self.is_local_connected: - # We were not able to connect locally, so fallback to MQTT and at least - # establish that connection explicitly. If this fails then raise an - # error and let the caller know we failed to subscribe. - self._mqtt_unsub = await self._mqtt_channel.subscribe(self._on_mqtt_message) - self._logger.debug("V1Channel connected to device via MQTT") + # Always subscribe to MQTT to receive protocol updates (data points) + # even if we have a local connection. Protocol updates only come via cloud/MQTT. + # Local connection is used for RPC commands, but push notifications come via MQTT. + self._mqtt_unsub = await self._mqtt_channel.subscribe(self._on_mqtt_message) + if self.is_local_connected: + self._logger.debug("V1Channel connected via local and MQTT (for protocol updates)") + else: + self._logger.debug("V1Channel connected via MQTT only") def unsub() -> None: """Unsubscribe from all messages.""" diff --git a/roborock/devices/traits/v1/common.py b/roborock/devices/traits/v1/common.py index 63ae2e20..7313e245 100644 --- a/roborock/devices/traits/v1/common.py +++ b/roborock/devices/traits/v1/common.py @@ -3,11 +3,15 @@ This is an internal library and should not be used directly by consumers. """ +from __future__ import annotations + import logging from abc import ABC, abstractmethod +from collections.abc import Callable from dataclasses import dataclass, fields from typing import ClassVar, Self +from roborock.callbacks import CallbackList from roborock.data import RoborockBase from roborock.protocols.v1_protocol import V1RpcChannel from roborock.roborock_typing import RoborockCommand @@ -15,6 +19,7 @@ _LOGGER = logging.getLogger(__name__) V1ResponseData = dict | list | int | str +V1TraitUpdateCallback = Callable[["V1TraitMixin"], None] @dataclass @@ -74,6 +79,7 @@ def __post_init__(self) -> None: device setup code. """ self._rpc_channel = None + self._update_callbacks: CallbackList[V1TraitMixin] = CallbackList() @property def rpc_channel(self) -> V1RpcChannel: @@ -97,6 +103,21 @@ def _update_trait_values(self, new_data: RoborockBase) -> None: new_value = getattr(new_data, field.name, None) setattr(self, field.name, new_value) + def add_update_callback(self, callback: V1TraitUpdateCallback) -> Callable[[], None]: + """Add a callback to be notified when the trait is updated. + + The callback will be called with the updated trait instance whenever + a protocol message updates the trait. + + Returns: + A callable that can be used to remove the callback. + """ + return self._update_callbacks.add_callback(callback) + + def notify_update(self) -> None: + """Notify all registered callbacks that the trait has been updated.""" + self._update_callbacks(self) + def _get_value_field(clazz: type[V1TraitMixin]) -> str: """Get the name of the field marked as the main value of the RoborockValueBase.""" From 4a23a4bf2cc65746ee004dd30fa17b4a4154a8e7 Mon Sep 17 00:00:00 2001 From: Luke Date: Wed, 28 Jan 2026 19:44:18 -0500 Subject: [PATCH 2/3] chore: minor refactoring --- roborock/devices/device.py | 29 ++++++++-------------------- roborock/devices/traits/v1/status.py | 21 ++++++++++++++++++-- 2 files changed, 27 insertions(+), 23 deletions(-) diff --git a/roborock/devices/device.py b/roborock/devices/device.py index 5c66c181..12b25bd6 100644 --- a/roborock/devices/device.py +++ b/roborock/devices/device.py @@ -13,7 +13,7 @@ from typing import Any from roborock.callbacks import CallbackList -from roborock.data import HomeDataDevice, HomeDataProduct, RoborockErrorCode, RoborockStateCode +from roborock.data import HomeDataDevice, HomeDataProduct from roborock.diagnostics import redact_device_data from roborock.exceptions import RoborockException from roborock.roborock_message import ( @@ -235,7 +235,7 @@ def _on_message(self, message: RoborockMessage) -> None: return # Only process messages that can contain protocol updates - # RPC_RESPONSE (102), GENERAL_REQUEST (4), and GENERAL_RESPONSE (5) + # RPC_RESPONSE (102), and GENERAL_RESPONSE (5) if message.protocol not in { RoborockMessageProtocol.RPC_RESPONSE, RoborockMessageProtocol.GENERAL_RESPONSE, @@ -246,7 +246,7 @@ def _on_message(self, message: RoborockMessage) -> None: return try: - payload = json.loads(message.payload.decode()) + payload = json.loads(message.payload.decode("utf-8")) dps = payload.get("dps", {}) if not dps: @@ -260,7 +260,7 @@ def _on_message(self, message: RoborockMessage) -> None: try: data_protocol = RoborockDataProtocol(int(data_point_number)) - self._logger.debug(f"Got device update for {data_protocol.name}: {data_point}") + self._logger.debug("Got device update for %s: %s", data_protocol.name, data_point) self._handle_protocol_update(data_protocol, data_point) except ValueError: # Unknown protocol number @@ -269,7 +269,7 @@ def _on_message(self, message: RoborockMessage) -> None: f"This may allow for faster updates in the future." ) except (json.JSONDecodeError, UnicodeDecodeError, KeyError) as ex: - self._logger.debug(f"Failed to parse protocol message: {ex}") + self._logger.debug("Failed to parse protocol message: %s", ex) def _handle_protocol_update(self, protocol: RoborockDataProtocol, data_point: Any) -> None: """Handle a protocol update for a specific data protocol. @@ -280,22 +280,9 @@ def _handle_protocol_update(self, protocol: RoborockDataProtocol, data_point: An """ # Handle status protocol updates if protocol in ROBOROCK_DATA_STATUS_PROTOCOL and self.v1_properties and self.v1_properties.status: - # Update the specific field in the status trait - match protocol: - case RoborockDataProtocol.ERROR_CODE: - self.v1_properties.status.error_code = RoborockErrorCode(data_point) - case RoborockDataProtocol.STATE: - self.v1_properties.status.state = RoborockStateCode(data_point) - case RoborockDataProtocol.BATTERY: - self.v1_properties.status.battery = data_point - case RoborockDataProtocol.CHARGE_STATUS: - self.v1_properties.status.charge_status = data_point - case _: - # There is also fan power and water box mode, but for now those are skipped - return - - self._logger.debug("Updated status.%s to %s", protocol.name.lower(), data_point) - self.v1_properties.status.notify_update() + if self.v1_properties.status.handle_protocol_update(protocol, data_point): + self._logger.debug("Updated status.%s to %s", protocol.name.lower(), data_point) + self.v1_properties.status.notify_update() def diagnostic_data(self) -> dict[str, Any]: """Return diagnostics information about the device.""" diff --git a/roborock/devices/traits/v1/status.py b/roborock/devices/traits/v1/status.py index 08cd0c45..4f60bdba 100644 --- a/roborock/devices/traits/v1/status.py +++ b/roborock/devices/traits/v1/status.py @@ -1,7 +1,8 @@ -from typing import Self +from typing import Any, Self -from roborock.data import HomeDataProduct, ModelStatus, S7MaxVStatus, Status +from roborock.data import HomeDataProduct, ModelStatus, RoborockErrorCode, RoborockStateCode, S7MaxVStatus, Status from roborock.devices.traits.v1 import common +from roborock.roborock_message import RoborockDataProtocol from roborock.roborock_typing import RoborockCommand @@ -22,3 +23,19 @@ def _parse_response(self, response: common.V1ResponseData) -> Self: if isinstance(response, dict): return status_type.from_dict(response) raise ValueError(f"Unexpected status format: {response!r}") + + def handle_protocol_update(self, protocol: RoborockDataProtocol, data_point: Any) -> bool: + """Handle a protocol update for a specific data protocol.""" + match protocol: + case RoborockDataProtocol.ERROR_CODE: + self.error_code = RoborockErrorCode(data_point) + case RoborockDataProtocol.STATE: + self.state = RoborockStateCode(data_point) + case RoborockDataProtocol.BATTERY: + self.battery = data_point + case RoborockDataProtocol.CHARGE_STATUS: + self.charge_status = data_point + case _: + # There is also fan power and water box mode, but for now those are skipped + return False + return True From cdf6f3bd6dd9b0aabe510abc66278a5eb52265d0 Mon Sep 17 00:00:00 2001 From: Luke Date: Wed, 28 Jan 2026 20:37:53 -0500 Subject: [PATCH 3/3] chore: move logic to be owned by v1 instead of device --- roborock/devices/device.py | 85 +++++--------------------- roborock/devices/traits/v1/__init__.py | 82 +++++++++++++++++++++++++ roborock/devices/traits/v1/common.py | 18 +++--- 3 files changed, 108 insertions(+), 77 deletions(-) diff --git a/roborock/devices/device.py b/roborock/devices/device.py index 12b25bd6..12a3c2f3 100644 --- a/roborock/devices/device.py +++ b/roborock/devices/device.py @@ -6,7 +6,6 @@ import asyncio import datetime -import json import logging from abc import ABC from collections.abc import Callable @@ -16,12 +15,6 @@ from roborock.data import HomeDataDevice, HomeDataProduct from roborock.diagnostics import redact_device_data from roborock.exceptions import RoborockException -from roborock.roborock_message import ( - ROBOROCK_DATA_STATUS_PROTOCOL, - RoborockDataProtocol, - RoborockMessage, - RoborockMessageProtocol, -) from roborock.util import RoborockLoggerAdapter from .traits import Trait @@ -81,6 +74,7 @@ def __init__( self._channel = channel self._connect_task: asyncio.Task[None] | None = None self._unsub: Callable[[], None] | None = None + self._v1_unsub: Callable[[], None] | None = None self._ready_callbacks = CallbackList["RoborockDevice"]() self._has_connected = False @@ -202,15 +196,23 @@ async def connect(self) -> None: """Connect to the device using the appropriate protocol channel.""" if self._unsub: raise ValueError("Already connected to the device") - unsub = await self._channel.subscribe(self._on_message) + if self.v1_properties is not None: try: + # V1 layer subscribes to the channel and handles protocol updates. + # Note: V1Channel only allows one subscription, so the V1 layer + # is the sole subscriber for V1 devices. + self._v1_unsub = await self.v1_properties.subscribe_async(self._channel) await self.v1_properties.discover_features() except RoborockException: - unsub() + if self._v1_unsub: + self._v1_unsub() raise + else: + # Non-V1 devices subscribe directly (no protocol update handling needed) + self._unsub = await self._channel.subscribe(lambda msg: None) + self._logger.info("Connected to device") - self._unsub = unsub async def close(self) -> None: """Close all connections to the device.""" @@ -220,70 +222,13 @@ async def close(self) -> None: await self._connect_task except asyncio.CancelledError: pass + if self._v1_unsub: + self._v1_unsub() + self._v1_unsub = None if self._unsub: self._unsub() self._unsub = None - def _on_message(self, message: RoborockMessage) -> None: - """Handle incoming messages from the device. - - Note: Protocol updates (data points) are only sent via cloud/MQTT, not local connection. - """ - self._logger.debug("Received message from device: %s", message) - if self.v1_properties is None: - # Ensure we are only doing below logic for set-up V1 devices. - return - - # Only process messages that can contain protocol updates - # RPC_RESPONSE (102), and GENERAL_RESPONSE (5) - if message.protocol not in { - RoborockMessageProtocol.RPC_RESPONSE, - RoborockMessageProtocol.GENERAL_RESPONSE, - }: - return - - if not message.payload: - return - - try: - payload = json.loads(message.payload.decode("utf-8")) - dps = payload.get("dps", {}) - - if not dps: - return - - # Process each data point in the message - for data_point_number, data_point in dps.items(): - # Skip RPC responses (102) as they're handled by the RPC channel - if data_point_number == "102": - continue - - try: - data_protocol = RoborockDataProtocol(int(data_point_number)) - self._logger.debug("Got device update for %s: %s", data_protocol.name, data_point) - self._handle_protocol_update(data_protocol, data_point) - except ValueError: - # Unknown protocol number - self._logger.debug( - f"Got unknown data protocol {data_point_number}, data: {data_point}. " - f"This may allow for faster updates in the future." - ) - except (json.JSONDecodeError, UnicodeDecodeError, KeyError) as ex: - self._logger.debug("Failed to parse protocol message: %s", ex) - - def _handle_protocol_update(self, protocol: RoborockDataProtocol, data_point: Any) -> None: - """Handle a protocol update for a specific data protocol. - - Args: - protocol: The data protocol number. - data_point: The data value for this protocol. - """ - # Handle status protocol updates - if protocol in ROBOROCK_DATA_STATUS_PROTOCOL and self.v1_properties and self.v1_properties.status: - if self.v1_properties.status.handle_protocol_update(protocol, data_point): - self._logger.debug("Updated status.%s to %s", protocol.name.lower(), data_point) - self.v1_properties.status.notify_update() - def diagnostic_data(self) -> dict[str, Any]: """Return diagnostics information about the device.""" extra: dict[str, Any] = {} diff --git a/roborock/devices/traits/v1/__init__.py b/roborock/devices/traits/v1/__init__.py index 7438c6b0..db5b047b 100644 --- a/roborock/devices/traits/v1/__init__.py +++ b/roborock/devices/traits/v1/__init__.py @@ -52,7 +52,9 @@ code in HomeDataProduct Schema that is required for the field to be supported. """ +import json import logging +from collections.abc import Callable from dataclasses import dataclass, field, fields from functools import cache from typing import Any, get_args @@ -61,8 +63,15 @@ from roborock.data.v1.v1_code_mappings import RoborockDockTypeCode from roborock.devices.cache import DeviceCache from roborock.devices.traits import Trait +from roborock.devices.transport.channel import Channel from roborock.map.map_parser import MapParserConfig from roborock.protocols.v1_protocol import V1RpcChannel +from roborock.roborock_message import ( + ROBOROCK_DATA_STATUS_PROTOCOL, + RoborockDataProtocol, + RoborockMessage, + RoborockMessageProtocol, +) from roborock.web_api import UserWebApiClient from . import ( @@ -313,6 +322,79 @@ def as_dict(self) -> dict[str, Any]: result[item.name] = data return result + async def subscribe_async(self, channel: Channel) -> Callable[[], None]: + """Subscribe to protocol updates from the channel. + + This handles MQTT protocol updates for V1 devices, routing data point + updates to the appropriate traits. + + Args: + channel: The channel to subscribe to for updates. + + Returns: + A callable that can be used to unsubscribe from updates. + """ + + def on_message(message: RoborockMessage) -> None: + self._handle_message(message) + + return await channel.subscribe(on_message) + + def _handle_message(self, message: RoborockMessage) -> None: + """Handle incoming messages from the device. + + Parses protocol updates and routes them to the appropriate traits. + """ + # Only process messages that can contain protocol updates + # RPC_RESPONSE (102), and GENERAL_RESPONSE (5) + if message.protocol not in { + RoborockMessageProtocol.RPC_RESPONSE, + RoborockMessageProtocol.GENERAL_RESPONSE, + }: + return + + if not message.payload: + return + + try: + payload = json.loads(message.payload.decode("utf-8")) + dps = payload.get("dps", {}) + + if not dps: + return + + # Process each data point in the message + for data_point_number, data_point in dps.items(): + # Skip RPC responses (102) as they're handled by the RPC channel + if data_point_number == "102": + continue + + try: + data_protocol = RoborockDataProtocol(int(data_point_number)) + _LOGGER.debug("Got device update for %s: %s", data_protocol.name, data_point) + self._handle_protocol_update(data_protocol, data_point) + except ValueError: + # Unknown protocol number + _LOGGER.debug( + f"Got unknown data protocol {data_point_number}, data: {data_point}. " + f"This may allow for faster updates in the future." + ) + except (json.JSONDecodeError, UnicodeDecodeError, KeyError) as ex: + _LOGGER.debug("Failed to parse protocol message: %s", ex) + + def _handle_protocol_update(self, protocol: RoborockDataProtocol, data_point: Any) -> None: + """Handle a protocol update for a specific data protocol. + + Args: + protocol: The data protocol number. + data_point: The data value for this protocol. + """ + # Handle status protocol updates + if protocol in ROBOROCK_DATA_STATUS_PROTOCOL and self.status: + if self.status.handle_protocol_update(protocol, data_point): + _LOGGER.debug("Updated status.%s to %s", protocol.name.lower(), data_point) + self.status.notify_update() + def create( device_uid: str, diff --git a/roborock/devices/traits/v1/common.py b/roborock/devices/traits/v1/common.py index 7313e245..cdcedf71 100644 --- a/roborock/devices/traits/v1/common.py +++ b/roborock/devices/traits/v1/common.py @@ -11,7 +11,6 @@ from dataclasses import dataclass, fields from typing import ClassVar, Self -from roborock.callbacks import CallbackList from roborock.data import RoborockBase from roborock.protocols.v1_protocol import V1RpcChannel from roborock.roborock_typing import RoborockCommand @@ -19,7 +18,7 @@ _LOGGER = logging.getLogger(__name__) V1ResponseData = dict | list | int | str -V1TraitUpdateCallback = Callable[["V1TraitMixin"], None] +V1TraitUpdateCallback = Callable[[], None] @dataclass @@ -79,7 +78,7 @@ def __post_init__(self) -> None: device setup code. """ self._rpc_channel = None - self._update_callbacks: CallbackList[V1TraitMixin] = CallbackList() + self._update_callbacks: list[V1TraitUpdateCallback] = [] @property def rpc_channel(self) -> V1RpcChannel: @@ -106,17 +105,22 @@ def _update_trait_values(self, new_data: RoborockBase) -> None: def add_update_callback(self, callback: V1TraitUpdateCallback) -> Callable[[], None]: """Add a callback to be notified when the trait is updated. - The callback will be called with the updated trait instance whenever - a protocol message updates the trait. + The callback will be called whenever a protocol message updates the trait. + Callers should track which trait they subscribed to if needed. Returns: A callable that can be used to remove the callback. """ - return self._update_callbacks.add_callback(callback) + self._update_callbacks.append(callback) + return lambda: self._update_callbacks.remove(callback) def notify_update(self) -> None: """Notify all registered callbacks that the trait has been updated.""" - self._update_callbacks(self) + for callback in self._update_callbacks: + try: + callback() + except Exception: # noqa: BLE001 + _LOGGER.exception("Error in trait update callback") def _get_value_field(clazz: type[V1TraitMixin]) -> str: