[Question] Implementation of Multiple Contact Sensors on a Single Body and Identification of Contacting Bodies #2026
Replies: 4 comments 1 reply
-
Thank you for posting this. I will move this to our Discussions for the team to follow up. |
Beta Was this translation helpful? Give feedback.
-
Quick update: I'm trying to implement the So far I have modified the contact_sensor.py, contact_sensor_data.py, and contact_sensor_cfg.py scripts as follow: contact_sensor_cfg.py# Copyright (c) 2022-2025, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
from isaaclab.markers import VisualizationMarkersCfg
from isaaclab.markers.config import CONTACT_SENSOR_MARKER_CFG
from isaaclab.utils import configclass
from ..sensor_base_cfg import SensorBaseCfg
from .contact_sensor import ContactSensor
@configclass
class ContactSensorCfg(SensorBaseCfg):
"""Configuration for the contact sensor."""
# Existing code...
max_contact_data_count: int = 100
"""The maximum number of contact data entries to store globally for detailed contact reporting.
This determines the size of the buffers for forces, points, normals, and separation distances.
Defaults to 100.
""" contact_sensor_data.py# Copyright (c) 2022-2025, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
# needed to import for allowing type-hinting: torch.Tensor | None
from __future__ import annotations
import torch
from dataclasses import dataclass
@dataclass
class ContactSensorData:
# Existing code...
contact_forces_buffer: torch.Tensor | None = None
"""Buffer storing detailed contact forces per patch.
Shape is (max_contact_data_count, 1), where each entry is the force magnitude at a contact point.
Note:
If the :attr:`ContactSensorCfg.filter_prim_paths_expr` is empty, then this quantity is None.
If the :attr: `ContactSensorCfg.max_contact_data_count` is 0, then this quantity is None.
"""
contact_points_buffer: torch.Tensor | None = None
"""Buffer storing contact points in world frame.
Shape is (max_contact_data_count, 3), where each entry is the (x, y, z) position of a contact point.
Note:
If the :attr:`ContactSensorCfg.filter_prim_paths_expr` is empty, then this quantity is None.
If the :attr: `ContactSensorCfg.max_contact_data_count` is 0, then this quantity is None.
"""
contact_normals_buffer: torch.Tensor | None = None
"""Buffer storing contact normals in world frame.
Shape is (max_contact_data_count, 3), where each entry is the (x, y, z) normal vector at a contact point.
Note:
If the :attr:`ContactSensorCfg.filter_prim_paths_expr` is empty, then this quantity is None.
If the :attr: `ContactSensorCfg.max_contact_data_count` is 0, then this quantity is None.
"""
contact_separation_distances_buffer: torch.Tensor | None = None
"""Buffer storing separation distances at contact points.
Shape is (max_contact_data_count, 1), where each entry is the separation distance.
Note:
If the :attr:`ContactSensorCfg.filter_prim_paths_expr` is empty, then this quantity is None.
If the :attr: `ContactSensorCfg.max_contact_data_count` is 0, then this quantity is None.
"""
contact_count_buffer: torch.Tensor | None = None
"""Number of active contacts per sensor-filter pair.
Shape is (num_envs, num_bodies, num_filters), where each entry is the count of contacts.
Note:
If the :attr:`ContactSensorCfg.filter_prim_paths_expr` is empty, then this quantity is None.
If the :attr: `ContactSensorCfg.max_contact_data_count` is 0, then this quantity is None.
"""
contact_start_indices_buffer: torch.Tensor | None = None
"""Start indices in buffers for each sensor-filter pair.
Shape is (num_envs, num_bodies, num_filters), where each entry points to the start index in data buffers.
Note:
If the :attr:`ContactSensorCfg.filter_prim_paths_expr` is empty, then this quantity is None.
If the :attr: `ContactSensorCfg.max_contact_data_count` is 0, then this quantity is None.
""" contact_sensor.py# Copyright (c) 2022-2025, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
# Ignore optional memory usage warning globally
# pyright: reportOptionalSubscript=false
from __future__ import annotations
import torch
from collections.abc import Sequence
from typing import TYPE_CHECKING
import omni.physics.tensors.impl.api as physx
from pxr import PhysxSchema
import isaaclab.sim as sim_utils
import isaaclab.utils.string as string_utils
from isaaclab.markers import VisualizationMarkers
from isaaclab.utils.math import convert_quat
from ..sensor_base import SensorBase
from .contact_sensor_data import ContactSensorData
if TYPE_CHECKING:
from .contact_sensor_cfg import ContactSensorCfg
class ContactSensor(SensorBase):
"""A contact reporting sensor.
The contact sensor reports the normal contact forces on a rigid body in the world frame.
It relies on the `PhysX ContactReporter`_ API to be activated on the rigid bodies.
To enable the contact reporter on a rigid body, please make sure to enable the
:attr:`isaaclab.sim.spawner.RigidObjectSpawnerCfg.activate_contact_sensors` on your
asset spawner configuration. This will enable the contact reporter on all the rigid bodies
in the asset.
The sensor can be configured to report the contact forces on a set of bodies with a given
filter pattern using the :attr:`ContactSensorCfg.filter_prim_paths_expr`. This is useful
when you want to report the contact forces between the sensor bodies and a specific set of
bodies in the scene. The data can be accessed using the :attr:`ContactSensorData.force_matrix_w`.
Please check the documentation on `RigidContact`_ for more details.
The reporting of the filtered contact forces is only possible as one-to-many. This means that only one
sensor body in an environment can be filtered against multiple bodies in that environment. If you need to
filter multiple sensor bodies against multiple bodies, you need to create separate sensors for each sensor
body.
As an example, suppose you want to report the contact forces for all the feet of a robot against an object
exclusively. In that case, setting the :attr:`ContactSensorCfg.prim_path` and
:attr:`ContactSensorCfg.filter_prim_paths_expr` with ``{ENV_REGEX_NS}/Robot/.*_FOOT`` and ``{ENV_REGEX_NS}/Object``
respectively will not work. Instead, you need to create a separate sensor for each foot and filter
it against the object.
.. _PhysX ContactReporter: https://docs.omniverse.nvidia.com/kit/docs/omni_usd_schema_physics/104.2/class_physx_schema_physx_contact_report_a_p_i.html
.. _RigidContact: https://docs.omniverse.nvidia.com/py/isaacsim/source/isaacsim.core/docs/index.html#isaacsim.core.prims.RigidContact
"""
cfg: ContactSensorCfg
"""The configuration parameters."""
def __init__(self, cfg: ContactSensorCfg):
"""Initializes the contact sensor object.
Args:
cfg: The configuration parameters.
"""
# initialize base class
super().__init__(cfg)
# Create empty variables for storing output data
self._data: ContactSensorData = ContactSensorData()
# initialize self._body_physx_view for running in extension mode
self._body_physx_view = None
def __str__(self) -> str:
"""Returns: A string containing information about the instance."""
return (
f"Contact sensor @ '{self.cfg.prim_path}': \n"
f"\tview type : {self.body_physx_view.__class__}\n"
f"\tupdate period (s) : {self.cfg.update_period}\n"
f"\tnumber of bodies : {self.num_bodies}\n"
f"\tbody names : {self.body_names}\n"
)
"""
Properties
"""
@property
def num_instances(self) -> int:
return self.body_physx_view.count
@property
def data(self) -> ContactSensorData:
# update sensors if needed
self._update_outdated_buffers()
# return the data
return self._data
@property
def num_bodies(self) -> int:
"""Number of bodies with contact sensors attached."""
return self._num_bodies
@property
def body_names(self) -> list[str]:
"""Ordered names of bodies with contact sensors attached."""
prim_paths = self.body_physx_view.prim_paths[: self.num_bodies]
return [path.split("/")[-1] for path in prim_paths]
@property
def body_physx_view(self) -> physx.RigidBodyView:
"""View for the rigid bodies captured (PhysX).
Note:
Use this view with caution. It requires handling of tensors in a specific way.
"""
return self._body_physx_view
@property
def contact_physx_view(self) -> physx.RigidContactView:
"""Contact reporter view for the bodies (PhysX).
Note:
Use this view with caution. It requires handling of tensors in a specific way.
"""
return self._contact_physx_view
"""
Operations
"""
def reset(self, env_ids: Sequence[int] | None = None):
# reset the timers and counters
super().reset(env_ids)
# resolve None
if env_ids is None:
env_ids = slice(None)
# reset accumulative data buffers
self._data.net_forces_w[env_ids] = 0.0
self._data.net_forces_w_history[env_ids] = 0.0
if self.cfg.history_length > 0:
self._data.net_forces_w_history[env_ids] = 0.0
# reset force matrix
if len(self.cfg.filter_prim_paths_expr) != 0:
self._data.force_matrix_w[env_ids] = 0.0
# TODO reset contact data buffers
# reset the current air time
if self.cfg.track_air_time:
self._data.current_air_time[env_ids] = 0.0
self._data.last_air_time[env_ids] = 0.0
self._data.current_contact_time[env_ids] = 0.0
self._data.last_contact_time[env_ids] = 0.0
def find_bodies(self, name_keys: str | Sequence[str], preserve_order: bool = False) -> tuple[list[int], list[str]]:
"""Find bodies in the articulation based on the name keys.
Args:
name_keys: A regular expression or a list of regular expressions to match the body names.
preserve_order: Whether to preserve the order of the name keys in the output. Defaults to False.
Returns:
A tuple of lists containing the body indices and names.
"""
return string_utils.resolve_matching_names(name_keys, self.body_names, preserve_order)
def compute_first_contact(self, dt: float, abs_tol: float = 1.0e-8) -> torch.Tensor:
"""Checks if bodies that have established contact within the last :attr:`dt` seconds.
This function checks if the bodies have established contact within the last :attr:`dt` seconds
by comparing the current contact time with the given time period. If the contact time is less
than the given time period, then the bodies are considered to be in contact.
Note:
The function assumes that :attr:`dt` is a factor of the sensor update time-step. In other
words :math:`dt / dt_sensor = n`, where :math:`n` is a natural number. This is always true
if the sensor is updated by the physics or the environment stepping time-step and the sensor
is read by the environment stepping time-step.
Args:
dt: The time period since the contact was established.
abs_tol: The absolute tolerance for the comparison.
Returns:
A boolean tensor indicating the bodies that have established contact within the last
:attr:`dt` seconds. Shape is (N, B), where N is the number of sensors and B is the
number of bodies in each sensor.
Raises:
RuntimeError: If the sensor is not configured to track contact time.
"""
# check if the sensor is configured to track contact time
if not self.cfg.track_air_time:
raise RuntimeError(
"The contact sensor is not configured to track contact time."
"Please enable the 'track_air_time' in the sensor configuration."
)
# check if the bodies are in contact
currently_in_contact = self.data.current_contact_time > 0.0
less_than_dt_in_contact = self.data.current_contact_time < (dt + abs_tol)
return currently_in_contact * less_than_dt_in_contact
def compute_first_air(self, dt: float, abs_tol: float = 1.0e-8) -> torch.Tensor:
"""Checks if bodies that have broken contact within the last :attr:`dt` seconds.
This function checks if the bodies have broken contact within the last :attr:`dt` seconds
by comparing the current air time with the given time period. If the air time is less
than the given time period, then the bodies are considered to not be in contact.
Note:
It assumes that :attr:`dt` is a factor of the sensor update time-step. In other words,
:math:`dt / dt_sensor = n`, where :math:`n` is a natural number. This is always true if
the sensor is updated by the physics or the environment stepping time-step and the sensor
is read by the environment stepping time-step.
Args:
dt: The time period since the contract is broken.
abs_tol: The absolute tolerance for the comparison.
Returns:
A boolean tensor indicating the bodies that have broken contact within the last :attr:`dt` seconds.
Shape is (N, B), where N is the number of sensors and B is the number of bodies in each sensor.
Raises:
RuntimeError: If the sensor is not configured to track contact time.
"""
# check if the sensor is configured to track contact time
if not self.cfg.track_air_time:
raise RuntimeError(
"The contact sensor is not configured to track contact time."
"Please enable the 'track_air_time' in the sensor configuration."
)
# check if the sensor is configured to track contact time
currently_detached = self.data.current_air_time > 0.0
less_than_dt_detached = self.data.current_air_time < (dt + abs_tol)
return currently_detached * less_than_dt_detached
"""
Implementation.
"""
def _initialize_impl(self):
super()._initialize_impl()
# create simulation view
self._physics_sim_view = physx.create_simulation_view(self._backend)
self._physics_sim_view.set_subspace_roots("/")
# check that only rigid bodies are selected
leaf_pattern = self.cfg.prim_path.rsplit("/", 1)[-1]
template_prim_path = self._parent_prims[0].GetPath().pathString
body_names = list()
for prim in sim_utils.find_matching_prims(template_prim_path + "/" + leaf_pattern):
# check if prim has contact reporter API
if prim.HasAPI(PhysxSchema.PhysxContactReportAPI):
prim_path = prim.GetPath().pathString
body_names.append(prim_path.rsplit("/", 1)[-1])
# check that there is at least one body with contact reporter API
if not body_names:
raise RuntimeError(
f"Sensor at path '{self.cfg.prim_path}' could not find any bodies with contact reporter API."
"\nHINT: Make sure to enable 'activate_contact_sensors' in the corresponding asset spawn configuration."
)
# construct regex expression for the body names
body_names_regex = r"(" + "|".join(body_names) + r")"
body_names_regex = f"{self.cfg.prim_path.rsplit('/', 1)[0]}/{body_names_regex}"
# convert regex expressions to glob expressions for PhysX
body_names_glob = body_names_regex.replace(".*", "*")
filter_prim_paths_glob = [expr.replace(".*", "*") for expr in self.cfg.filter_prim_paths_expr]
# create a rigid prim view for the sensor
self._body_physx_view = self._physics_sim_view.create_rigid_body_view(body_names_glob)
self._contact_physx_view = self._physics_sim_view.create_rigid_contact_view(
body_names_glob, filter_patterns=filter_prim_paths_glob,
max_contact_data_count = self.cfg.max_contact_data_count
)
# resolve the true count of bodies
self._num_bodies = self.body_physx_view.count // self._num_envs
# check that contact reporter succeeded
if self._num_bodies != len(body_names):
raise RuntimeError(
"Failed to initialize contact reporter for specified bodies."
f"\n\tInput prim path : {self.cfg.prim_path}"
f"\n\tResolved prim paths: {body_names_regex}"
)
# prepare data buffers
self._data.net_forces_w = torch.zeros(self._num_envs, self._num_bodies, 3, device=self._device)
# optional buffers
# -- history of net forces
if self.cfg.history_length > 0:
self._data.net_forces_w_history = torch.zeros(
self._num_envs, self.cfg.history_length, self._num_bodies, 3, device=self._device
)
else:
self._data.net_forces_w_history = self._data.net_forces_w.unsqueeze(1)
# -- pose of sensor origins
if self.cfg.track_pose:
self._data.pos_w = torch.zeros(self._num_envs, self._num_bodies, 3, device=self._device)
self._data.quat_w = torch.zeros(self._num_envs, self._num_bodies, 4, device=self._device)
# -- air/contact time between contacts
if self.cfg.track_air_time:
self._data.last_air_time = torch.zeros(self._num_envs, self._num_bodies, device=self._device)
self._data.current_air_time = torch.zeros(self._num_envs, self._num_bodies, device=self._device)
self._data.last_contact_time = torch.zeros(self._num_envs, self._num_bodies, device=self._device)
self._data.current_contact_time = torch.zeros(self._num_envs, self._num_bodies, device=self._device)
# force matrix: (num_envs, num_bodies, num_filter_shapes, 3)
if len(self.cfg.filter_prim_paths_expr) != 0:
num_filters = self.contact_physx_view.filter_count
self._data.force_matrix_w = torch.zeros(self._num_envs, self._num_bodies, num_filters, 3, device=self._device)
# contact data buffers TODO: check if this works
if self.cfg.max_contact_data_count > 0:
self._data.contact_forces_buffer = None
self._data.contact_points_buffer = None
self._data.contact_normals_buffer = None
self._data.contact_separation_distances_buffer = None
# Per-environment-body-filter indices
self._data.contact_start_indices_buffer = torch.zeros(self._num_envs, self._num_bodies, num_filters, dtype=torch.int32, device=self._device)#dtype=torch.int32
self._data.contact_count_buffer = torch.zeros(self._num_envs, self._num_bodies, num_filters, dtype=torch.int32, device=self._device)#dtype=torch.int32
def _update_buffers_impl(self, env_ids: Sequence[int]):
"""Fills the buffers of the sensor data."""
# default to all sensors
if len(env_ids) == self._num_envs:
env_ids = slice(None)
# obtain the contact forces
# TODO: We are handling the indexing ourself because of the shape; (N, B) vs expected (N * B).
# This isn't the most efficient way to do this, but it's the easiest to implement.
net_forces_w = self.contact_physx_view.get_net_contact_forces(dt=self._sim_physics_dt)
self._data.net_forces_w[env_ids, :, :] = net_forces_w.view(-1, self._num_bodies, 3)[env_ids]
# update contact force history
if self.cfg.history_length > 0:
self._data.net_forces_w_history[env_ids, 1:] = self._data.net_forces_w_history[env_ids, :-1].clone()
self._data.net_forces_w_history[env_ids, 0] = self._data.net_forces_w[env_ids]
# obtain the contact force matrix
if len(self.cfg.filter_prim_paths_expr) != 0:
# shape of the filtering matrix: (num_envs, num_bodies, num_filter_shapes, 3)
num_filters = self.contact_physx_view.filter_count
# acquire and shape the force matrix
force_matrix_w = self.contact_physx_view.get_contact_force_matrix(dt=self._sim_physics_dt)
force_matrix_w = force_matrix_w.view(-1, self._num_bodies, num_filters, 3)
self._data.force_matrix_w[env_ids] = force_matrix_w[env_ids]
# acquire and shape the contact data TODO: check if this works
if self.cfg.max_contact_data_count > 0:
(
contact_forces_buffer,
contact_points_buffer,
contact_normals_buffer,
contact_separation_distances_buffer,
contact_count_buffer,
start_indices_buffer,
) = self.contact_physx_view.get_contact_data(dt=self._sim_physics_dt)
# Update global buffers
self._data.contact_forces_buffer = contact_forces_buffer
self._data.contact_points_buffer = contact_points_buffer
self._data.contact_normals_buffer = contact_normals_buffer
self._data.contact_separation_distances_buffer = contact_separation_distances_buffer
# Reshape count and indices to (num_envs, num_bodies, num_filters)
contact_count = contact_count_buffer.view(self._num_envs, self._num_bodies, -1)
start_indices = start_indices_buffer.view(self._num_envs, self._num_bodies, -1)
self._data.contact_count_buffer[env_ids] = contact_count[env_ids]
self._data.contact_start_indices_buffer[env_ids] = start_indices[env_ids]
# obtain the pose of the sensor origin
if self.cfg.track_pose:
pose = self.body_physx_view.get_transforms().view(-1, self._num_bodies, 7)[env_ids]
pose[..., 3:] = convert_quat(pose[..., 3:], to="wxyz")
self._data.pos_w[env_ids], self._data.quat_w[env_ids] = pose.split([3, 4], dim=-1)
# obtain the air time
if self.cfg.track_air_time:
# -- time elapsed since last update
# since this function is called every frame, we can use the difference to get the elapsed time
elapsed_time = self._timestamp[env_ids] - self._timestamp_last_update[env_ids]
# -- check contact state of bodies
is_contact = torch.norm(self._data.net_forces_w[env_ids, :, :], dim=-1) > self.cfg.force_threshold
is_first_contact = (self._data.current_air_time[env_ids] > 0) * is_contact
is_first_detached = (self._data.current_contact_time[env_ids] > 0) * ~is_contact
# -- update the last contact time if body has just become in contact
self._data.last_air_time[env_ids] = torch.where(
is_first_contact,
self._data.current_air_time[env_ids] + elapsed_time.unsqueeze(-1),
self._data.last_air_time[env_ids],
)
# -- increment time for bodies that are not in contact
self._data.current_air_time[env_ids] = torch.where(
~is_contact, self._data.current_air_time[env_ids] + elapsed_time.unsqueeze(-1), 0.0
)
# -- update the last contact time if body has just detached
self._data.last_contact_time[env_ids] = torch.where(
is_first_detached,
self._data.current_contact_time[env_ids] + elapsed_time.unsqueeze(-1),
self._data.last_contact_time[env_ids],
)
# -- increment time for bodies that are in contact
self._data.current_contact_time[env_ids] = torch.where(
is_contact, self._data.current_contact_time[env_ids] + elapsed_time.unsqueeze(-1), 0.0
)
def _set_debug_vis_impl(self, debug_vis: bool):
# set visibility of markers
# note: parent only deals with callbacks. not their visibility
if debug_vis:
# create markers if necessary for the first tome
if not hasattr(self, "contact_visualizer"):
self.contact_visualizer = VisualizationMarkers(self.cfg.visualizer_cfg)
# set their visibility to true
self.contact_visualizer.set_visibility(True)
else:
if hasattr(self, "contact_visualizer"):
self.contact_visualizer.set_visibility(False)
def _debug_vis_callback(self, event):
# safely return if view becomes invalid
# note: this invalidity happens because of isaac sim view callbacks
if self.body_physx_view is None:
return
# marker indices
# 0: contact, 1: no contact
net_contact_force_w = torch.norm(self._data.net_forces_w, dim=-1)
marker_indices = torch.where(net_contact_force_w > self.cfg.force_threshold, 0, 1)
# check if prim is visualized
if self.cfg.track_pose:
frame_origins: torch.Tensor = self._data.pos_w
else:
pose = self.body_physx_view.get_transforms()
frame_origins = pose.view(-1, self._num_bodies, 7)[:, :, :3]
# visualize
self.contact_visualizer.visualize(frame_origins.view(-1, 3), marker_indices=marker_indices.view(-1))
"""
Internal simulation callbacks.
"""
def _invalidate_initialize_callback(self, event):
"""Invalidates the scene elements."""
# call parent
super()._invalidate_initialize_callback(event)
# set all existing views to None to invalidate them
self._physics_sim_view = None
self._body_physx_view = None
self._contact_physx_view = None Also, I created a simple simulation with cubes to test that a single sensor captures multiple object collision and points of contact test code# Copyright (c) 2022-2025, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
"""Launch Isaac Sim Simulator first."""
import argparse
from isaaclab.app import AppLauncher
# add argparse arguments
parser = argparse.ArgumentParser(description="Example of contact sensing between multiple cubes.")
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to spawn.")
# append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser)
# parse the arguments
args_cli = parser.parse_args()
# launch omniverse app
app_launcher = AppLauncher(args_cli)
simulation_app = app_launcher.app
"""Rest everything follows."""
import torch
import isaaclab.sim as sim_utils
from isaaclab.assets import AssetBaseCfg, RigidObjectCfg
from isaaclab.scene import InteractiveScene, InteractiveSceneCfg
from isaaclab.sensors import ContactSensorCfg
from isaaclab.utils import configclass
@configclass
class ContactSensorSceneCfg(InteractiveSceneCfg):
"""Design the scene with a large cube and multiple smaller cubes."""
# ground plane
ground = AssetBaseCfg(prim_path="/World/defaultGroundPlane", spawn=sim_utils.GroundPlaneCfg())
# lights
dome_light = AssetBaseCfg(
prim_path="/World/Light", spawn=sim_utils.DomeLightCfg(intensity=3000.0, color=(0.75, 0.75, 0.75))
)
# Big cube (static)
big_cube = RigidObjectCfg(
prim_path="{ENV_REGEX_NS}/BigCube",
spawn=sim_utils.CuboidCfg(
size=(1.0, 1.0, 1.0),
rigid_props=sim_utils.RigidBodyPropertiesCfg(),
mass_props=sim_utils.MassPropertiesCfg(mass=0.0), # Static object
collision_props=sim_utils.CollisionPropertiesCfg(),
physics_material=sim_utils.RigidBodyMaterialCfg(static_friction=0.8),
visual_material=sim_utils.PreviewSurfaceCfg(diffuse_color=(1.0, 0.0, 0.0)),
activate_contact_sensors=True,
),
init_state=RigidObjectCfg.InitialStateCfg(pos=(0.0, 0.0, 0.8)),
)
# Small cubes (dynamic)
cube1 = RigidObjectCfg(
prim_path="{ENV_REGEX_NS}/Cube1",
spawn=sim_utils.CuboidCfg(
size=(0.2, 0.2, 0.2),
rigid_props=sim_utils.RigidBodyPropertiesCfg(),
mass_props=sim_utils.MassPropertiesCfg(mass=10.0),
collision_props=sim_utils.CollisionPropertiesCfg(),
physics_material=sim_utils.RigidBodyMaterialCfg(),
visual_material=sim_utils.PreviewSurfaceCfg(diffuse_color=(0.0, 0.0, 1.0)),
activate_contact_sensors=True,
),
init_state=RigidObjectCfg.InitialStateCfg(pos=(0.0, 0.0, 2.0)),
)
cube2 = RigidObjectCfg(
prim_path="{ENV_REGEX_NS}/Cube2",
spawn=sim_utils.CuboidCfg(
size=(0.5, 1.0, 0.1),
rigid_props=sim_utils.RigidBodyPropertiesCfg(),
mass_props=sim_utils.MassPropertiesCfg(mass=10.0),
collision_props=sim_utils.CollisionPropertiesCfg(),
physics_material=sim_utils.RigidBodyMaterialCfg(),
visual_material=sim_utils.PreviewSurfaceCfg(diffuse_color=(0.0, 1.0, 0.0)),
activate_contact_sensors=True,
),
init_state=RigidObjectCfg.InitialStateCfg(pos=(0.4, 0.0, 0.05)),
)
cube3 = RigidObjectCfg(
prim_path="{ENV_REGEX_NS}/Cube3",
spawn=sim_utils.CuboidCfg(
size=(0.5, 1.0, 0.1),
rigid_props=sim_utils.RigidBodyPropertiesCfg(),
mass_props=sim_utils.MassPropertiesCfg(mass=10.0),
collision_props=sim_utils.CollisionPropertiesCfg(),
physics_material=sim_utils.RigidBodyMaterialCfg(),
visual_material=sim_utils.PreviewSurfaceCfg(diffuse_color=(1.0, 1.0, 0.0)),
activate_contact_sensors=True,
),
init_state=RigidObjectCfg.InitialStateCfg(pos=(-0.4, 0.0, 0.05)),
)
# Contact sensor on big cube
big_cube_contact = ContactSensorCfg(
prim_path="{ENV_REGEX_NS}/BigCube",
update_period=0.0,
history_length=6,
debug_vis=True,
filter_prim_paths_expr=["{ENV_REGEX_NS}/Cube1", "{ENV_REGEX_NS}/Cube2", "{ENV_REGEX_NS}/Cube3"],
max_contact_data_count=8000
)
def run_simulator(sim: sim_utils.SimulationContext, scene: InteractiveScene):
"""Run the simulator."""
sim_dt = sim.get_physics_dt()
sim_time = 0.0
count = 0
# Simulate physics
while simulation_app.is_running():
if count % 2000 == 0:
count = 0
# Reset small cubes to initial positions
for cube_name in ["cube1", "cube2", "cube3"]:
cube = scene[cube_name]
root_state = cube.data.default_root_state.clone()
root_state[:, :3] += scene.env_origins
cube.write_root_pose_to_sim(root_state[:, :7])
cube.write_root_velocity_to_sim(root_state[:, 7:])
scene.reset()
print("[INFO]: Resetting cubes...")
# Perform simulation step
sim.step()
sim_time += sim_dt
count += 1
scene.update(sim_dt)
# Print contact information
print("\n-------------------------------")
print("Big Cube Contact Sensor Data:")
print_sensor_contact_details(scene["big_cube_contact"])
def print_sensor_contact_details(sensor):
"""Improved debug output with multiple filter support"""
print(f"\n--- Sensor: {sensor.cfg.prim_path} ---")
if sensor.data.contact_forces_buffer is None:
print("Buffer Status:")
if len(sensor.cfg.filter_prim_paths_expr) == 0:
print(" - No filters configured!")
if sensor.cfg.max_contact_data_count <= 0:
print(" - max_contact_data_count not set!")
return
# Iterate over all environments
for env_id in range(sensor.data.contact_start_indices_buffer.shape[0]):
print(f"\nEnvironment ID: {env_id}")
body_id = 0 # Single body sensor
# Check each filter
for filter_idx, filter_prim in enumerate(sensor.cfg.filter_prim_paths_expr):
print(f"\nFilter: {filter_prim}")
start_idx = int(sensor.data.contact_start_indices_buffer[env_id, body_id, filter_idx].item())
contact_count = int(sensor.data.contact_count_buffer[env_id, body_id, filter_idx].item())
if contact_count > 0:
print(f"Found {contact_count} contact points:")
end_idx = min(start_idx + contact_count, len(sensor.data.contact_forces_buffer))
for i in range(start_idx, end_idx):
print(f" Point {i - start_idx + 1}:")
print(f" Force: {sensor.data.contact_forces_buffer[i].cpu().numpy()} N")
print(f" Position: {sensor.data.contact_points_buffer[i].cpu().numpy()} m")
print(f" Normal: {sensor.data.contact_normals_buffer[i].cpu().numpy()}")
print(f" Separation: {sensor.data.contact_separation_distances_buffer[i].cpu().numpy()} m")
else:
print("No contacts detected.")
def main():
"""Main function."""
sim_cfg = sim_utils.SimulationCfg(dt=0.005, device=args_cli.device)
sim = sim_utils.SimulationContext(sim_cfg)
sim.set_camera_view(eye=[5.0, 5.0, 5.0], target=[0.0, 0.0, 0.0])
scene_cfg = ContactSensorSceneCfg(num_envs=args_cli.num_envs, env_spacing=2.0)
scene = InteractiveScene(scene_cfg)
sim.reset()
print("[INFO]: Setup complete...")
run_simulator(sim, scene)
if __name__ == "__main__":
main()
simulation_app.close() These are the reported results for a single environment:
I have also tested for multiple environments, but I need to check if the results make sense. However, logically the performance of the simulation is greatly affected as I scale the number of environments. I would appreciate any recommendations or comments on this implementation. I think with the Note: This might be useful for #1925 Thank you in advance for any help you could provide to optimize this implementation or any idea on how to solve the questions posted in my first post. |
Beta Was this translation helpful? Give feedback.
-
I do not know your application exactly but just a suggestion, in many of this cases if you want to determine where the contact is coming from within a body you can take the applied force between objects from the force matrix apply a quaternion rotation to transfotm to the local frame get the magnitude and find the angle of the force vector to the local coordinate axis and determine (roughly the direction of impact). |
Beta Was this translation helpful? Give feedback.
-
Hi, thanks for the great posting. This helps a lot! For me, I need a contact data with "ground" specifically, For this if I modify like this:
the contact force / friction force are just 0. Could you please check if this is the case for you as well? Thank you |
Beta Was this translation helpful? Give feedback.
-
Question
Hi everyone,
I’m exploring contact sensing capabilities in Isaac Lab and would like guidance on two related scenarios critical to my project.
1. Localized Contact Sensors on a Single Body
Objective: Attach multiple contact sensors to distinct regions of a single rigid body (e.g., an object or robot link) to detect localized contacts (e.g., collisions at specific points on a gripper or object surface).
Current Understanding:
Questions:
2. Identifying Specific Contacting Bodies in One-to-Many Interactions
Objective: When a sensor-equipped body interacts with multiple bodies/prims (e.g., a robot gripper colliding with objects B, C, D), determine which exact body/prim triggered the contact.
Use Case Example:
Body A (with sensors) collides with Body B (target object) and Body C (obstacle). Downstream logic requires applying distinct behaviors based on whether contact occurred with B or C.
Questions:
Thank you for your time and expertise!
Beta Was this translation helpful? Give feedback.
All reactions