From a0512e25d38e88ce3b61542f3dc34cdf29ea139c Mon Sep 17 00:00:00 2001 From: yuanhaonan Date: Thu, 29 Jan 2026 18:58:24 +0800 Subject: [PATCH 1/2] update device controller for both gripper and hand --- embodichain/lab/devices/device_controller.py | 113 +++++++++++++++---- 1 file changed, 93 insertions(+), 20 deletions(-) diff --git a/embodichain/lab/devices/device_controller.py b/embodichain/lab/devices/device_controller.py index 2ae48acc..a9bf6203 100644 --- a/embodichain/lab/devices/device_controller.py +++ b/embodichain/lab/devices/device_controller.py @@ -71,10 +71,21 @@ def __init__( # Joint mapping state: maps device joint name -> robot joint index self._joint_mapping: Dict[str, int] = {} self._mapping_initialized = False + self._binary_control_keys = set() # Device keys that need binary control logger.log_info(f"Device Controller initialized for robot: {robot.uid}") logger.log_info(f" Robot has {len(robot.joint_names)} joints") + def _get_gripper_joints(self) -> List[str]: + """Get gripper joint names from control_parts.""" + if self.robot.control_parts: + gripper_joints = [] + for part_name, joint_names in self.robot.control_parts.items(): + if "eef" in part_name.lower(): + gripper_joints.extend(joint_names) + return gripper_joints + return [] + def add_device( self, device: Device, device_name: str, set_active: bool = False ) -> None: @@ -165,29 +176,55 @@ def _map_device_to_robot( try: robot_joint_names = self.robot.joint_names - # Build joint mapping on first call if not self._mapping_initialized: self._build_joint_mapping(device_data, robot_joint_names) + # Get robot joint limits + qpos_limits = self.robot.body_data.qpos_limits[0] # [num_joints, 2] + # Extract joint values based on mapping joint_values = [] joint_indices = [] mapped_joints = {} - for device_joint, robot_idx in self._joint_mapping.items(): + gripper_threshold = 0.99 + + for device_joint, robot_indices in self._joint_mapping.items(): if device_joint in device_data: - value = device_data[device_joint] - joint_values.append(value) - joint_indices.append(robot_idx) - mapped_joints[robot_joint_names[robot_idx]] = value + vr_value = device_data[device_joint] + + # Binary control for grippers, direct value for other joints + if device_joint in self._binary_control_keys: + idx = robot_indices[0] + joint_min, joint_max = ( + qpos_limits[idx, 0].item(), + qpos_limits[idx, 1].item(), + ) + robot_value = ( + joint_min if vr_value >= gripper_threshold else joint_max + ) + + # Log state change + state = "CLOSE" if vr_value >= gripper_threshold else "OPEN" + if not hasattr(self, "_last_gripper_state"): + self._last_gripper_state = {} + if self._last_gripper_state.get(device_joint) != state: + logger.log_info( + f"{device_joint}: VR={vr_value:.3f} -> {state} (robot={robot_value:.3f}, limit=[{joint_min:.3f}, {joint_max:.3f}])" + ) + self._last_gripper_state[device_joint] = state + else: + robot_value = vr_value + + # One device value may correspond to multiple robot joints (e.g., two fingers of gripper) + for robot_idx in robot_indices: + joint_values.append(robot_value) + joint_indices.append(robot_idx) + mapped_joints[robot_joint_names[robot_idx]] = robot_value if len(joint_indices) == 0: return None - # Return as dict if requested - if as_dict: - return mapped_joints - # Convert to tensor and create full action device_tensor = torch.tensor( joint_values, dtype=torch.float32, device=self.robot.device @@ -210,30 +247,65 @@ def _map_device_to_robot( except Exception as e: logger.log_error(f"Error mapping device data to robot action: {e}") + import traceback + + logger.log_error(traceback.format_exc()) return None def _build_joint_mapping( self, device_data: Dict[str, float], robot_joint_names: List[str] ) -> None: - """Build mapping from device joints to robot joints. - - Args: - device_data: Device joint data to determine available joints. - robot_joint_names: Robot joint names. - """ + """Build mapping from device joints to robot joints.""" + # Store: device_joint_name -> [robot_joint_indices] self._joint_mapping = {} + logger.log_info("=" * 60) + logger.log_info("Building joint mapping...") + logger.log_info( + f"VR device data keys ({len(device_data)} total): {sorted(device_data.keys())}" + ) + logger.log_info( + f"Robot joint names ({len(robot_joint_names)} total): {robot_joint_names}" + ) + logger.log_info(f"Robot control_parts: {self.robot.control_parts}") + + # Direct name matching for robot_idx, robot_joint in enumerate(robot_joint_names): - # Direct name matching (can be extended with more sophisticated mapping) if robot_joint in device_data: - self._joint_mapping[robot_joint] = robot_idx + self._joint_mapping.setdefault(robot_joint, []).append(robot_idx) + + # Special gripper mapping: LEFT_GRIPPER/RIGHT_GRIPPER -> multiple finger joints + # Only for parallel grippers (< 5 joints), not dexterous hands (>= 5 joints) + if self.robot.control_parts: + for part_name, joint_names in self.robot.control_parts.items(): + if "eef" in part_name.lower() and len(joint_names) < 5: + device_key = ( + f"{'LEFT' if 'left' in part_name.lower() else 'RIGHT'}_GRIPPER" + ) + if ( + device_key in device_data + and device_key not in robot_joint_names + ): + indices = [ + robot_joint_names.index(j) + for j in joint_names + if j in robot_joint_names + ] + if indices: + self._joint_mapping[device_key] = indices + self._binary_control_keys.add(device_key) + logger.log_info( + f" Gripper mapping: '{device_key}' -> {joint_names} (indices: {indices})" + ) self._mapping_initialized = True + total_mappings = sum(len(v) for v in self._joint_mapping.values()) logger.log_info( - f"Joint mapping initialized: {len(self._joint_mapping)} joints mapped" + f"Joint mapping initialized: {len(self._joint_mapping)} device joints mapped to {total_mappings} robot joints" ) - logger.log_info(f" Mapped joints: {list(self._joint_mapping.keys())}") + logger.log_info(f" Device joints: {list(self._joint_mapping.keys())}") + logger.log_info("=" * 60) def _enforce_joint_limits(self, action: torch.Tensor) -> torch.Tensor: """Enforce robot joint limits on action. @@ -257,6 +329,7 @@ def reset(self) -> None: # Reset mapping self._joint_mapping = {} self._mapping_initialized = False + self._binary_control_keys.clear() logger.log_info("Device Controller reset") From d9f33ebe7eb84174845bb4bb076c20b130bdf3a5 Mon Sep 17 00:00:00 2001 From: yuanhaonan Date: Fri, 30 Jan 2026 11:02:15 +0800 Subject: [PATCH 2/2] wip --- embodichain/lab/devices/device.py | 19 +++ embodichain/lab/devices/device_controller.py | 158 +++---------------- 2 files changed, 43 insertions(+), 134 deletions(-) diff --git a/embodichain/lab/devices/device.py b/embodichain/lab/devices/device.py index 4e75cd9c..3aba0197 100644 --- a/embodichain/lab/devices/device.py +++ b/embodichain/lab/devices/device.py @@ -15,6 +15,11 @@ # ---------------------------------------------------------------------------- import abc # for abstract base class definitions +from typing import TYPE_CHECKING, Optional, Dict +import torch + +if TYPE_CHECKING: + from embodichain.lab.sim.objects import Robot class Device(metaclass=abc.ABCMeta): @@ -42,3 +47,17 @@ def stop_control(self): def get_controller_state(self): """Returns the current state of the device, a dictionary of pos, orn, grasp, and reset.""" raise NotImplementedError + + def map_to_robot( + self, robot: "Robot", device_data: Dict[str, float] + ) -> Optional[torch.Tensor]: + """Map device input to robot action (optional, device-specific). + + Args: + robot: Robot instance to control. + device_data: Device input data. + + Returns: + Robot action tensor [num_envs, num_joints], or None if not implemented. + """ + return None # Default: no custom mapping diff --git a/embodichain/lab/devices/device_controller.py b/embodichain/lab/devices/device_controller.py index a9bf6203..d950b861 100644 --- a/embodichain/lab/devices/device_controller.py +++ b/embodichain/lab/devices/device_controller.py @@ -68,11 +68,6 @@ def __init__( if device is not None: self.add_device(device, device_name, set_active=True) - # Joint mapping state: maps device joint name -> robot joint index - self._joint_mapping: Dict[str, int] = {} - self._mapping_initialized = False - self._binary_control_keys = set() # Device keys that need binary control - logger.log_info(f"Device Controller initialized for robot: {robot.uid}") logger.log_info(f" Robot has {len(robot.joint_names)} joints") @@ -158,69 +153,39 @@ def get_action( if device_data is None: return None - # Map device data to robot action - return self._map_device_to_robot(device_data, as_dict=as_dict) + # Priority 1: Use device-specific mapping if available + action = device.map_to_robot(self.robot, device_data) - def _map_device_to_robot( - self, device_data: Dict[str, float], as_dict: bool = False - ) -> Union[torch.Tensor, Dict[str, float], None]: - """Map device input to robot action. + # Priority 2: Use generic mapping as fallback + if action is None: + action = self._generic_mapping(device_data) + + return action + + def _generic_mapping(self, device_data: Dict[str, float]) -> Optional[torch.Tensor]: + """Generic device-to-robot mapping (simple direct name matching). + + This is a fallback for devices that don't implement custom mapping. + Only supports direct joint name matching (e.g., keyboard, simple controllers). Args: device_data: Device joint data (joint_name -> value). - as_dict: Whether to return as dict instead of tensor. Returns: - Robot action or None if mapping failed. + Robot action tensor or None if no joints matched. """ try: robot_joint_names = self.robot.joint_names + qpos_limits = self.robot.body_data.qpos_limits[0] - if not self._mapping_initialized: - self._build_joint_mapping(device_data, robot_joint_names) - - # Get robot joint limits - qpos_limits = self.robot.body_data.qpos_limits[0] # [num_joints, 2] - - # Extract joint values based on mapping joint_values = [] joint_indices = [] - mapped_joints = {} - - gripper_threshold = 0.99 - - for device_joint, robot_indices in self._joint_mapping.items(): - if device_joint in device_data: - vr_value = device_data[device_joint] - - # Binary control for grippers, direct value for other joints - if device_joint in self._binary_control_keys: - idx = robot_indices[0] - joint_min, joint_max = ( - qpos_limits[idx, 0].item(), - qpos_limits[idx, 1].item(), - ) - robot_value = ( - joint_min if vr_value >= gripper_threshold else joint_max - ) - - # Log state change - state = "CLOSE" if vr_value >= gripper_threshold else "OPEN" - if not hasattr(self, "_last_gripper_state"): - self._last_gripper_state = {} - if self._last_gripper_state.get(device_joint) != state: - logger.log_info( - f"{device_joint}: VR={vr_value:.3f} -> {state} (robot={robot_value:.3f}, limit=[{joint_min:.3f}, {joint_max:.3f}])" - ) - self._last_gripper_state[device_joint] = state - else: - robot_value = vr_value - - # One device value may correspond to multiple robot joints (e.g., two fingers of gripper) - for robot_idx in robot_indices: - joint_values.append(robot_value) - joint_indices.append(robot_idx) - mapped_joints[robot_joint_names[robot_idx]] = robot_value + + # Direct name matching only + for robot_idx, robot_joint in enumerate(robot_joint_names): + if robot_joint in device_data: + joint_values.append(device_data[robot_joint]) + joint_indices.append(robot_idx) if len(joint_indices) == 0: return None @@ -234,91 +199,21 @@ def _map_device_to_robot( ) # Get current robot qpos - current_qpos = self.robot.get_qpos() # [num_envs, num_joints] + current_qpos = self.robot.get_qpos() # Create action by updating controlled joints action = current_qpos.clone() action[:, indices_tensor] = device_tensor.unsqueeze(0) # Enforce joint limits - action = self._enforce_joint_limits(action) + action = torch.clamp(action, qpos_limits[:, 0], qpos_limits[:, 1]) return action except Exception as e: - logger.log_error(f"Error mapping device data to robot action: {e}") - import traceback - - logger.log_error(traceback.format_exc()) + logger.log_error(f"Error in generic device mapping: {e}") return None - def _build_joint_mapping( - self, device_data: Dict[str, float], robot_joint_names: List[str] - ) -> None: - """Build mapping from device joints to robot joints.""" - # Store: device_joint_name -> [robot_joint_indices] - self._joint_mapping = {} - - logger.log_info("=" * 60) - logger.log_info("Building joint mapping...") - logger.log_info( - f"VR device data keys ({len(device_data)} total): {sorted(device_data.keys())}" - ) - logger.log_info( - f"Robot joint names ({len(robot_joint_names)} total): {robot_joint_names}" - ) - logger.log_info(f"Robot control_parts: {self.robot.control_parts}") - - # Direct name matching - for robot_idx, robot_joint in enumerate(robot_joint_names): - if robot_joint in device_data: - self._joint_mapping.setdefault(robot_joint, []).append(robot_idx) - - # Special gripper mapping: LEFT_GRIPPER/RIGHT_GRIPPER -> multiple finger joints - # Only for parallel grippers (< 5 joints), not dexterous hands (>= 5 joints) - if self.robot.control_parts: - for part_name, joint_names in self.robot.control_parts.items(): - if "eef" in part_name.lower() and len(joint_names) < 5: - device_key = ( - f"{'LEFT' if 'left' in part_name.lower() else 'RIGHT'}_GRIPPER" - ) - if ( - device_key in device_data - and device_key not in robot_joint_names - ): - indices = [ - robot_joint_names.index(j) - for j in joint_names - if j in robot_joint_names - ] - if indices: - self._joint_mapping[device_key] = indices - self._binary_control_keys.add(device_key) - logger.log_info( - f" Gripper mapping: '{device_key}' -> {joint_names} (indices: {indices})" - ) - - self._mapping_initialized = True - - total_mappings = sum(len(v) for v in self._joint_mapping.values()) - logger.log_info( - f"Joint mapping initialized: {len(self._joint_mapping)} device joints mapped to {total_mappings} robot joints" - ) - logger.log_info(f" Device joints: {list(self._joint_mapping.keys())}") - logger.log_info("=" * 60) - - def _enforce_joint_limits(self, action: torch.Tensor) -> torch.Tensor: - """Enforce robot joint limits on action. - - Args: - action: Action tensor [num_envs, num_joints]. - - Returns: - Clamped action tensor. - """ - qpos_limits = self.robot.body_data.qpos_limits[0] # [num_joints, 2] - return torch.clamp(action, qpos_limits[:, 0], qpos_limits[:, 1]) - def reset(self) -> None: """Reset controller state.""" # Reset all devices @@ -326,11 +221,6 @@ def reset(self) -> None: if hasattr(device, "reset"): device.reset() - # Reset mapping - self._joint_mapping = {} - self._mapping_initialized = False - self._binary_control_keys.clear() - logger.log_info("Device Controller reset") def get_device_info(self, device_name: Optional[str] = None) -> Dict[str, Any]: