diff --git a/docs/source/tasks/table_top_gripper/index.md b/docs/source/tasks/table_top_gripper/index.md index 5a85418e8..f7246b995 100644 --- a/docs/source/tasks/table_top_gripper/index.md +++ b/docs/source/tasks/table_top_gripper/index.md @@ -279,6 +279,33 @@ A simple task where the objective is to grasp a red cube and move it to a target +## StackPyramid-v1 +![dense-reward][reward-badge] +:::{dropdown} Task Card +:icon: note +:color: primary + +**Task Description:** +The goal is to pick up a red cube, place it next to the green cube, and stack the blue cube on top of the red and green cube without it falling off. + +**Supported Robots: Panda** + +**Randomizations:** +- both cubes have their z-axis rotation randomized +- both cubes have their xy positions on top of the table scene randomized. The positions are sampled such that the cubes do not collide with each other + +**Success Conditions:** +- the blue cube is static +- the blue cube is on top of both the red and green cube (to within half of the cube size) +- the blue cube is static +- none of the red, green, blue cubes are grasped by the robot (robot must let go of the cubes) +::: + + + + ## PickSingleYCB-v1 ![dense-reward][dense-reward-badge] @@ -313,7 +340,25 @@ Pick up a random object sampled from the [YCB dataset](https://www.ycbbenchmarks +<<<<<<< Updated upstream ## PlaceSphere-v1 +======= +## PickAndPlace-v1 + +:::{dropdown} Task Card +:icon: note +:color: primary + + + + + + +## PegInsertionSide-v1 +![dense-reward][reward-badge] +>>>>>>> Stashed changes ![dense-reward][dense-reward-badge] ![sparse-reward][sparse-reward-badge] diff --git a/examples/baselines/diffusion_policy/examples.sh b/examples/baselines/diffusion_policy/examples.sh new file mode 100644 index 000000000..84f10d770 --- /dev/null +++ b/examples/baselines/diffusion_policy/examples.sh @@ -0,0 +1,87 @@ +### Example scripts for training Diffusion Policy that have some results ### + +# Learning from motion planning generated demonstrations + +# PushCube-v1 +python -m mani_skill.trajectory.replay_trajectory \ + --traj-path ~/.maniskill/demos/PushCube-v1/motionplanning/trajectory.h5 \ + --use-first-env-state -c pd_ee_delta_pos -o state \ + --save-traj --num-procs 10 -b cpu + +python train.py --env-id PushCube-v1 \ + --demo-path ~/.maniskill/demos/PushCube-v1/motionplanning/trajectory.state.pd_ee_delta_pos.cpu.h5 \ + --control-mode "pd_ee_delta_pos" --sim-backend "cpu" --num-demos 100 --max_episode_steps 100 \ + --total_iters 30000 + +# PickCube-v1 +python -m mani_skill.trajectory.replay_trajectory \ + --traj-path ~/.maniskill/demos/PickCube-v1/motionplanning/trajectory.h5 \ + --use-first-env-state -c pd_ee_delta_pos -o state \ + --save-traj --num-procs 10 -b cpu + +python train.py --env-id PickCube-v1 \ + --demo-path ~/.maniskill/demos/PickCube-v1/motionplanning/trajectory.state.pd_ee_delta_pos.cpu.h5 \ + --control-mode "pd_ee_delta_pos" --sim-backend "cpu" --num-demos 100 --max_episode_steps 100 \ + --total_iters 30000 + +# StackCube-v1 +python -m mani_skill.trajectory.replay_trajectory \ + --traj-path ~/.maniskill/demos/StackCube-v1/motionplanning/trajectory.h5 \ + --use-first-env-state -c pd_ee_delta_pos -o state \ + --save-traj --num-procs 10 -b cpu + +python train.py --env-id StackCube-v1 \ + --demo-path ~/.maniskill/demos/StackCube-v1/motionplanning/trajectory.state.pd_ee_delta_pos.cpu.h5 \ + --control-mode "pd_ee_delta_pos" --sim-backend "cpu" --num-demos 100 --max_episode_steps 200 \ + --total_iters 30000 + +# PegInsertionSide-v1 +python -m mani_skill.trajectory.replay_trajectory \ + --traj-path ~/.maniskill/demos/PegInsertionSide-v1/motionplanning/trajectory.h5 \ + --use-first-env-state -c pd_ee_delta_pose -o state \ + --save-traj --num-procs 10 -b cpu + +python train.py --env-id PegInsertionSide-v1 \ + --demo-path ~/.maniskill/demos/PegInsertionSide-v1/motionplanning/trajectory.state.pd_ee_delta_pose.cpu.h5 \ + --control-mode "pd_ee_delta_pose" --sim-backend "cpu" --num-demos 100 --max_episode_steps 300 \ + --total_iters 300000 + +# StackPyramid-v1 +stack_pyramid_traj_name="stack_pyramid_trajectory" +python -m mani_skill.examples.motionplanning.panda.run -e "StackPyramid-v1" \ + --num-procs 20 -n 100 --reward-mode="sparse" \ + --sim-backend cpu --only-count-success \ + --traj-name $stack_pyramid_traj_name + +python -m mani_skill.trajectory.replay_trajectory -b cpu \ + --traj-path ./demos/StackPyramid-v1/motionplanning/$stack_pyramid_traj_name.h5 \ + --use-first-env-state -c pd_joint_delta_pos -o state \ + --save-traj --num-procs 20 + +python train.py --env-id StackPyramid-v1 \ + --demo-path ./demos/StackPyramid-v1/motionplanning/$stack_pyramid_traj_name.state.pd_joint_delta_pos.cpu.h5 \ + --control-mode "pd_joint_delta_pos" --sim-backend "cpu" --num-demos 100 \ + --max_episode_steps 300 --total_iters 30000 --num-eval-envs 20 + +# PickAndPlace-v1 +pick_and_place_traj_name="pick_and_place_trajectory" +python -m mani_skill.examples.motionplanning.panda.run -e "PickAndPlace-v1" \ + --num-procs 20 -n 100 --reward-mode="sparse" \ + --sim-backend cpu --only-count-success \ + --traj-name $pick_and_place_traj_name + +python -m mani_skill.trajectory.replay_trajectory -b cpu \ + --traj-path ./demos/PickAndPlace-v1/motionplanning/$pick_and_place_traj_name.h5 \ + --use-first-env-state -c pd_joint_delta_pos -o state \ + --save-traj --num-procs 20 + +python train.py --env-id PickAndPlace-v1 \ + --demo-path ./demos/PickAndPlace-v1/motionplanning/$pick_and_place_traj_name.state.pd_joint_delta_pos.cpu.h5 \ + --control-mode "pd_joint_delta_pos" --sim-backend "cpu" --num-demos 100 \ + --max_episode_steps 300 --total_iters 30000 --num-eval-envs 20 + + +python train.py --env-id PickAndPlace-v1 \ + --demo-path ../../../demos/PickAndPlace-v1/motionplanning/20241209_162523.state.pd_joint_delta_pos.cpu.h5 \ + --control-mode "pd_joint_delta_pos" --sim-backend "cpu" --num-demos 100 --max_episode_steps 800 \ + --total_iters 30000 --num-eval-envs 20 diff --git a/figures/environment_demos/PickAndPlace-v1_rt.mp4 b/figures/environment_demos/PickAndPlace-v1_rt.mp4 new file mode 100644 index 000000000..17a5f6179 Binary files /dev/null and b/figures/environment_demos/PickAndPlace-v1_rt.mp4 differ diff --git a/figures/environment_demos/StackPyramid-v1_rt.mp4 b/figures/environment_demos/StackPyramid-v1_rt.mp4 new file mode 100644 index 000000000..0c12c49f0 Binary files /dev/null and b/figures/environment_demos/StackPyramid-v1_rt.mp4 differ diff --git a/mani_skill/agents/robots/panda/panda.py b/mani_skill/agents/robots/panda/panda.py index 750d85550..b44eda580 100644 --- a/mani_skill/agents/robots/panda/panda.py +++ b/mani_skill/agents/robots/panda/panda.py @@ -16,7 +16,7 @@ @register_agent() class Panda(BaseAgent): uid = "panda" - urdf_path = f"{PACKAGE_ASSET_DIR}/robots/panda/panda_v2.urdf" + urdf_path = f"{PACKAGE_ASSET_DIR}/robots/panda/panda_v3.urdf" urdf_config = dict( _materials=dict( gripper=dict(static_friction=2.0, dynamic_friction=2.0, restitution=0.0) diff --git a/mani_skill/envs/tasks/tabletop/__init__.py b/mani_skill/envs/tasks/tabletop/__init__.py index be5c15bbc..c2540dff6 100644 --- a/mani_skill/envs/tasks/tabletop/__init__.py +++ b/mani_skill/envs/tasks/tabletop/__init__.py @@ -15,4 +15,6 @@ from .place_sphere import PlaceSphereEnv from .roll_ball import RollBallEnv from .push_t import PushTEnv -from .pull_cube_tool import PullCubeToolEnv \ No newline at end of file +from .pull_cube_tool import PullCubeToolEnv +from .stack_pyramid import StackPyramidEnv +from .pick_and_place import PickAndPlaceEnv diff --git a/mani_skill/envs/tasks/tabletop/pick_and_place.py b/mani_skill/envs/tasks/tabletop/pick_and_place.py new file mode 100644 index 000000000..001a16001 --- /dev/null +++ b/mani_skill/envs/tasks/tabletop/pick_and_place.py @@ -0,0 +1,212 @@ +from typing import Any, Dict, Union + +import numpy as np +import sapien +import torch + +from mani_skill import logger +import mani_skill.envs.utils.randomization as randomization +from mani_skill.agents.robots.panda.panda_wristcam import PandaWristCam +from mani_skill.envs.sapien_env import BaseEnv +from mani_skill.sensors.camera import CameraConfig +from mani_skill.utils import sapien_utils +from mani_skill.utils.building import actors +from mani_skill.utils.registration import register_env +from mani_skill.utils.scene_builder.table import TableSceneBuilder +from mani_skill.utils.structs.actor import Actor +from mani_skill.utils.structs.pose import Pose +from mani_skill.utils.structs.types import SimConfig +from mani_skill.utils.building.actors.common import build_container_grid + + +@register_env("PickAndPlace-v1", max_episode_steps=500) +class PickAndPlaceEnv(BaseEnv): + """ + **Task Description:** + - Pick and place the four cubes into distinct container cells. + - The red cube must be placed on the upper left cell. + - The green cube must be placed on the upper right cell. + - The blue cube must be placed on the lower left cell. + - The yellow cube must be placed on the lower right cell. + + **Randomizations:** + - The four cubes' positions are randomized in the xy range of x in [0.05, 0.09] and y in [-0.15, -0.1]. + - The z-coordinate of the cubes is fixed at 0.04 to place them on the table surface. + - The cubes' orientations are randomized. + - Even though randomized, the cubes are ensured to have a minimum Euclidean distance of 0.06 from each other to avoid overlaps. + + **Success Conditions:** + - Each cube is placed into its corresponding cell. + - Robot is static within the threshold of 0.2 + + **Goal Specification:** + - 3D goal position (also visualized in human renders) + + _sample_video_link = "https://github.com/haosulab/ManiSkill/raw/main/figures/environment_demos/PickAndPlace-v1_rt.mp4" + """ + SUPPORTED_ROBOTS = ["panda_wristcam"] + SUPPORTED_REWARD_MODES = ["none", "sparse"] + agent: Union[PandaWristCam] + cube_half_size = 0.02 + goal_thresh = 0.05 + + def __init__(self, *args, robot_uids="panda", robot_init_qpos_noise=0.02, **kwargs): + self.robot_init_qpos_noise = robot_init_qpos_noise + super().__init__(*args, robot_uids=robot_uids, **kwargs) + + @property + def _default_sensor_configs(self): + pose = sapien_utils.look_at(eye=[0.3, 0, 0.6], target=[-0.1, 0, 0.1]) + return [CameraConfig("base_camera", pose, 128, 128, np.pi / 2, 0.01, 100)] + + @property + def _default_human_render_camera_configs(self): + pose = sapien_utils.look_at([0.6, 0.7, 0.6], [0.0, 0.0, 0.35]) + return CameraConfig("render_camera", pose, 512, 512, 1, 0.01, 100) + + def _load_agent(self, options: dict): + super()._load_agent(options, sapien.Pose(p=[-0.615, 0, 0])) + + def _load_scene(self, options: dict): + self.table_scene = TableSceneBuilder( + self, robot_init_qpos_noise=self.robot_init_qpos_noise + ) + self.table_scene.build() + self.container_grid, self.goal_sites = build_container_grid( + self.scene, + initial_pose=sapien.Pose(p=[0.0, 0.2, 0.04], q=[1, 0, 0, 0]), + size=0.25, + height=0.05, + thickness=0.01, + color=(0.8, 0.6, 0.4), + name="container_grid", + n=2, + m=2, + body_type="kinematic", + ) + self.red_cube = actors.build_cube( + self.scene, + half_size=self.cube_half_size, + color=[1, 0, 0, 1], + name="red_cube", + initial_pose=sapien.Pose(p=[0,0,0.1]) + ) + + self.green_cube = actors.build_cube( + self.scene, + half_size=self.cube_half_size, + color=[0, 1, 0, 1], + name="green_cube", + initial_pose=sapien.Pose(p=[0,0,0.5]) + ) + + self.blue_cube = actors.build_cube( + self.scene, + half_size=self.cube_half_size, + color=[0, 0, 1, 1], + name="blue_cube", + initial_pose=sapien.Pose(p=[0,0,1.0]) + ) + + self.yellow_cube = actors.build_cube( + self.scene, + half_size=self.cube_half_size, + color=[1, 1, 0, 1], + name="yellow_cube", + initial_pose=sapien.Pose(p=[0,0,1.5]) + ) + + self.cubes = [self.red_cube, self.green_cube, self.blue_cube, self.yellow_cube] + self._hidden_objects.extend(self.goal_sites) + + def _initialize_episode(self, env_idx: torch.Tensor, options: dict): + with torch.device(self.device): + env_count = len(env_idx) + self.table_scene.initialize(env_idx) + + # Initialize container grid + container_pose = sapien.Pose(p=[0.0, 0.2, 0.04], q=[1, 0, 0, 0]) + self.container_grid.set_pose(container_pose) + + region = [[0.05, -0.15], [0.09, -0.1]] + radius = torch.linalg.norm(torch.tensor([0.02, 0.02])) + 0.001 + min_distance = 0.06 + + # Randomize cube positions + for i, cube in enumerate(self.cubes): + while True: + xyz = torch.zeros((env_count, 3)) + sampler = randomization.UniformPlacementSampler( + bounds=region, batch_size=env_count, device=self.device + ) + cube_xy = torch.rand((env_count, 2)) * 0.4 - 0.4 + sampler.sample(radius, 100) + xyz[:, :2] = cube_xy + xyz[:, 2] = 0.04 + qs = randomization.random_quaternions( + env_count, + lock_x=True, + lock_y=True, + lock_z=False, + ) + cube.set_pose(Pose.create_from_pq(p=xyz, q=qs)) + + overlap = False + for j in range(i): + other_cube = self.cubes[j] + other_xyz = other_cube.pose.p + distance = torch.linalg.norm(xyz - other_xyz, axis=1) + if torch.any(distance < min_distance): + overlap = True + break + + if not overlap: + break + + def _get_obs_extra(self, info: Dict): + # in reality some people hack is_grasped into observations by checking if the gripper can close fully or not + obs = {"tcp_pose": self.agent.tcp.pose.raw_pose} + for goal_site in self.goal_sites: + obs[f"{goal_site.name}_pose"] = goal_site.pose.p + + for i, obj in enumerate(self.cubes): + obs[f"{obj.name}_pose"] = obj.pose.raw_pose + if "state" in self.obs_mode: + pass + obs[f"{obj.name}_to_goal_pos"] = self.goal_sites[i].pose.p - obj.pose.p + obs[f"tcp_to_{obj.name}_pos"] = obj.pose.p - self.agent.tcp.pose.p + + return obs + + + def evaluate(self): + results = dict() + all_placed = torch.tensor(True, device=self.device) + any_grasped = torch.tensor(False, device=self.device) + # only count success after all objects grasped and placed + + for i, goal_site in enumerate(self.goal_sites): + obj = self.cubes[i] + obj_name = obj.name + + distance_to_goal = torch.linalg.norm(goal_site.pose.p - obj.pose.p, axis=1) + is_placed = ( + distance_to_goal + <= self.goal_thresh + ) + is_grasped = self.agent.is_grasping(obj) + + results[f"{obj_name}_distance_to_goal"] = distance_to_goal + results[f"is_{obj_name}_placed"] = is_placed + results[f"is_{obj_name}_grasped"] = is_grasped + + all_placed = torch.logical_and(all_placed, is_placed) + any_grasped = torch.logical_or(any_grasped, is_grasped) + + + # Success is defined as all cubes being placed and none being grasped + results["success"] = torch.logical_and(all_placed, torch.logical_not(any_grasped)) + + # Reward for the robot being static + results["is_robot_static"] = self.agent.is_static(0.2) + + return results diff --git a/mani_skill/envs/tasks/tabletop/stack_pyramid.py b/mani_skill/envs/tasks/tabletop/stack_pyramid.py new file mode 100644 index 000000000..5916ce0a8 --- /dev/null +++ b/mani_skill/envs/tasks/tabletop/stack_pyramid.py @@ -0,0 +1,171 @@ +from typing import Any, Dict, Union + +import numpy as np +import sapien +import torch + +from mani_skill.agents.robots import Fetch, Panda +from mani_skill.envs.sapien_env import BaseEnv +from mani_skill.envs.utils import randomization +from mani_skill.sensors.camera import CameraConfig +from mani_skill.utils import common, sapien_utils +from mani_skill.utils.building import actors +from mani_skill.utils.registration import register_env +from mani_skill.utils.scene_builder.table import TableSceneBuilder +from mani_skill.utils.structs.pose import Pose +from mani_skill.utils.logging_utils import logger + +@register_env("StackPyramid-v1", max_episode_steps=50) +class StackPyramidEnv(BaseEnv): + """ + **Task Description:** + - The goal is to pick up a red cube, place it next to the green cube, and stack the blue cube on top of the red and green cube without it falling off. + + **Randomizations:** + - both cubes have their z-axis rotation randomized + - both cubes have their xy positions on top of the table scene randomized. The positions are sampled such that the cubes do not collide with each other + + **Success Conditions:** + - the blue cube is static + - the blue cube is on top of both the red and green cube (to within half of the cube size) + - none of the red, green, blue cubes are grasped by the robot (robot must let go of the cubes) + + _sample_video_link = "https://github.com/haosulab/ManiSkill/raw/main/figures/environment_demos/StackPyramid-v1_rt.mp4" + + """ + + SUPPORTED_ROBOTS = ["panda_wristcam", "panda", "fetch"] + SUPPORTED_REWARD_MODES = ["none", "sparse"] + + agent: Union[Panda, Fetch] + + def __init__( + self, *args, robot_uids="panda_wristcam", robot_init_qpos_noise=0.02, **kwargs + ): + self.robot_init_qpos_noise = robot_init_qpos_noise + super().__init__(*args, robot_uids=robot_uids, **kwargs) + + @property + def _default_sensor_configs(self): + pose = sapien_utils.look_at(eye=[0.3, 0, 0.6], target=[-0.1, 0, 0.1]) + return [CameraConfig("base_camera", pose, 128, 128, np.pi / 2, 0.01, 100)] + + @property + def _default_human_render_camera_configs(self): + pose = sapien_utils.look_at([0.6, 0.7, 0.6], [0.0, 0.0, 0.35]) + return CameraConfig("render_camera", pose, 512, 512, 1, 0.01, 100) + + def _load_scene(self, options: dict): + self.cube_half_size = common.to_tensor([0.02] * 3) + self.table_scene = TableSceneBuilder( + env=self, robot_init_qpos_noise=self.robot_init_qpos_noise + ) + self.table_scene.build() + self.cubeA = actors.build_cube( + self.scene, half_size=0.02, color=[1, 0, 0, 1], name="cubeA", initial_pose=sapien.Pose(p=[0, 0, 0.1]) + ) + self.cubeB = actors.build_cube( + self.scene, half_size=0.02, color=[0, 1, 0, 1], name="cubeB", initial_pose=sapien.Pose(p=[1, 0, 0.1]) + ) + self.cubeC = actors.build_cube( + self.scene, half_size=0.02, color=[0, 0, 1, 1], name="cubeC", initial_pose=sapien.Pose(p=[-1, 0, 0.1]) + ) + + def _initialize_episode(self, env_idx: torch.Tensor, options: dict): + with torch.device(self.device): + b = len(env_idx) + self.table_scene.initialize(env_idx) + + xyz = torch.zeros((b, 3)) + xyz[:, 2] = 0.02 + xy = torch.rand((b, 2)) * 0.2 - 0.1 + region = [[-0.1, -0.2], [0.1, 0.2]] + sampler = randomization.UniformPlacementSampler(bounds=region, batch_size=b) + radius = torch.linalg.norm(torch.tensor([0.02, 0.02])) + 0.001 + cubeA_xy = xy + sampler.sample(radius, 100) + cubeB_xy = xy + sampler.sample(radius, 100) + cubeC_xy = xy + sampler.sample(radius, 100) + + xyz[:, :2] = cubeA_xy + qs = randomization.random_quaternions( + b, + lock_x=True, + lock_y=True, + lock_z=False, + ) + self.cubeA.set_pose(Pose.create_from_pq(p=xyz.clone(), q=qs)) + + xyz[:, :2] = cubeB_xy + qs = randomization.random_quaternions( + b, + lock_x=True, + lock_y=True, + lock_z=False, + ) + self.cubeB.set_pose(Pose.create_from_pq(p=xyz.clone(), q=qs)) + + xyz[:, :2] = cubeC_xy + qs = randomization.random_quaternions( + b, + lock_x=True, + lock_y=True, + lock_z=False + ) + self.cubeC.set_pose(Pose.create_from_pq(p=xyz, q=qs)) + # ... + + def evaluate(self): + pos_A = self.cubeA.pose.p + pos_B = self.cubeB.pose.p + pos_C = self.cubeC.pose.p + + offset_AB = pos_A - pos_B + offset_BC = pos_B - pos_C + offset_AC = pos_A - pos_C + + def evaluate_cube_distance(offset, cube_a, cube_b, top_or_next): + tolerance = 0.5 + if top_or_next == "top": + xy_offset = torch.linalg.norm(offset[..., :2], axis=-1) - torch.linalg.norm(self.cube_half_size[:2]) + z_offset = torch.linalg.norm(offset[..., 2]) - torch.linalg.norm(2 * self.cube_half_size[2]) + + xy_flag = xy_offset <= tolerance + z_flag = z_offset <= tolerance + + else: + xy_offset = torch.linalg.norm(offset[..., :2], axis=-1) - torch.linalg.norm(2 * self.cube_half_size[:2]) + z_offset = torch.abs(offset[..., 2] - self.cube_half_size[2]) + xy_flag = xy_offset <= 0.05 + z_flag = z_offset <= 0.05 + is_cubeA_on_cubeB = torch.logical_and(xy_flag, z_flag) + is_cubeA_static = cube_a.is_static(lin_thresh=1e-2, ang_thresh=0.5) + is_cubeA_grasped = self.agent.is_grasping(cube_a) + + success = is_cubeA_on_cubeB & is_cubeA_static & (~is_cubeA_grasped) + return success.bool() + + success_A_B = evaluate_cube_distance(offset_AB, self.cubeA, self.cubeB, "next_to") + success_C_B = evaluate_cube_distance(offset_BC, self.cubeC, self.cubeB, "top") + success_C_A = evaluate_cube_distance(offset_AC, self.cubeC, self.cubeA, "top") + + success = success_A_B and success_C_B and success_C_A + + return { + "success": success, + } + + def _get_obs_extra(self, info: Dict): + obs = dict(tcp_pose=self.agent.tcp.pose.raw_pose) + if "state" in self.obs_mode: + obs.update( + cubeA_pose=self.cubeA.pose.raw_pose, + cubeB_pose=self.cubeB.pose.raw_pose, + cubeC_pose=self.cubeC.pose.raw_pose, + tcp_to_cubeA_pos=self.cubeA.pose.p - self.agent.tcp.pose.p, + tcp_to_cubeB_pos=self.cubeB.pose.p - self.agent.tcp.pose.p, + tcp_to_cubeC_pos=self.cubeC.pose.p - self.agent.tcp.pose.p, + cubeA_to_cubeB_pos=self.cubeB.pose.p - self.cubeA.pose.p, + cubeB_to_cubeC_pos=self.cubeC.pose.p - self.cubeB.pose.p, + cubeA_to_cubeC_pos=self.cubeC.pose.p - self.cubeA.pose.p, + ) + return obs diff --git a/mani_skill/envs/tasks/tabletop/test_stack_pyramid.py b/mani_skill/envs/tasks/tabletop/test_stack_pyramid.py new file mode 100644 index 000000000..1b937ab5d --- /dev/null +++ b/mani_skill/envs/tasks/tabletop/test_stack_pyramid.py @@ -0,0 +1,69 @@ + +import pytest +import torch +from mani_skill.envs.tasks.tabletop.stack_pyramid import StackPyramidEnv +from mani_skill.utils.structs.pose import Pose + +import gymnasium as gym + + +import numpy as np + +@pytest.fixture +def env(): + # return StackPyramidEnv() + env = gym.make("StackPyramid-v1", obs_mode="state") + return env + +def test_env_initialization(env): + assert env is not None + +def test_env_reset(env): + obs, info = env.reset() + assert isinstance(obs, torch.Tensor) + assert isinstance(info, dict) + +def test_observation_space(env): + _ = env.reset() + action = env.action_space.sample() + obs, reward, terminated, truncated, info = env.step(action) + +def test_action_space(env): + action = env.action_space.sample() + print(type(action)) + assert isinstance(action, np.ndarray) + +def test_step(env): + env.reset() + action = env.action_space.sample() + obs, reward, terminated, truncated, info = env.step(action) + print(type(obs), type(reward), type(terminated), type(truncated), type(info)) + assert isinstance(obs, torch.Tensor) + assert isinstance(reward, torch.Tensor) + assert isinstance(terminated, torch.Tensor) + assert isinstance(truncated, torch.Tensor) + assert isinstance(info, dict) + +def test_success_condition(env): + env.reset() + # Manually set the positions of the cubes to a successful configuration + env.cubeA.set_pose(Pose.create_from_pq(p=np.array([0.0, 0.0, 0.0]))) + env.cubeB.set_pose(Pose.create_from_pq(p=np.array([0.04, 0.0, 0.0]))) + env.cubeC.set_pose(Pose.create_from_pq(p=np.array([0.02, 0.0, 0.04]))) + env.cubeA.is_static = lambda lin_thresh, ang_thresh: True + env.cubeB.is_static = lambda lin_thresh, ang_thresh: True + env.cubeC.is_static = lambda lin_thresh, ang_thresh: True + env.agent.is_grasping = lambda cube: False + + success_info = env.evaluate() + # while True: + # env.render_human() + + print("Cube A Position:", env.cubeA.pose.p) + print("Cube B Position:", env.cubeB.pose.p) + print("Cube C Position:", env.cubeC.pose.p) + print("Success Info:", success_info) + + + assert success_info["success"], "Success condition failed." + diff --git a/mani_skill/examples/motionplanning/panda/run.py b/mani_skill/examples/motionplanning/panda/run.py index 6fd1621c3..2f6932863 100644 --- a/mani_skill/examples/motionplanning/panda/run.py +++ b/mani_skill/examples/motionplanning/panda/run.py @@ -9,7 +9,8 @@ import os.path as osp from mani_skill.utils.wrappers.record import RecordEpisode from mani_skill.trajectory.merge_trajectory import merge_trajectories -from mani_skill.examples.motionplanning.panda.solutions import solvePushCube, solvePickCube, solveStackCube, solvePegInsertionSide, solvePlugCharger, solvePullCubeTool, solveLiftPegUpright, solvePullCube, solveDrawTriangle +from mani_skill.examples.motionplanning.panda.solutions import solvePushCube, solvePickCube, solveStackCube, solvePegInsertionSide, solvePlugCharger, solvePullCubeTool, solveLiftPegUpright, solvePullCube, solveStackPyramid, solvePickAndPlace, solveDrawTriangle + MP_SOLUTIONS = { "DrawTriangle-v1": solveDrawTriangle, "PickCube-v1": solvePickCube, @@ -19,8 +20,9 @@ "PushCube-v1": solvePushCube, "PullCubeTool-v1": solvePullCubeTool, "LiftPegUpright-v1": solveLiftPegUpright, - "PullCube-v1": solvePullCube - + "PullCube-v1": solvePullCube, + "StackPyramid-v1": solveStackPyramid, + "PickAndPlace-v1": solvePickAndPlace, } def parse_args(args=None): parser = argparse.ArgumentParser() @@ -86,7 +88,6 @@ def _main(args, proc_id: int = 0, start_seed: int = 0) -> str: except Exception as e: print(f"Cannot find valid solution because of an error in motion planning solution: {e}") res = -1 - if res == -1: success = False failed_motion_plans += 1 diff --git a/mani_skill/examples/motionplanning/panda/solutions/__init__.py b/mani_skill/examples/motionplanning/panda/solutions/__init__.py index f1b2510d3..0e12f21be 100644 --- a/mani_skill/examples/motionplanning/panda/solutions/__init__.py +++ b/mani_skill/examples/motionplanning/panda/solutions/__init__.py @@ -5,4 +5,7 @@ from .push_cube import solve as solvePushCube from .pull_cube_tool import solve as solvePullCubeTool from .lift_peg_upright import solve as solveLiftPegUpright -from .pull_cube import solve as solvePullCube \ No newline at end of file +from .pull_cube import solve as solvePullCube +from .stack_pyramid import solve as solveStackPyramid +from .pick_and_place import solve as solvePickAndPlace +from .draw_triangle import solve as solveDrawTriangle diff --git a/mani_skill/examples/motionplanning/panda/solutions/pick_and_place.py b/mani_skill/examples/motionplanning/panda/solutions/pick_and_place.py new file mode 100644 index 000000000..aad95a72e --- /dev/null +++ b/mani_skill/examples/motionplanning/panda/solutions/pick_and_place.py @@ -0,0 +1,85 @@ +import numpy as np +import sapien + +from mani_skill.envs.tasks import PickAndPlaceEnv +from mani_skill.examples.motionplanning.panda.motionplanner import PandaArmMotionPlanningSolver +from mani_skill.examples.motionplanning.panda.utils import compute_grasp_info_by_obb, get_actor_obb + +def solve(env: PickAndPlaceEnv, seed=None, debug=False, vis=False): + env.reset(seed=seed) + planner = PandaArmMotionPlanningSolver( + env, + debug=debug, + vis=vis, + base_pose=env.unwrapped.agent.robot.pose, + visualize_target_grasp_pose=vis, + print_env_info=False, + ) + base_half_size = 0.25 + # height = 0.05 + height = 0.045 + planner.add_box_collision(extents=np.array([base_half_size, base_half_size, height]), pose=env.container_grid.pose.sp) + + # FINGER_LENGTH = 0.025 + FINGER_LENGTH = 0.03 + + env = env.unwrapped + if seed is not None: + rng = np.random.default_rng(seed) + else: + rng = np.random.default_rng() + + # Iterate over each cube and move it to the goal site + + order_of_cubes = list(range(len(env.cubes))) + + rng.shuffle(order_of_cubes) + cube_done = 0 + for i in order_of_cubes: + cube = env.cubes[i] + goal_site = env.goal_sites[i] + # Retrieve the object oriented bounding box (trimesh box object) + obb = get_actor_obb(cube) + + approaching = np.array([0, 0, -1]) + # Get transformation matrix of the tcp pose, is default batched and on torch + target_closing = env.agent.tcp.pose.to_transformation_matrix()[0, :3, 1].cpu().numpy() + # Build a simple grasp pose using this information for Panda + grasp_info = compute_grasp_info_by_obb( + obb, + approaching=approaching, + target_closing=target_closing, + depth=FINGER_LENGTH, + ) + closing, center = grasp_info["closing"], grasp_info["center"] + + grasp_pose = env.agent.build_grasp_pose(approaching, closing, cube.pose.sp.p) + + # -------------------------------------------------------------------------- # + # Reach + # -------------------------------------------------------------------------- # + reach_pose = grasp_pose * sapien.Pose([0, 0, -0.05]) + + planner.move_to_pose_with_screw(reach_pose) + + # -------------------------------------------------------------------------- # + # Grasp + # -------------------------------------------------------------------------- # + planner.move_to_pose_with_screw(grasp_pose) + planner.close_gripper() + + # -------------------------------------------------------------------------- # + # Move to goal pose + # -------------------------------------------------------------------------- # + goal_pose_offset = np.array([0.0, 0.0, 0.14]) + goal_pose = sapien.Pose(goal_site.pose.sp.p + goal_pose_offset, grasp_pose.q) + planner.move_to_pose_with_screw(goal_pose) + cube_done += 1 + + if (cube_done == (len(env.cubes))): + res = planner.open_gripper() + else: + planner.open_gripper() + + planner.close() + return res diff --git a/mani_skill/examples/motionplanning/panda/solutions/stack_pyramid.py b/mani_skill/examples/motionplanning/panda/solutions/stack_pyramid.py new file mode 100644 index 000000000..1be90afdb --- /dev/null +++ b/mani_skill/examples/motionplanning/panda/solutions/stack_pyramid.py @@ -0,0 +1,130 @@ +import argparse +import gymnasium as gym +import numpy as np +import sapien +from transforms3d.euler import euler2quat + +from mani_skill.envs.tasks import StackPyramidEnv +from mani_skill.examples.motionplanning.panda.motionplanner import \ + PandaArmMotionPlanningSolver +from mani_skill.examples.motionplanning.panda.utils import ( + compute_grasp_info_by_obb, get_actor_obb) +from mani_skill.utils.wrappers.record import RecordEpisode +from mani_skill.utils.structs import Pose + +def solve(env: StackPyramidEnv, seed=None, debug=False, vis=False): + env.reset(seed=seed) + assert env.unwrapped.control_mode in [ + "pd_joint_pos", + "pd_joint_pos_vel", + ], env.unwrapped.control_mode + planner = PandaArmMotionPlanningSolver( + env, + debug=debug, + vis=vis, + base_pose=env.unwrapped.agent.robot.pose, + visualize_target_grasp_pose=vis, + print_env_info=False, + ) + FINGER_LENGTH = 0.025 + env = env.unwrapped + + # -------------------------------------------------------------------------- # + # Push Cube A to be next to Cube B + # -------------------------------------------------------------------------- # + # Move Gripper to Cube A + obb = get_actor_obb(env.cubeA) + approaching = np.array([0, 0, -1]) + target_closing = env.agent.tcp.pose.to_transformation_matrix()[0, :3, 1].cpu().numpy() + grasp_info = compute_grasp_info_by_obb( + obb, + approaching=approaching, + target_closing=target_closing, + depth=FINGER_LENGTH, + ) + closing, center = grasp_info["closing"], grasp_info["center"] + distance = np.abs(np.linalg.norm(env.cubeA.pose.sp.p, axis=0) - np.linalg.norm(env.cubeB.pose.sp.p, axis=0)) + print(f"Distance: {distance}") + if (distance > 0.009): + print(f"Distance >= 0.009: {distance}") + planner.close_gripper() + grasp_pose = env.agent.build_grasp_pose(approaching, closing, env.cubeA.pose.sp.p) + + # Reach + reach_pose = grasp_pose * sapien.Pose([0, 0, -0.05]) + planner.move_to_pose_with_screw(reach_pose) + + # Grasp + planner.move_to_pose_with_screw(grasp_pose) + planner.close_gripper() + + # Move to Goal Pose + goal_pose = sapien.Pose(env.cubeB.pose.sp.p * 0.8, grasp_pose.q) + planner.move_to_pose_with_screw(goal_pose) + res = planner.open_gripper() + + # -------------------------------------------------------------------------- # + # Stack Cube C onto Cube A and B + # -------------------------------------------------------------------------- # + + obb = get_actor_obb(env.cubeC) + target_closing = env.agent.tcp.pose.to_transformation_matrix()[0, :3, 1].numpy() + grasp_info = compute_grasp_info_by_obb( + obb, + approaching=approaching, + target_closing=target_closing, + depth=FINGER_LENGTH, + ) + closing, center = grasp_info["closing"], grasp_info["center"] + grasp_pose = env.agent.build_grasp_pose(approaching, closing, center) + + # Search a valid pose + angles = np.arange(0, np.pi * 2 / 3, np.pi / 2) + angles = np.repeat(angles, 2) + angles[1::2] *= -1 + for angle in angles: + delta_pose = sapien.Pose(q=euler2quat(0, 0, angle)) + grasp_pose2 = grasp_pose * delta_pose + res = planner.move_to_pose_with_screw(grasp_pose2, dry_run=True) + if res == -1: + continue + grasp_pose = grasp_pose2 + break + else: + print("Fail to find a valid grasp pose") + + # -------------------------------------------------------------------------- # + # Reach + # -------------------------------------------------------------------------- # + + # planner.planner.update_attached_box([0.04, 0.04, 0.04], Pose.create(env.cubeB.pose).raw_pose.numpy().astype(np.float64).reshape(7,1)) + + reach_pose = grasp_pose * sapien.Pose([0, 0, -0.05]) + planner.move_to_pose_with_screw(reach_pose) + + # -------------------------------------------------------------------------- # + # Grasp + # -------------------------------------------------------------------------- # + planner.move_to_pose_with_screw(grasp_pose) + planner.close_gripper() + + # -------------------------------------------------------------------------- # + # Lift + # -------------------------------------------------------------------------- # + lift_pose = sapien.Pose([0, 0, 0.1]) * grasp_pose + planner.move_to_pose_with_screw(lift_pose) + + # -------------------------------------------------------------------------- # + # Stack + # -------------------------------------------------------------------------- # + goal_pose_A = env.cubeA.pose * sapien.Pose([0, 0, env.cube_half_size[2] * 2]) + goal_pose_B = env.cubeB.pose * sapien.Pose([0, 0, env.cube_half_size[2] * 2]) + goal_pose_p = (goal_pose_A.p + goal_pose_B.p)/2 + offset = (goal_pose_p - env.cubeC.pose.p).numpy()[0] # remember that all data in ManiSkill is batched and a torch tensor + align_pose = sapien.Pose(lift_pose.p + offset, lift_pose.q) + planner.move_to_pose_with_screw(align_pose) + + res = planner.open_gripper() + planner.close() + return res + diff --git a/mani_skill/utils/building/actors/common.py b/mani_skill/utils/building/actors/common.py index 07f5b695e..099595a6b 100644 --- a/mani_skill/utils/building/actors/common.py +++ b/mani_skill/utils/building/actors/common.py @@ -84,6 +84,82 @@ def build_box( ) return _build_by_type(builder, name, body_type, scene_idxs, initial_pose) +def build_container_grid( + scene: ManiSkillScene, + size: float, + height: float, + thickness: float, + color, + name: str, + n: int, + m: int, + initial_pose: sapien.Pose, + body_type: str = "static", + scene_idxs: Optional[Array] = None, +): + builder = scene.create_actor_builder() + + # Container base + base_pose = sapien.Pose([0., 0., -thickness / 2]) # Make the base's z equal to 0 + base_half_size = [size / 2, size / 2, thickness / 2] + builder.add_box_collision(pose=base_pose, half_size=base_half_size) + builder.add_box_visual(pose=base_pose, half_size=base_half_size) + + # Container sides (x4) + for i in [-1, 1]: + for axis in ['x', 'y']: + side_pose = sapien.Pose( + [i * (size - thickness) / 2 if axis == 'x' else 0, + i * (size - thickness) / 2 if axis == 'y' else 0, + height / 2] + ) + side_half_size = [thickness / 2, size / 2, height / 2] if axis == 'x' else [size / 2, thickness / 2, height / 2] + builder.add_box_collision(pose=side_pose, half_size=side_half_size) + builder.add_box_visual(pose=side_pose, half_size=side_half_size) + + # Create grid cells + internal_size = size - 2 * thickness + + cell_width = internal_size / n + cell_height = internal_size / m + + for i in range(1, n): + # Vertical dividers + divider_pose = sapien.Pose([i * cell_width - internal_size / 2, 0, height / 2]) + divider_half_size = [thickness / 2, size / 2, height / 2] + builder.add_box_collision(pose=divider_pose, half_size=divider_half_size) + builder.add_box_visual(pose=divider_pose, half_size=divider_half_size) + + for j in range(1, m): + # Horizontal dividers + divider_pose = sapien.Pose([0, j * cell_height - internal_size / 2, height / 2]) + divider_half_size = [size / 2, thickness / 2, height / 2] + builder.add_box_collision(pose=divider_pose, half_size=divider_half_size) + builder.add_box_visual(pose=divider_pose, half_size=divider_half_size) + + container = _build_by_type(builder, name, body_type, scene_idxs, initial_pose) + # Create goal sites at the center of each cell + goal_sites = [] + goal_radius = 0.02 + goal_color = [0, 1, 0, 1] + for i in range(n): + for j in range(m): + goal_x = (i + 0.5) * cell_width - internal_size / 2 + goal_y = (j + 0.5) * cell_height - internal_size / 2 + goal_local_pose = sapien.Pose([goal_x, goal_y, thickness + goal_radius]) + goal_world_pose = initial_pose * goal_local_pose # Transform to world frame + goal_site = build_box( + scene, + half_sizes=[cell_width / 2, cell_height / 2, thickness / 2], + color=goal_color, + name=f"goal_site_{i}_{j}", + body_type="kinematic", + add_collision=False, + initial_pose=goal_world_pose, + ) + goal_sites.append(goal_site) + + return container, goal_sites def build_cylinder( scene: ManiSkillScene,