diff --git a/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/manipulation/lift/lift_env_cfg.py b/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/manipulation/lift/lift_env_cfg.py index 6a3a078cb5..925c956b15 100644 --- a/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/manipulation/lift/lift_env_cfg.py +++ b/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/manipulation/lift/lift_env_cfg.py @@ -219,4 +219,4 @@ def __post_init__(self): self.sim.physx.bounce_threshold_velocity = 0.01 self.sim.physx.gpu_found_lost_aggregate_pairs_capacity = 1024 * 1024 * 4 self.sim.physx.gpu_total_aggregate_pairs_capacity = 16 * 1024 - self.sim.physx.friction_correlation_distance = 0.00625 + self.sim.physx.friction_correlation_distance = 0.00625 \ No newline at end of file diff --git a/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/manipulation/lift/lift_place_env_cfg.py b/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/manipulation/lift/lift_place_env_cfg.py new file mode 100644 index 0000000000..15f4acf6bf --- /dev/null +++ b/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/manipulation/lift/lift_place_env_cfg.py @@ -0,0 +1,257 @@ +# Copyright (c) 2022-2024, The Isaac Lab Project Developers. +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from dataclasses import MISSING + +import torch +import omni.isaac.lab.sim as sim_utils +from omni.isaac.lab.assets import ArticulationCfg, AssetBaseCfg, RigidObjectCfg +from omni.isaac.lab.envs import ManagerBasedRLEnv +from omni.isaac.lab.envs import ManagerBasedRLEnvCfg +from omni.isaac.lab.managers import CurriculumTermCfg as CurrTerm +from omni.isaac.lab.managers import EventTermCfg as EventTerm +from omni.isaac.lab.managers import ObservationGroupCfg as ObsGroup +from omni.isaac.lab.managers import ObservationTermCfg as ObsTerm +from omni.isaac.lab.managers import RewardTermCfg as RewTerm +from omni.isaac.lab.managers import SceneEntityCfg +from omni.isaac.lab.managers import TerminationTermCfg as DoneTerm +from omni.isaac.lab.scene import InteractiveSceneCfg +from omni.isaac.lab.sensors.frame_transformer.frame_transformer_cfg import FrameTransformerCfg +from omni.isaac.lab.sim.spawners.from_files.from_files_cfg import GroundPlaneCfg, UsdFileCfg +from omni.isaac.lab.utils import configclass +from omni.isaac.lab.utils.assets import ISAAC_NUCLEUS_DIR + +from . import mdp + +## +# Scene definition +## + + +@configclass +class ObjectTableSceneCfg(InteractiveSceneCfg): + """Configuration for the lift scene with a robot and a object. + This is the abstract base implementation, the exact scene is defined in the derived classes + which need to set the target object, robot and end-effector frames + """ + + # robots: will be populated by agent env cfg + robot: ArticulationCfg = MISSING + # end-effector sensor: will be populated by agent env cfg + ee_frame: FrameTransformerCfg = MISSING + # target object: will be populated by agent env cfg + object: RigidObjectCfg = MISSING + + # Table + table = AssetBaseCfg( + prim_path="{ENV_REGEX_NS}/Table", + init_state=AssetBaseCfg.InitialStateCfg(pos=[0.5, 0, 0], rot=[0.707, 0, 0, 0.707]), + spawn=UsdFileCfg(usd_path=f"{ISAAC_NUCLEUS_DIR}/Props/Mounts/SeattleLabTable/table_instanceable.usd"), + ) + + # plane + plane = AssetBaseCfg( + prim_path="/World/GroundPlane", + init_state=AssetBaseCfg.InitialStateCfg(pos=[0, 0, -1.05]), + spawn=GroundPlaneCfg(), + ) + + # lights + light = AssetBaseCfg( + prim_path="/World/light", + spawn=sim_utils.DomeLightCfg(color=(0.75, 0.75, 0.75), intensity=3000.0), + ) + + +## +# MDP settings +## + + +@configclass +class CommandsCfg: + """Command terms for the MDP.""" + + object_pose = mdp.UniformPoseCommandCfg( + asset_name="robot", + body_name="panda_hand", + resampling_time_range=(5.0, 5.0), + debug_vis=True, + ranges=mdp.UniformPoseCommandCfg.Ranges( + pos_x=(0.4, 0.6), pos_y=(-0.25, 0.25), pos_z=(0.25, 0.5), roll=(0.0, 0.0), pitch=(0.0, 0.0), yaw=(0.0, 0.0) + ), + ) + + place_pose = mdp.UniformPoseCommandCfg( + asset_name="robot", + body_name="panda_hand", + resampling_time_range=(5.0, 5.0), + debug_vis=True, + ranges=mdp.UniformPoseCommandCfg.Ranges( + pos_x=(0.4, 0.6), pos_y=(-0.25, 0.25), pos_z=(0.05, 0.05), roll=(0.0, 0.0), pitch=(0.0, 0.0), yaw=(0.0, 0.0) + ), + ) + + +@configclass +class ActionsCfg: + """Action specifications for the MDP.""" + + # will be set by agent env cfg + arm_action: mdp.JointPositionActionCfg | mdp.DifferentialInverseKinematicsActionCfg = MISSING + gripper_action: mdp.BinaryJointPositionActionCfg = MISSING + + +@configclass +class ObservationsCfg: + """Observation specifications for the MDP.""" + + @configclass + class PolicyCfg(ObsGroup): + """Observations for policy group.""" + + joint_pos = ObsTerm(func=mdp.joint_pos_rel) + joint_vel = ObsTerm(func=mdp.joint_vel_rel) + object_position = ObsTerm(func=mdp.object_position_in_robot_root_frame) + target_object_position = ObsTerm(func=mdp.generated_commands, params={"command_name": "object_pose"}) + target_place_position = ObsTerm(func=mdp.generated_commands, params={"command_name": "place_pose"}) + actions = ObsTerm(func=mdp.last_action) + + def __post_init__(self): + self.enable_corruption = True + self.concatenate_terms = True + + # observation groups + policy: PolicyCfg = PolicyCfg() + + +@configclass +class EventCfg: + """Configuration for events.""" + + reset_all = EventTerm(func=mdp.reset_scene_to_default, mode="reset") + + reset_object_position = EventTerm( + func=mdp.reset_root_state_uniform, + mode="reset", + params={ + "pose_range": {"x": (-0.1, 0.1), "y": (-0.25, 0.25), "z": (0.0, 0.0)}, + "velocity_range": {}, + "asset_cfg": SceneEntityCfg("object", body_names="Object"), + }, + ) + + +@configclass +class RewardsCfg: + """Reward terms for the MDP.""" + + # Reward for reaching object + reaching_object = RewTerm( + func=mdp.object_ee_distance, + params={"std": 0.1}, + weight=1.0 + ) + + # Reward for lifting object + lifting_object = RewTerm( + func=mdp.object_is_lifted, + params={"minimal_height": 0.04, "maximal_height": 0.5}, + weight=15.0 + ) + + # Reward for moving object to lifting position + object_goal_tracking = RewTerm( + func=mdp.object_goal_distance, + params={ + "std": 0.3, + "minimal_height": 0.04, + "maximal_height": 0.5, + "command_name": "object_pose" + }, + weight=16.0 + ) + + # Reward for moving object to placing position + placing_tracking = RewTerm( + func=mdp.object_goal_distance, + params={ + "std": 0.1, + "minimal_height": 0.04, + "maximal_height": 0.5, + "command_name": "place_pose" + }, + weight=18.0 + ) + + object_placed = RewTerm( + func=mdp.object_is_placed, # Reference the function directly + params={ + "distance_threshold": 0.02, + "height_threshold": 0.01, + }, + weight=20.0 + ) + + +@configclass +class TerminationsCfg: + """Termination terms for the MDP.""" + + time_out = DoneTerm(func=mdp.time_out, time_out=True) + + object_dropping = DoneTerm( + func=mdp.root_height_below_minimum, params={"minimum_height": -0.05, "asset_cfg": SceneEntityCfg("object")} + ) + + +@configclass +class CurriculumCfg: + """Curriculum terms for the MDP.""" + + action_rate = CurrTerm( + func=mdp.modify_reward_weight, params={"term_name": "action_rate", "weight": -1e-1, "num_steps": 10000} + ) + + joint_vel = CurrTerm( + func=mdp.modify_reward_weight, params={"term_name": "joint_vel", "weight": -1e-1, "num_steps": 10000} + ) + + +## +# Environment configuration +## + + +@configclass +class LiftEnvCfg(ManagerBasedRLEnvCfg): + """Configuration for the lifting environment.""" + + # Scene settings + scene: ObjectTableSceneCfg = ObjectTableSceneCfg(num_envs=4096, env_spacing=2.5) + # Basic settings + observations: ObservationsCfg = ObservationsCfg() + actions: ActionsCfg = ActionsCfg() + commands: CommandsCfg = CommandsCfg() + # MDP settings + rewards: RewardsCfg = RewardsCfg() + terminations: TerminationsCfg = TerminationsCfg() + events: EventCfg = EventCfg() + curriculum: CurriculumCfg = CurriculumCfg() + + def __post_init__(self): + """Post initialization.""" + # general settings + self.decimation = 2 + self.episode_length_s = 5.0 + # simulation settings + self.sim.dt = 0.01 # 100Hz + self.sim.render_interval = self.decimation + + self.sim.physx.bounce_threshold_velocity = 0.2 + self.sim.physx.bounce_threshold_velocity = 0.01 + self.sim.physx.gpu_found_lost_aggregate_pairs_capacity = 1024 * 1024 * 4 + self.sim.physx.gpu_total_aggregate_pairs_capacity = 16 * 1024 + self.sim.physx.friction_correlation_distance = 0.00625 \ No newline at end of file diff --git a/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/manipulation/lift/mdp/rewards.py b/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/manipulation/lift/mdp/rewards.py index 334df9ea50..2c90967718 100644 --- a/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/manipulation/lift/mdp/rewards.py +++ b/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/manipulation/lift/mdp/rewards.py @@ -18,11 +18,38 @@ def object_is_lifted( - env: ManagerBasedRLEnv, minimal_height: float, object_cfg: SceneEntityCfg = SceneEntityCfg("object") + env: ManagerBasedRLEnv, + minimal_height: float, + maximal_height: float, + object_cfg: SceneEntityCfg = SceneEntityCfg("object"), ) -> torch.Tensor: """Reward the agent for lifting the object above the minimal height.""" object: RigidObject = env.scene[object_cfg.name] - return torch.where(object.data.root_pos_w[:, 2] > minimal_height, 1.0, 0.0) + return torch.where( + (object.data.root_pos_w[:, 2] > minimal_height) & (object.data.root_pos_w[:, 2] < maximal_height), 1.0, 0.0 + ) + + +def object_is_placed( + env: ManagerBasedRLEnv, + distance_threshold: float, + height_threshold: float, + object_cfg: SceneEntityCfg = SceneEntityCfg("object"), +) -> torch.Tensor: + object: RigidObject = env.scene[object_cfg.name] + place_command = env.command_manager.get_command("place_pose") + + object_pos = object.data.root_pos_w + target_pos = place_command[:, :3] + + # check xy-distance to target + xy_distance = torch.norm(object_pos[:, :2] - target_pos[:, :2], dim=1) + + # check height difference + height_diff = torch.abs(object_pos[:, 2] - target_pos[:, 2]) + + # return 1.0 if within thresholds, 0.0 otherwise + return torch.where((xy_distance < distance_threshold) & (height_diff < height_threshold), 1.0, 0.0) def object_ee_distance( @@ -49,6 +76,7 @@ def object_goal_distance( env: ManagerBasedRLEnv, std: float, minimal_height: float, + maximal_height: float, command_name: str, robot_cfg: SceneEntityCfg = SceneEntityCfg("robot"), object_cfg: SceneEntityCfg = SceneEntityCfg("object"), @@ -64,4 +92,14 @@ def object_goal_distance( # distance of the end-effector to the object: (num_envs,) distance = torch.norm(des_pos_w - object.data.root_pos_w[:, :3], dim=1) # rewarded if the object is lifted above the threshold - return (object.data.root_pos_w[:, 2] > minimal_height) * (1 - torch.tanh(distance / std)) + + if command_name == "object_pose": + # for lifting - reward when above minimal height + return ((object.data.root_pos_w[:, 2] > minimal_height) & (object.data.root_pos_w[:, 2] < maximal_height)) * ( + 1 - torch.tanh(distance / std) + ) + elif command_name == "place_pose": + # for placing - reward getting close to place position + target_height = des_pos_w[:, 2] # Get height from command + height_diff = torch.abs(object.data.root_pos_w[:, 2] - target_height) + return 1 - torch.tanh((distance + height_diff) / std) diff --git a/source/standalone/environments/state_machine/lift_and_place_cube_sm.py b/source/standalone/environments/state_machine/lift_and_place_cube_sm.py new file mode 100644 index 0000000000..215fd87462 --- /dev/null +++ b/source/standalone/environments/state_machine/lift_and_place_cube_sm.py @@ -0,0 +1,346 @@ +# Copyright (c) 2022-2024, The Isaac Lab Project Developers. +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +""" +Script to run an environment with a pick, lift and place state machine. + +The state machine is implemented in the kernel function `infer_state_machine`. +It uses the `warp` library to run the state machine in parallel on the GPU. + +.. code-block:: bash + + ./isaaclab.sh -p source/standalone/environments/state_machine/lift_and_place_cube_sm.py --num_envs 32 + +""" + +"""Launch Omniverse Toolkit first.""" + +import argparse + +from omni.isaac.lab.app import AppLauncher + +# add argparse arguments +parser = argparse.ArgumentParser(description="Pick, lift and place state machine for lift environments.") +parser.add_argument( + "--disable_fabric", action="store_true", default=False, help="Disable fabric and use USD I/O operations." +) +parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.") +# append AppLauncher cli args +AppLauncher.add_app_launcher_args(parser) +# parse the arguments +args_cli = parser.parse_args() + +# launch omniverse app +app_launcher = AppLauncher(headless=args_cli.headless) +simulation_app = app_launcher.app + +"""Rest everything else.""" + +import gymnasium as gym +import torch +from collections.abc import Sequence + +import warp as wp + +from omni.isaac.lab.assets.rigid_object.rigid_object_data import RigidObjectData + +import omni.isaac.lab_tasks # noqa: F401 +from omni.isaac.lab_tasks.manager_based.manipulation.lift.lift_place_env_cfg import LiftEnvCfg +from omni.isaac.lab_tasks.utils.parse_cfg import parse_env_cfg + +# initialize warp +wp.init() + + +class GripperState: + """States for the gripper.""" + + OPEN = wp.constant(1.0) + CLOSE = wp.constant(-1.0) + + +class PickSmState: + """States for the pick state machine.""" + + REST = wp.constant(0) + APPROACH_ABOVE_OBJECT = wp.constant(1) + APPROACH_OBJECT = wp.constant(2) + GRASP_OBJECT = wp.constant(3) + LIFT_OBJECT = wp.constant(4) + APPROACH_ABOVE_PLACE_POSITION = wp.constant(5) + APPROACH_PLACE_POSITION = wp.constant(6) + PLACE_OBJECT = wp.constant(7) + + +class PickSmWaitTime: + """Additional wait times (in s) for states for before switching.""" + + REST = wp.constant(0.2) + APPROACH_ABOVE_OBJECT = wp.constant(0.5) + APPROACH_OBJECT = wp.constant(0.6) + GRASP_OBJECT = wp.constant(0.3) + LIFT_OBJECT = wp.constant(1.0) + APPROACH_ABOVE_PLACE_POSITION = wp.constant(0.5) + APPROACH_PLACE_POSITION = wp.constant(0.6) + PLACE_OBJECT = wp.constant(0.3) + + +@wp.kernel +def infer_state_machine( + dt: wp.array(dtype=float), + sm_state: wp.array(dtype=int), + sm_wait_time: wp.array(dtype=float), + ee_pose: wp.array(dtype=wp.transform), + object_pose: wp.array(dtype=wp.transform), + des_object_pose: wp.array(dtype=wp.transform), + place_pose: wp.array(dtype=wp.transform), + des_place_pose: wp.array(dtype=wp.transform), + des_ee_pose: wp.array(dtype=wp.transform), + gripper_state: wp.array(dtype=float), + offset: wp.array(dtype=wp.transform), +): + # retrieve thread id + tid = wp.tid() + # retrieve state machine state + state = sm_state[tid] + # decide next state + if state == PickSmState.REST: + des_ee_pose[tid] = ee_pose[tid] + gripper_state[tid] = GripperState.OPEN + # wait for a while + if sm_wait_time[tid] >= PickSmWaitTime.REST: + # move to next state and reset wait time + sm_state[tid] = PickSmState.APPROACH_ABOVE_OBJECT + sm_wait_time[tid] = 0.0 + elif state == PickSmState.APPROACH_ABOVE_OBJECT: + des_ee_pose[tid] = wp.transform_multiply(offset[tid], object_pose[tid]) + gripper_state[tid] = GripperState.OPEN + # TODO: error between current and desired ee pose below threshold + # wait for a while + if sm_wait_time[tid] >= PickSmWaitTime.APPROACH_OBJECT: + # move to next state and reset wait time + sm_state[tid] = PickSmState.APPROACH_OBJECT + sm_wait_time[tid] = 0.0 + elif state == PickSmState.APPROACH_OBJECT: + des_ee_pose[tid] = object_pose[tid] + gripper_state[tid] = GripperState.OPEN + # TODO: error between current and desired ee pose below threshold + # wait for a while + if sm_wait_time[tid] >= PickSmWaitTime.APPROACH_OBJECT: + # move to next state and reset wait time + sm_state[tid] = PickSmState.GRASP_OBJECT + sm_wait_time[tid] = 0.0 + elif state == PickSmState.GRASP_OBJECT: + des_ee_pose[tid] = object_pose[tid] + gripper_state[tid] = GripperState.CLOSE + # wait for a while + if sm_wait_time[tid] >= PickSmWaitTime.GRASP_OBJECT: + # move to next state and reset wait time + sm_state[tid] = PickSmState.LIFT_OBJECT + sm_wait_time[tid] = 0.0 + elif state == PickSmState.LIFT_OBJECT: + des_ee_pose[tid] = des_object_pose[tid] + gripper_state[tid] = GripperState.CLOSE + # TODO: error between current and desired ee pose below threshold + # wait for a while + if sm_wait_time[tid] >= PickSmWaitTime.LIFT_OBJECT: + # move to next state and reset wait time + sm_state[tid] = PickSmState.APPROACH_ABOVE_PLACE_POSITION + sm_wait_time[tid] = 0.0 + elif state == PickSmState.APPROACH_ABOVE_PLACE_POSITION: + des_ee_pose[tid] = wp.transform_multiply(offset[tid], place_pose[tid]) + gripper_state[tid] = GripperState.CLOSE + + if sm_wait_time[tid] >= PickSmWaitTime.APPROACH_ABOVE_PLACE_POSITION: + # move to next state and reset wait time + sm_state[tid] = PickSmState.APPROACH_PLACE_POSITION + sm_wait_time[tid] = 0.0 + elif state == PickSmState.APPROACH_PLACE_POSITION: + des_ee_pose[tid] = des_place_pose[tid] + gripper_state[tid] = GripperState.CLOSE + # TODO: error between current and desired ee pose below threshold + # wait for a while + if sm_wait_time[tid] >= PickSmWaitTime.APPROACH_PLACE_POSITION: + # move to next state and reset wait time + sm_state[tid] = PickSmState.PLACE_OBJECT # or PickSmState.REST + sm_wait_time[tid] = 0.0 + elif state == PickSmState.PLACE_OBJECT: + des_ee_pose[tid] = des_place_pose[tid] + gripper_state[tid] = GripperState.OPEN + # TODO: error between current and desired ee pose below threshold + # wait for a while + if sm_wait_time[tid] >= PickSmWaitTime.PLACE_OBJECT: + # move to next state and reset wait time + sm_state[tid] = PickSmState.PLACE_OBJECT + sm_wait_time[tid] = 0.0 + + + # increment wait time + sm_wait_time[tid] = sm_wait_time[tid] + dt[tid] + + +class PickAndLiftSm: + """A simple state machine in a robot's task space to pick, lift and place an object. + + The state machine is implemented as a warp kernel. It takes in the current state of + the robot's end-effector and the object, and outputs the desired state of the robot's + end-effector and the gripper. The state machine is implemented as a finite state + machine with the following states: + + 1. REST: The robot is at rest. + 2. APPROACH_ABOVE_OBJECT: The robot moves above the object. + 3. APPROACH_OBJECT: The robot moves to the object. + 4. GRASP_OBJECT: The robot grasps the object. + 5. LIFT_OBJECT: The robot lifts the object to the desired pose. This is the final state. + 6. APPROACH_ABOVE_PLACE_POSITION: The robot moves above the place position. + 7. APPROACH_PLACE_POSITION: The robot moves to the place position. + 8. PLACE_OBJECT: The robot drops the object. + """ + + def __init__(self, dt: float, num_envs: int, device: torch.device | str = "cpu"): + """Initialize the state machine. + + Args: + dt: The environment time step. + num_envs: The number of environments to simulate. + device: The device to run the state machine on. + """ + # save parameters + self.dt = float(dt) + self.num_envs = num_envs + self.device = device + # initialize state machine + self.sm_dt = torch.full((self.num_envs,), self.dt, device=self.device) + self.sm_state = torch.full((self.num_envs,), 0, dtype=torch.int32, device=self.device) + self.sm_wait_time = torch.zeros((self.num_envs,), device=self.device) + + # desired state + self.des_ee_pose = torch.zeros((self.num_envs, 7), device=self.device) + self.des_gripper_state = torch.full((self.num_envs,), 0.0, device=self.device) + + # approach above object offset + self.offset = torch.zeros((self.num_envs, 7), device=self.device) + self.offset[:, 2] = 0.1 + self.offset[:, -1] = 1.0 # warp expects quaternion as (x, y, z, w) + + # convert to warp + self.sm_dt_wp = wp.from_torch(self.sm_dt, wp.float32) + self.sm_state_wp = wp.from_torch(self.sm_state, wp.int32) + self.sm_wait_time_wp = wp.from_torch(self.sm_wait_time, wp.float32) + self.des_ee_pose_wp = wp.from_torch(self.des_ee_pose, wp.transform) + self.des_gripper_state_wp = wp.from_torch(self.des_gripper_state, wp.float32) + self.offset_wp = wp.from_torch(self.offset, wp.transform) + + def reset_idx(self, env_ids: Sequence[int] = None): + """Reset the state machine.""" + if env_ids is None: + env_ids = slice(None) + self.sm_state[env_ids] = 0 + self.sm_wait_time[env_ids] = 0.0 + + def compute(self, ee_pose: torch.Tensor, object_pose: torch.Tensor, des_object_pose: torch.Tensor, place_pose: torch.Tensor, des_place_pose: torch.Tensor): + """Compute the desired state of the robot's end-effector and the gripper.""" + # convert all transformations from (w, x, y, z) to (x, y, z, w) + ee_pose = ee_pose[:, [0, 1, 2, 4, 5, 6, 3]] + object_pose = object_pose[:, [0, 1, 2, 4, 5, 6, 3]] + des_object_pose = des_object_pose[:, [0, 1, 2, 4, 5, 6, 3]] + place_pose = place_pose[:, [0, 1, 2, 4, 5, 6, 3]] + des_place_pose = des_place_pose[:, [0, 1, 2, 4, 5, 6, 3]] + + # convert to warp + ee_pose_wp = wp.from_torch(ee_pose.contiguous(), wp.transform) + object_pose_wp = wp.from_torch(object_pose.contiguous(), wp.transform) + des_object_pose_wp = wp.from_torch(des_object_pose.contiguous(), wp.transform) + place_pose_wp = wp.from_torch(place_pose.contiguous(), wp.transform) + des_place_pose_wp = wp.from_torch(des_place_pose.contiguous(), wp.transform) + + # run state machine + wp.launch( + kernel=infer_state_machine, + dim=self.num_envs, + inputs=[ + self.sm_dt_wp, + self.sm_state_wp, + self.sm_wait_time_wp, + ee_pose_wp, + object_pose_wp, + des_object_pose_wp, + place_pose_wp, + des_place_pose_wp, + self.des_ee_pose_wp, + self.des_gripper_state_wp, + self.offset_wp, + ], + device=self.device, + ) + + # convert transformations back to (w, x, y, z) + des_ee_pose = self.des_ee_pose[:, [0, 1, 2, 6, 3, 4, 5]] + # convert to torch + return torch.cat([des_ee_pose, self.des_gripper_state.unsqueeze(-1)], dim=-1) + + +def main(): + # parse configuration + env_cfg: LiftEnvCfg = parse_env_cfg( + "Isaac-Lift-Cube-Franka-IK-Abs-v0", + device=args_cli.device, + num_envs=args_cli.num_envs, + use_fabric=not args_cli.disable_fabric, + ) + # create environment + env = gym.make("Isaac-Lift-Cube-Franka-IK-Abs-v0", cfg=env_cfg) + # reset environment at start + env.reset() + + # create action buffers (position + quaternion) + actions = torch.zeros(env.unwrapped.action_space.shape, device=env.unwrapped.device) + actions[:, 4] = 1.0 + # desired object orientation (we only do position control of object) + desired_orientation = torch.zeros((env.unwrapped.num_envs, 4), device=env.unwrapped.device) + desired_orientation[:, 1] = 1.0 + # create state machine + pick_sm = PickAndLiftSm(env_cfg.sim.dt * env_cfg.decimation, env.unwrapped.num_envs, env.unwrapped.device) + + while simulation_app.is_running(): + # run everything in inference mode + with torch.inference_mode(): + # step environment + dones = env.step(actions)[-2] + + # observations + # -- end-effector frame + ee_frame_sensor = env.unwrapped.scene["ee_frame"] + tcp_rest_position = ee_frame_sensor.data.target_pos_w[..., 0, :].clone() - env.unwrapped.scene.env_origins + tcp_rest_orientation = ee_frame_sensor.data.target_quat_w[..., 0, :].clone() + # -- object frame + object_data: RigidObjectData = env.unwrapped.scene["object"].data + object_position = object_data.root_pos_w - env.unwrapped.scene.env_origins + # -- target object frame + desired_position = env.unwrapped.command_manager.get_command("object_pose")[..., :3] + desired_place_position = env.unwrapped.command_manager.get_command("place_pose")[..., :3] + + # advance state machine + actions = pick_sm.compute( + torch.cat([tcp_rest_position, tcp_rest_orientation], dim=-1), + torch.cat([object_position, desired_orientation], dim=-1), + torch.cat([desired_position, desired_orientation], dim=-1), + torch.cat([desired_place_position, desired_orientation], dim=-1), + torch.cat([desired_place_position, desired_orientation], dim=-1), + ) + + # reset state machine + if dones.any(): + pick_sm.reset_idx(dones.nonzero(as_tuple=False).squeeze(-1)) + + # close the environment + env.close() + + +if __name__ == "__main__": + # run the main function + main() + # close sim app + simulation_app.close() \ No newline at end of file diff --git a/source/standalone/environments/state_machine/lift_cube_sm.py b/source/standalone/environments/state_machine/lift_cube_sm.py index bd14dcaf0d..f5fdb280c4 100644 --- a/source/standalone/environments/state_machine/lift_cube_sm.py +++ b/source/standalone/environments/state_machine/lift_cube_sm.py @@ -47,7 +47,7 @@ from omni.isaac.lab.assets.rigid_object.rigid_object_data import RigidObjectData import omni.isaac.lab_tasks # noqa: F401 -from omni.isaac.lab_tasks.manager_based.manipulation.lift.lift_env_cfg import LiftEnvCfg +from omni.isaac.lab_tasks.manager_based.manipulation.lift.lift_place_env_cfg import LiftEnvCfg from omni.isaac.lab_tasks.utils.parse_cfg import parse_env_cfg # initialize warp @@ -295,4 +295,4 @@ def main(): # run the main function main() # close sim app - simulation_app.close() + simulation_app.close() \ No newline at end of file