Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,7 @@ docs/_build/

# GitHub App credentials
gha-creds-*.json

# pickle files
*.p
*.pickle
Comment on lines +22 to +25
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated but i have a few of these in my repo and just wanted to stop them from filling up my git diff

23 changes: 15 additions & 8 deletions roborock/devices/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from roborock.data import HomeDataDevice, HomeDataProduct
from roborock.diagnostics import redact_device_data
from roborock.exceptions import RoborockException
from roborock.roborock_message import RoborockMessage
from roborock.util import RoborockLoggerAdapter

from .traits import Trait
Expand Down Expand Up @@ -75,6 +74,7 @@ def __init__(
self._channel = channel
self._connect_task: asyncio.Task[None] | None = None
self._unsub: Callable[[], None] | None = None
self._v1_unsub: Callable[[], None] | None = None
self._ready_callbacks = CallbackList["RoborockDevice"]()
self._has_connected = False

Expand Down Expand Up @@ -196,15 +196,23 @@ async def connect(self) -> None:
"""Connect to the device using the appropriate protocol channel."""
if self._unsub:
raise ValueError("Already connected to the device")
unsub = await self._channel.subscribe(self._on_message)

if self.v1_properties is not None:
try:
# V1 layer subscribes to the channel and handles protocol updates.
# Note: V1Channel only allows one subscription, so the V1 layer
# is the sole subscriber for V1 devices.
self._v1_unsub = await self.v1_properties.subscribe_async(self._channel)
await self.v1_properties.discover_features()
except RoborockException:
unsub()
if self._v1_unsub:
self._v1_unsub()
raise
else:
# Non-V1 devices subscribe directly (no protocol update handling needed)
self._unsub = await self._channel.subscribe(lambda msg: None)

self._logger.info("Connected to device")
self._unsub = unsub

async def close(self) -> None:
"""Close all connections to the device."""
Expand All @@ -214,14 +222,13 @@ async def close(self) -> None:
await self._connect_task
except asyncio.CancelledError:
pass
if self._v1_unsub:
self._v1_unsub()
self._v1_unsub = None
if self._unsub:
self._unsub()
self._unsub = None

def _on_message(self, message: RoborockMessage) -> None:
"""Handle incoming messages from the device."""
self._logger.debug("Received message from device: %s", message)

def diagnostic_data(self) -> dict[str, Any]:
"""Return diagnostics information about the device."""
extra: dict[str, Any] = {}
Expand Down
14 changes: 8 additions & 6 deletions roborock/devices/rpc/v1_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,12 +305,14 @@ async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callab
loop = asyncio.get_running_loop()
self._reconnect_task = loop.create_task(self._background_reconnect())

if not self.is_local_connected:
# We were not able to connect locally, so fallback to MQTT and at least
# establish that connection explicitly. If this fails then raise an
# error and let the caller know we failed to subscribe.
self._mqtt_unsub = await self._mqtt_channel.subscribe(self._on_mqtt_message)
self._logger.debug("V1Channel connected to device via MQTT")
# Always subscribe to MQTT to receive protocol updates (data points)
# even if we have a local connection. Protocol updates only come via cloud/MQTT.
# Local connection is used for RPC commands, but push notifications come via MQTT.
self._mqtt_unsub = await self._mqtt_channel.subscribe(self._on_mqtt_message)
if self.is_local_connected:
self._logger.debug("V1Channel connected via local and MQTT (for protocol updates)")
else:
self._logger.debug("V1Channel connected via MQTT only")

def unsub() -> None:
"""Unsubscribe from all messages."""
Expand Down
82 changes: 82 additions & 0 deletions roborock/devices/traits/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@
code in HomeDataProduct Schema that is required for the field to be supported.
"""

import json
import logging
from collections.abc import Callable
from dataclasses import dataclass, field, fields
from functools import cache
from typing import Any, get_args
Expand All @@ -61,8 +63,15 @@
from roborock.data.v1.v1_code_mappings import RoborockDockTypeCode
from roborock.devices.cache import DeviceCache
from roborock.devices.traits import Trait
from roborock.devices.transport.channel import Channel
from roborock.map.map_parser import MapParserConfig
from roborock.protocols.v1_protocol import V1RpcChannel
from roborock.roborock_message import (
ROBOROCK_DATA_STATUS_PROTOCOL,
RoborockDataProtocol,
RoborockMessage,
RoborockMessageProtocol,
)
from roborock.web_api import UserWebApiClient

from . import (
Expand Down Expand Up @@ -313,6 +322,79 @@ def as_dict(self) -> dict[str, Any]:
result[item.name] = data
return result

async def subscribe_async(self, channel: Channel) -> Callable[[], None]:
"""Subscribe to protocol updates from the channel.

This handles MQTT protocol updates for V1 devices, routing data point
updates to the appropriate traits.

Args:
channel: The channel to subscribe to for updates.

Returns:
A callable that can be used to unsubscribe from updates.
"""

def on_message(message: RoborockMessage) -> None:
self._handle_message(message)

return await channel.subscribe(on_message)

def _handle_message(self, message: RoborockMessage) -> None:
"""Handle incoming messages from the device.

Parses protocol updates and routes them to the appropriate traits.
"""
# Only process messages that can contain protocol updates
# RPC_RESPONSE (102), and GENERAL_RESPONSE (5)
if message.protocol not in {
RoborockMessageProtocol.RPC_RESPONSE,
RoborockMessageProtocol.GENERAL_RESPONSE,
}:
return

if not message.payload:
return

try:
payload = json.loads(message.payload.decode("utf-8"))
dps = payload.get("dps", {})

if not dps:
return

# Process each data point in the message
for data_point_number, data_point in dps.items():
# Skip RPC responses (102) as they're handled by the RPC channel
if data_point_number == "102":
continue

try:
data_protocol = RoborockDataProtocol(int(data_point_number))
_LOGGER.debug("Got device update for %s: %s", data_protocol.name, data_point)
self._handle_protocol_update(data_protocol, data_point)
except ValueError:
# Unknown protocol number
_LOGGER.debug(
f"Got unknown data protocol {data_point_number}, data: {data_point}. "
f"This may allow for faster updates in the future."
)
except (json.JSONDecodeError, UnicodeDecodeError, KeyError) as ex:
_LOGGER.debug("Failed to parse protocol message: %s", ex)

def _handle_protocol_update(self, protocol: RoborockDataProtocol, data_point: Any) -> None:
"""Handle a protocol update for a specific data protocol.

Args:
protocol: The data protocol number.
data_point: The data value for this protocol.
"""
# Handle status protocol updates
if protocol in ROBOROCK_DATA_STATUS_PROTOCOL and self.status:
if self.status.handle_protocol_update(protocol, data_point):
_LOGGER.debug("Updated status.%s to %s", protocol.name.lower(), data_point)
self.status.notify_update()


def create(
device_uid: str,
Expand Down
25 changes: 25 additions & 0 deletions roborock/devices/traits/v1/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
This is an internal library and should not be used directly by consumers.
"""

from __future__ import annotations

import logging
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass, fields
from typing import ClassVar, Self

Expand All @@ -15,6 +18,7 @@
_LOGGER = logging.getLogger(__name__)

V1ResponseData = dict | list | int | str
V1TraitUpdateCallback = Callable[[], None]


@dataclass
Expand Down Expand Up @@ -74,6 +78,7 @@ def __post_init__(self) -> None:
device setup code.
"""
self._rpc_channel = None
self._update_callbacks: list[V1TraitUpdateCallback] = []

@property
def rpc_channel(self) -> V1RpcChannel:
Expand All @@ -97,6 +102,26 @@ def _update_trait_values(self, new_data: RoborockBase) -> None:
new_value = getattr(new_data, field.name, None)
setattr(self, field.name, new_value)

def add_update_callback(self, callback: V1TraitUpdateCallback) -> Callable[[], None]:
"""Add a callback to be notified when the trait is updated.

The callback will be called whenever a protocol message updates the trait.
Callers should track which trait they subscribed to if needed.

Returns:
A callable that can be used to remove the callback.
"""
self._update_callbacks.append(callback)
return lambda: self._update_callbacks.remove(callback)

def notify_update(self) -> None:
"""Notify all registered callbacks that the trait has been updated."""
for callback in self._update_callbacks:
try:
callback()
except Exception: # noqa: BLE001
_LOGGER.exception("Error in trait update callback")


def _get_value_field(clazz: type[V1TraitMixin]) -> str:
"""Get the name of the field marked as the main value of the RoborockValueBase."""
Expand Down
21 changes: 19 additions & 2 deletions roborock/devices/traits/v1/status.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Self
from typing import Any, Self

from roborock.data import HomeDataProduct, ModelStatus, S7MaxVStatus, Status
from roborock.data import HomeDataProduct, ModelStatus, RoborockErrorCode, RoborockStateCode, S7MaxVStatus, Status
from roborock.devices.traits.v1 import common
from roborock.roborock_message import RoborockDataProtocol
from roborock.roborock_typing import RoborockCommand


Expand All @@ -22,3 +23,19 @@ def _parse_response(self, response: common.V1ResponseData) -> Self:
if isinstance(response, dict):
return status_type.from_dict(response)
raise ValueError(f"Unexpected status format: {response!r}")

def handle_protocol_update(self, protocol: RoborockDataProtocol, data_point: Any) -> bool:
"""Handle a protocol update for a specific data protocol."""
match protocol:
case RoborockDataProtocol.ERROR_CODE:
self.error_code = RoborockErrorCode(data_point)
case RoborockDataProtocol.STATE:
self.state = RoborockStateCode(data_point)
case RoborockDataProtocol.BATTERY:
self.battery = data_point
case RoborockDataProtocol.CHARGE_STATUS:
self.charge_status = data_point
case _:
# There is also fan power and water box mode, but for now those are skipped
return False
return True
Loading