@@ -25,6 +25,36 @@ def observable_indices_in_tensor(
25
25
return sorted_obs_dict
26
26
27
27
28
+ def wing_qpos_to_conventional (model_wing_qpos : np .ndarray ,
29
+ body_pitch_angle : float = 47.5 ,
30
+ ) -> np .ndarray :
31
+ """Transform model wing joint qpos to conventional wing kinematics definition.
32
+
33
+ Args:
34
+ model_wing_qpos: Wing MjData.qpos in radians, shape (B, 6).
35
+ Order of joints: yaw, roll, pitch, yaw, roll, pitch.
36
+ Left-right order is arbitrary.
37
+ body_pitch_angle: Body pitch angle for initial flight pose, relative to
38
+ ground, degrees. 0: horizontal body position. Default value from
39
+ https://doi.org/10.1126/science.1248955
40
+
41
+ Returns:
42
+ Wing angles transformed to conventional representation.
43
+ """
44
+ if not isinstance (model_wing_qpos , np .ndarray ):
45
+ model_wing_qpos = np .array (model_wing_qpos )
46
+ conventional = np .zeros_like (model_wing_qpos )
47
+ body_pitch_angle = np .deg2rad (body_pitch_angle )
48
+ # Yaw, doesn't require transformation.
49
+ conventional [..., [0 , 3 ]] = model_wing_qpos [..., [0 , 3 ]].copy ()
50
+ # Roll.
51
+ conventional [..., [1 , 4 ]] = - model_wing_qpos [..., [1 , 4 ]]
52
+ # Pitch.
53
+ conventional [..., [2 , 5 ]] = (
54
+ np .pi / 2 - body_pitch_angle - model_wing_qpos [..., [2 , 5 ]])
55
+ return conventional
56
+
57
+
28
58
def get_random_policy (action_spec : 'dm_env.specs.BoundedArray' ,
29
59
minimum : float = - 0.2 ,
30
60
maximum : float = 0.2 ) -> Callable [[Any ], np .ndarray ]:
0 commit comments