Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 44 additions & 3 deletions embodichain/lab/gym/envs/managers/randomization/physics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
from __future__ import annotations

import torch
from typing import TYPE_CHECKING, Union, List
from typing import TYPE_CHECKING

from embodichain.lab.sim.objects import RigidObject, Robot
from embodichain.lab.gym.envs.managers.cfg import SceneEntityCfg
from embodichain.utils.math import sample_uniform
from embodichain.utils import logger


if TYPE_CHECKING:
Expand All @@ -30,7 +31,7 @@

def randomize_rigid_object_mass(
env: EmbodiedEnv,
env_ids: Union[torch.Tensor, List[int]],
env_ids: torch.Tensor | list[int],
entity_cfg: SceneEntityCfg,
mass_range: tuple[float, float],
relative: bool = False,
Expand All @@ -39,7 +40,7 @@ def randomize_rigid_object_mass(

Args:
env (EmbodiedEnv): The environment instance.
env_ids (Union[torch.Tensor, List[int]]): The environment IDs to apply the randomization.
env_ids (torch.Tensor | list[int]): The environment IDs to apply the randomization.
entity_cfg (SceneEntityCfg): The configuration for the scene entity.
mass_range (tuple[float, float]): The range (min, max) to sample the mass from.
relative (bool): Whether to apply the mass change relative to the initial mass. Defaults to False.
Expand All @@ -61,3 +62,43 @@ def randomize_rigid_object_mass(
sampled_masses = init_mass + sampled_masses

rigid_object.set_mass(sampled_masses, env_ids=env_ids)


def randomize_rigid_object_center_of_mass(
env: EmbodiedEnv,
env_ids: torch.Tensor | list[int],
entity_cfg: SceneEntityCfg,
com_pos_offset_range: tuple[list[float], list[float]],
) -> None:
"""Randomize the center of mass of rigid objects in the environment.

Args:
env (EmbodiedEnv): The environment instance.
env_ids (torch.Tensor | list[int]): The environment IDs to apply the randomization.
entity_cfg (SceneEntityCfg): The configuration for the scene entity.
com_pos_offset_range (tuple[list[float], list[float]]): The range (min, max) to sample the center of mass offset from.
"""

if entity_cfg.uid not in env.sim.get_rigid_object_uid_list():
return

rigid_object: RigidObject = env.sim.get_rigid_object(entity_cfg.uid)
if rigid_object.is_non_dynamic:
logger.log_warning(
f"Cannot randomize center of mass for non-dynamic rigid object '{entity_cfg.uid}'."
)
return

num_instance = len(env_ids)

sampled_com_pos_offsets = sample_uniform(
lower=com_pos_offset_range[0],
upper=com_pos_offset_range[1],
size=(num_instance, 3),
)

com = rigid_object.body_data.default_com_pose[env_ids]
updated_com = com.clone()
updated_com[:, 0:3] += sampled_com_pos_offsets

rigid_object.set_com_pose(updated_com, env_ids=env_ids)
82 changes: 82 additions & 0 deletions embodichain/lab/sim/objects/rigid_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,13 @@ def __init__(
self._ang_vel = torch.zeros(
(self.num_instances, 3), dtype=torch.float32, device=self.device
)
# center of mass pose in format (x, y, z, qw, qx, qy, qz)
self.default_com_pose = torch.zeros(
(self.num_instances, 7), dtype=torch.float32, device=self.device
)
self._com_pose = torch.zeros(
(self.num_instances, 7), dtype=torch.float32, device=self.device
)

@property
def pose(self) -> torch.Tensor:
Expand Down Expand Up @@ -154,6 +161,23 @@ def vel(self) -> torch.Tensor:
"""
return torch.cat((self.lin_vel, self.ang_vel), dim=-1)

@property
def com_pose(self) -> torch.Tensor:
"""Get the center of mass pose of the rigid bodies.

Returns:
torch.Tensor: The center of mass pose with shape (N, 7).
"""
for i, entity in enumerate(self.entities):
pos, quat = entity.get_physical_body().get_cmass_local_pose()
self._com_pose[i, :3] = torch.as_tensor(
pos, dtype=torch.float32, device=self.device
)
self._com_pose[i, 3:7] = torch.as_tensor(
quat, dtype=torch.float32, device=self.device
)
return self._com_pose


class RigidObject(BatchEntity):
"""RigidObject represents a batch of rigid body in the simulation.
Expand Down Expand Up @@ -198,6 +222,10 @@ def __init__(
# set default collision filter
self._set_default_collision_filter()

# update default center of mass pose (only for non-static bodies with body data).
if self.body_data is not None:
self.body_data.default_com_pose = self.body_data.com_pose.clone()

# TODO: Must be called after setting all attributes.
# May be improved in the future.
if cfg.attrs.enable_collision is False:
Expand Down Expand Up @@ -626,6 +654,60 @@ def set_body_scale(
else:
logger.log_error(f"Setting body scale on GPU is not supported yet.")

def set_com_pose(
self, com_pose: torch.Tensor, env_ids: Sequence[int] | None = None
) -> None:
"""Set the center of mass pose of the rigid body. The pose format is (x, y, z, qw, qx, qy, qz).

Args:
com_pose (torch.Tensor): The center of mass pose to set with shape (N, 7).
env_ids (Sequence[int] | None, optional): Environment indices. If None, then all indices are used.
"""
if self.is_non_dynamic:
logger.log_warning(
"Cannot set center of mass pose for non-dynamic rigid body."
)
return

local_env_ids = self._all_indices if env_ids is None else env_ids

if len(local_env_ids) != len(com_pose):
logger.log_error(
f"Length of env_ids {len(local_env_ids)} does not match com_pose length {len(com_pose)}."
)

com_pose = com_pose.cpu().numpy()
for i, env_idx in enumerate(local_env_ids):
pos = com_pose[i, :3]
quat = com_pose[i, 3:7]
self._entities[env_idx].get_physical_body().set_cmass_local_pose(pos, quat)

def set_body_type(self, body_type: str) -> None:
"""Set the body type of the rigid object.

Note:
Only 'dynamic' and 'kinematic' body types are supported and can be changed at runtime.

Args:
body_type (str): The body type to set. Must be one of 'dynamic', or 'kinematic'.
"""
from dexsim.types import ActorType

if body_type not in ("dynamic", "kinematic"):
logger.log_error(
f"Invalid body type {body_type}. Must be one of 'dynamic', or 'kinematic'."
)

if body_type == "dynamic":
actor_type = ActorType.DYNAMIC
else:
actor_type = ActorType.KINEMATIC

for entity in self._entities:
entity.set_actor_type(actor_type)

self.body_type = body_type

def get_vertices(self, env_ids: Sequence[int] | None = None) -> torch.Tensor:
"""
Retrieve the vertices of the rigid objects.
Expand Down
Loading