Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions embodichain/lab/devices/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
# ----------------------------------------------------------------------------

import abc # for abstract base class definitions
from typing import TYPE_CHECKING, Optional, Dict
import torch

if TYPE_CHECKING:
from embodichain.lab.sim.objects import Robot


class Device(metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -42,3 +47,17 @@ def stop_control(self):
def get_controller_state(self):
"""Returns the current state of the device, a dictionary of pos, orn, grasp, and reset."""
raise NotImplementedError

def map_to_robot(
self, robot: "Robot", device_data: Dict[str, float]
) -> Optional[torch.Tensor]:
"""Map device input to robot action (optional, device-specific).

Args:
robot: Robot instance to control.
device_data: Device input data.

Returns:
Robot action tensor [num_envs, num_joints], or None if not implemented.
"""
return None # Default: no custom mapping
101 changes: 32 additions & 69 deletions embodichain/lab/devices/device_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,19 @@ def __init__(
if device is not None:
self.add_device(device, device_name, set_active=True)

# Joint mapping state: maps device joint name -> robot joint index
self._joint_mapping: Dict[str, int] = {}
self._mapping_initialized = False

logger.log_info(f"Device Controller initialized for robot: {robot.uid}")
logger.log_info(f" Robot has {len(robot.joint_names)} joints")

def _get_gripper_joints(self) -> List[str]:
"""Get gripper joint names from control_parts."""
if self.robot.control_parts:
gripper_joints = []
for part_name, joint_names in self.robot.control_parts.items():
if "eef" in part_name.lower():
gripper_joints.extend(joint_names)
return gripper_joints
return []

def add_device(
self, device: Device, device_name: str, set_active: bool = False
) -> None:
Expand Down Expand Up @@ -147,47 +153,43 @@ def get_action(
if device_data is None:
return None

# Map device data to robot action
return self._map_device_to_robot(device_data, as_dict=as_dict)
# Priority 1: Use device-specific mapping if available
action = device.map_to_robot(self.robot, device_data)

def _map_device_to_robot(
self, device_data: Dict[str, float], as_dict: bool = False
) -> Union[torch.Tensor, Dict[str, float], None]:
"""Map device input to robot action.
# Priority 2: Use generic mapping as fallback
if action is None:
action = self._generic_mapping(device_data)

return action

def _generic_mapping(self, device_data: Dict[str, float]) -> Optional[torch.Tensor]:
"""Generic device-to-robot mapping (simple direct name matching).

This is a fallback for devices that don't implement custom mapping.
Only supports direct joint name matching (e.g., keyboard, simple controllers).

Args:
device_data: Device joint data (joint_name -> value).
as_dict: Whether to return as dict instead of tensor.

Returns:
Robot action or None if mapping failed.
Robot action tensor or None if no joints matched.
"""
try:
robot_joint_names = self.robot.joint_names
qpos_limits = self.robot.body_data.qpos_limits[0]

# Build joint mapping on first call
if not self._mapping_initialized:
self._build_joint_mapping(device_data, robot_joint_names)

# Extract joint values based on mapping
joint_values = []
joint_indices = []
mapped_joints = {}

for device_joint, robot_idx in self._joint_mapping.items():
if device_joint in device_data:
value = device_data[device_joint]
joint_values.append(value)
# Direct name matching only
for robot_idx, robot_joint in enumerate(robot_joint_names):
if robot_joint in device_data:
joint_values.append(device_data[robot_joint])
joint_indices.append(robot_idx)
mapped_joints[robot_joint_names[robot_idx]] = value

if len(joint_indices) == 0:
return None

# Return as dict if requested
if as_dict:
return mapped_joints

# Convert to tensor and create full action
device_tensor = torch.tensor(
joint_values, dtype=torch.float32, device=self.robot.device
Expand All @@ -197,67 +199,28 @@ def _map_device_to_robot(
)

# Get current robot qpos
current_qpos = self.robot.get_qpos() # [num_envs, num_joints]
current_qpos = self.robot.get_qpos()

# Create action by updating controlled joints
action = current_qpos.clone()
action[:, indices_tensor] = device_tensor.unsqueeze(0)

# Enforce joint limits
action = self._enforce_joint_limits(action)
action = torch.clamp(action, qpos_limits[:, 0], qpos_limits[:, 1])

return action

except Exception as e:
logger.log_error(f"Error mapping device data to robot action: {e}")
logger.log_error(f"Error in generic device mapping: {e}")
return None

def _build_joint_mapping(
self, device_data: Dict[str, float], robot_joint_names: List[str]
) -> None:
"""Build mapping from device joints to robot joints.

Args:
device_data: Device joint data to determine available joints.
robot_joint_names: Robot joint names.
"""
self._joint_mapping = {}

for robot_idx, robot_joint in enumerate(robot_joint_names):
# Direct name matching (can be extended with more sophisticated mapping)
if robot_joint in device_data:
self._joint_mapping[robot_joint] = robot_idx

self._mapping_initialized = True

logger.log_info(
f"Joint mapping initialized: {len(self._joint_mapping)} joints mapped"
)
logger.log_info(f" Mapped joints: {list(self._joint_mapping.keys())}")

def _enforce_joint_limits(self, action: torch.Tensor) -> torch.Tensor:
"""Enforce robot joint limits on action.

Args:
action: Action tensor [num_envs, num_joints].

Returns:
Clamped action tensor.
"""
qpos_limits = self.robot.body_data.qpos_limits[0] # [num_joints, 2]
return torch.clamp(action, qpos_limits[:, 0], qpos_limits[:, 1])

def reset(self) -> None:
"""Reset controller state."""
# Reset all devices
for device in self._devices.values():
if hasattr(device, "reset"):
device.reset()

# Reset mapping
self._joint_mapping = {}
self._mapping_initialized = False

logger.log_info("Device Controller reset")

def get_device_info(self, device_name: Optional[str] = None) -> Dict[str, Any]:
Expand Down
Loading