diff --git a/embodichain/lab/devices/device.py b/embodichain/lab/devices/device.py index 4e75cd9c..3aba0197 100644 --- a/embodichain/lab/devices/device.py +++ b/embodichain/lab/devices/device.py @@ -15,6 +15,11 @@ # ---------------------------------------------------------------------------- import abc # for abstract base class definitions +from typing import TYPE_CHECKING, Optional, Dict +import torch + +if TYPE_CHECKING: + from embodichain.lab.sim.objects import Robot class Device(metaclass=abc.ABCMeta): @@ -42,3 +47,17 @@ def stop_control(self): def get_controller_state(self): """Returns the current state of the device, a dictionary of pos, orn, grasp, and reset.""" raise NotImplementedError + + def map_to_robot( + self, robot: "Robot", device_data: Dict[str, float] + ) -> Optional[torch.Tensor]: + """Map device input to robot action (optional, device-specific). + + Args: + robot: Robot instance to control. + device_data: Device input data. + + Returns: + Robot action tensor [num_envs, num_joints], or None if not implemented. + """ + return None # Default: no custom mapping diff --git a/embodichain/lab/devices/device_controller.py b/embodichain/lab/devices/device_controller.py index 2ae48acc..d950b861 100644 --- a/embodichain/lab/devices/device_controller.py +++ b/embodichain/lab/devices/device_controller.py @@ -68,13 +68,19 @@ def __init__( if device is not None: self.add_device(device, device_name, set_active=True) - # Joint mapping state: maps device joint name -> robot joint index - self._joint_mapping: Dict[str, int] = {} - self._mapping_initialized = False - logger.log_info(f"Device Controller initialized for robot: {robot.uid}") logger.log_info(f" Robot has {len(robot.joint_names)} joints") + def _get_gripper_joints(self) -> List[str]: + """Get gripper joint names from control_parts.""" + if self.robot.control_parts: + gripper_joints = [] + for part_name, joint_names in self.robot.control_parts.items(): + if "eef" in part_name.lower(): + gripper_joints.extend(joint_names) + return gripper_joints + return [] + def add_device( self, device: Device, device_name: str, set_active: bool = False ) -> None: @@ -147,47 +153,43 @@ def get_action( if device_data is None: return None - # Map device data to robot action - return self._map_device_to_robot(device_data, as_dict=as_dict) + # Priority 1: Use device-specific mapping if available + action = device.map_to_robot(self.robot, device_data) - def _map_device_to_robot( - self, device_data: Dict[str, float], as_dict: bool = False - ) -> Union[torch.Tensor, Dict[str, float], None]: - """Map device input to robot action. + # Priority 2: Use generic mapping as fallback + if action is None: + action = self._generic_mapping(device_data) + + return action + + def _generic_mapping(self, device_data: Dict[str, float]) -> Optional[torch.Tensor]: + """Generic device-to-robot mapping (simple direct name matching). + + This is a fallback for devices that don't implement custom mapping. + Only supports direct joint name matching (e.g., keyboard, simple controllers). Args: device_data: Device joint data (joint_name -> value). - as_dict: Whether to return as dict instead of tensor. Returns: - Robot action or None if mapping failed. + Robot action tensor or None if no joints matched. """ try: robot_joint_names = self.robot.joint_names + qpos_limits = self.robot.body_data.qpos_limits[0] - # Build joint mapping on first call - if not self._mapping_initialized: - self._build_joint_mapping(device_data, robot_joint_names) - - # Extract joint values based on mapping joint_values = [] joint_indices = [] - mapped_joints = {} - for device_joint, robot_idx in self._joint_mapping.items(): - if device_joint in device_data: - value = device_data[device_joint] - joint_values.append(value) + # Direct name matching only + for robot_idx, robot_joint in enumerate(robot_joint_names): + if robot_joint in device_data: + joint_values.append(device_data[robot_joint]) joint_indices.append(robot_idx) - mapped_joints[robot_joint_names[robot_idx]] = value if len(joint_indices) == 0: return None - # Return as dict if requested - if as_dict: - return mapped_joints - # Convert to tensor and create full action device_tensor = torch.tensor( joint_values, dtype=torch.float32, device=self.robot.device @@ -197,56 +199,21 @@ def _map_device_to_robot( ) # Get current robot qpos - current_qpos = self.robot.get_qpos() # [num_envs, num_joints] + current_qpos = self.robot.get_qpos() # Create action by updating controlled joints action = current_qpos.clone() action[:, indices_tensor] = device_tensor.unsqueeze(0) # Enforce joint limits - action = self._enforce_joint_limits(action) + action = torch.clamp(action, qpos_limits[:, 0], qpos_limits[:, 1]) return action except Exception as e: - logger.log_error(f"Error mapping device data to robot action: {e}") + logger.log_error(f"Error in generic device mapping: {e}") return None - def _build_joint_mapping( - self, device_data: Dict[str, float], robot_joint_names: List[str] - ) -> None: - """Build mapping from device joints to robot joints. - - Args: - device_data: Device joint data to determine available joints. - robot_joint_names: Robot joint names. - """ - self._joint_mapping = {} - - for robot_idx, robot_joint in enumerate(robot_joint_names): - # Direct name matching (can be extended with more sophisticated mapping) - if robot_joint in device_data: - self._joint_mapping[robot_joint] = robot_idx - - self._mapping_initialized = True - - logger.log_info( - f"Joint mapping initialized: {len(self._joint_mapping)} joints mapped" - ) - logger.log_info(f" Mapped joints: {list(self._joint_mapping.keys())}") - - def _enforce_joint_limits(self, action: torch.Tensor) -> torch.Tensor: - """Enforce robot joint limits on action. - - Args: - action: Action tensor [num_envs, num_joints]. - - Returns: - Clamped action tensor. - """ - qpos_limits = self.robot.body_data.qpos_limits[0] # [num_joints, 2] - return torch.clamp(action, qpos_limits[:, 0], qpos_limits[:, 1]) - def reset(self) -> None: """Reset controller state.""" # Reset all devices @@ -254,10 +221,6 @@ def reset(self) -> None: if hasattr(device, "reset"): device.reset() - # Reset mapping - self._joint_mapping = {} - self._mapping_initialized = False - logger.log_info("Device Controller reset") def get_device_info(self, device_name: Optional[str] = None) -> Dict[str, Any]: