diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index 52e723135..d83eb9496 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -934,7 +934,7 @@ class HwLinkMetadata(JaxsimDataclass): L_H_pre: The homogeneous transforms for child joints. """ - shape: jtp.Vector + _shape: Static[tuple[int]] dims: jtp.Vector density: jtp.Float L_H_G: jtp.Matrix @@ -942,8 +942,16 @@ class HwLinkMetadata(JaxsimDataclass): L_H_pre_mask: jtp.Vector L_H_pre: jtp.Matrix + @property + def shape(self) -> int: + """ + Return the shape of the link. + """ + return np.array(self._shape) + @staticmethod def compute_mass_and_inertia( + shape_types: jtp.Array, hw_link_metadata: HwLinkMetadata, ) -> tuple[jtp.Float, jtp.Matrix]: """ @@ -954,6 +962,7 @@ def compute_mass_and_inertia( by using shape-specific methods. Args: + shape_types: The shape types of the link (e.g., box, sphere, cylinder). hw_link_metadata: Metadata describing the hardware link, including its shape, dimensions, and density. @@ -963,68 +972,128 @@ def compute_mass_and_inertia( - inertia: The computed inertia tensor of the hardware link. """ - mass, inertia = jax.lax.switch( - hw_link_metadata.shape, - [ - HwLinkMetadata._box, - HwLinkMetadata._cylinder, - HwLinkMetadata._sphere, - ], + def box(dims, density) -> tuple[jtp.Float, jtp.Matrix]: + lx, ly, lz = dims + + mass = density * lx * ly * lz + + inertia = jnp.array( + [ + [mass * (ly**2 + lz**2) / 12, 0, 0], + [0, mass * (lx**2 + lz**2) / 12, 0], + [0, 0, mass * (lx**2 + ly**2) / 12], + ] + ) + return mass, inertia + + def cylinder(dims, density) -> tuple[jtp.Float, jtp.Matrix]: + r, l, _ = dims + + mass = density * (jnp.pi * r**2 * l) + + inertia = jnp.array( + [ + [mass * (3 * r**2 + l**2) / 12, 0, 0], + [0, mass * (3 * r**2 + l**2) / 12, 0], + [0, 0, mass * (r**2) / 2], + ] + ) + + return mass, inertia + + def sphere(dims, density) -> tuple[jtp.Float, jtp.Matrix]: + r = dims[0] + + mass = density * (4 / 3 * jnp.pi * r**3) + + inertia = jnp.eye(3) * (2 / 5 * mass * r**2) + + return mass, inertia + + def compute_mass_inertia(shape_idx, dims, density): + return jax.lax.switch(shape_idx, (box, cylinder, sphere), dims, density) + + mass, inertia = jax.vmap(compute_mass_inertia)( + jnp.array(shape_types), hw_link_metadata.dims, hw_link_metadata.density, ) + return mass, inertia @staticmethod - def _box(dims, density) -> tuple[jtp.Float, jtp.Matrix]: - lx, ly, lz = dims - - mass = density * lx * ly * lz + def compute_contact_points( + original_contact_params: jtp.Vector, + shape_types: jtp.Vector, + original_com_positions: jtp.Vector, + updated_com_positions: jtp.Vector, + scaling_factors: ScalingFactors, + ) -> jtp.Matrix: + """ + Compute the new contact points based on the original contact parameters and + the scaling factors. - inertia = jnp.array( - [ - [mass * (ly**2 + lz**2) / 12, 0, 0], - [0, mass * (lx**2 + lz**2) / 12, 0], - [0, 0, mass * (lx**2 + ly**2) / 12], - ] - ) - return mass, inertia + Args: + original_contact_params: The original contact parameters. + shape_types: The shape types of the links (e.g., box, sphere, cylinder). + original_com_positions: The original center of mass positions of the links. + updated_com_positions: The updated center of mass positions of the links. + scaling_factors: The scaling factors for the link dimensions. - @staticmethod - def _cylinder(dims, density) -> tuple[jtp.Float, jtp.Matrix]: - r, l, _ = dims + Returns: + The new contact points positions in the parent link frame. + """ - mass = density * (jnp.pi * r**2 * l) + parent_link_indices = np.array(original_contact_params.body) - inertia = jnp.array( - [ - [mass * (3 * r**2 + l**2) / 12, 0, 0], - [0, mass * (3 * r**2 + l**2) / 12, 0], - [0, 0, mass * (r**2) / 2], - ] + # Translate the original contact point positions in the origin, so + # that we can apply the scaling factors. + L_p_Ci = ( + original_contact_params.point - original_com_positions[parent_link_indices] ) - return mass, inertia + # Extract the shape types of the parent links. + parent_shape_types = jnp.array(shape_types[parent_link_indices]) - @staticmethod - def _sphere(dims, density) -> tuple[jtp.Float, jtp.Matrix]: - r = dims[0] + def sphere(parent_idx, L_p_C): + r = scaling_factors.dims[parent_idx][0] + return L_p_C * r - mass = density * (4 / 3 * jnp.pi * r**3) + def cylinder(parent_idx, L_p_C): + # Cylinder collisions are not supported in JaxSim. + return L_p_C + + def box(parent_idx, L_p_C): + lx, ly, lz = scaling_factors.dims[parent_idx] + return jnp.hstack( + [ + L_p_C[0] * lx, + L_p_C[1] * ly, + L_p_C[2] * lz, + ] + ) - inertia = jnp.eye(3) * (2 / 5 * mass * r**2) + new_positions = jax.vmap( + lambda shape_idx, parent_idx, L_p_C: jax.lax.switch( + shape_idx, (box, cylinder, sphere), parent_idx, L_p_C + ) + )( + parent_shape_types, + parent_link_indices, + L_p_Ci, + ) - return mass, inertia + return new_positions + updated_com_positions[parent_link_indices] @staticmethod def _convert_scaling_to_3d_vector( - shape: jtp.Int, scaling_factors: jtp.Vector + shape_types: jtp.Int, scaling_factors: jtp.Vector ) -> jtp.Vector: """ Convert scaling factors for specific shape dimensions into a 3D scaling vector. Args: - shape: The shape of the link (e.g., box, sphere, cylinder). + shape_types: The shape_types of the link (e.g., box, sphere, cylinder). scaling_factors: The scaling factors for the shape dimensions. Returns: @@ -1036,38 +1105,24 @@ def _convert_scaling_to_3d_vector( - Cylinder: [r, r, l] - Sphere: [r, r, r] """ - return jax.lax.switch( - shape, - branches=[ - # Box - lambda: jnp.array( - [ - scaling_factors[0], - scaling_factors[1], - scaling_factors[2], - ] - ), - # Cylinder - lambda: jnp.array( - [ - scaling_factors[0], - scaling_factors[0], - scaling_factors[1], - ] - ), - # Sphere - lambda: jnp.array( - [ - scaling_factors[0], - scaling_factors[0], - scaling_factors[0], - ] - ), - ], + + # Index mapping for each shape type (shape_type x 3 dims) + shape_indices = jnp.array( + [ + [0, 1, 2], # Box + [0, 0, 1], # Cylinder + [0, 0, 0], # Sphere + ] ) + # For each link, get the index vector for its shape + per_link_indices = shape_indices[shape_types] + + # Gather dims per link according to per_link_indices + return jnp.take_along_axis(scaling_factors.dims, per_link_indices, axis=1) + @staticmethod - def compute_inertia_link(I_com, mass, L_H_G) -> jtp.Matrix: + def compute_inertia_link(I_com, L_H_G) -> jtp.Matrix: """ Compute the inertia tensor of the link based on its shape and mass. """ @@ -1077,12 +1132,17 @@ def compute_inertia_link(I_com, mass, L_H_G) -> jtp.Matrix: @staticmethod def apply_scaling( - hw_metadata: HwLinkMetadata, scaling_factors: ScalingFactors + has_joints: bool, + scale_vector: jtp.Vector, + hw_metadata: HwLinkMetadata, + scaling_factors: ScalingFactors, ) -> HwLinkMetadata: """ Apply scaling to the hardware parameters and return a new HwLinkMetadata object. Args: + has_joints: A boolean indicating if the model has joints. + scale_vector: The scaling vector to apply. hw_metadata: the original HwLinkMetadata object. scaling_factors: the scaling factors to apply. @@ -1090,83 +1150,73 @@ def apply_scaling( A new HwLinkMetadata object with updated parameters. """ - # ================================== - # Handle unsupported links - # ================================== - def unsupported_case(hw_metadata, scaling_factors): - # Return the metadata unchanged for unsupported links - return hw_metadata - - def supported_case(hw_metadata, scaling_factors): - # ================================== - # Update the kinematics of the link - # ================================== - - # Get the nominal transforms - L_H_G = hw_metadata.L_H_G - L_H_vis = hw_metadata.L_H_vis - L_H_pre_array = hw_metadata.L_H_pre - L_H_pre_mask = hw_metadata.L_H_pre_mask - - # Compute the 3D scaling vector - scale_vector = HwLinkMetadata._convert_scaling_to_3d_vector( - hw_metadata.shape, scaling_factors.dims - ) + # ================================= + # Update the kinematics of the link + # ================================= - # Express the transforms in the G frame - G_H_L = jaxsim.math.Transform.inverse(L_H_G) - G_H_vis = G_H_L @ L_H_vis - G_H_pre_array = jax.vmap(lambda L_H_pre: G_H_L @ L_H_pre)(L_H_pre_array) - - # Apply the scaling to the position vectors - G_H̅_L = G_H_L.at[:3, 3].set(scale_vector * G_H_L[:3, 3]) - G_H̅_vis = G_H_vis.at[:3, 3].set(scale_vector * G_H_vis[:3, 3]) - # Apply scaling to the position vectors in G_H_pre_array based on the mask - G_H̅_pre_array = jax.vmap( - lambda G_H_pre, mask: jnp.where( - # Expand mask for broadcasting - mask[..., None, None], - # Apply scaling - G_H_pre.at[:3, 3].set(scale_vector * G_H_pre[:3, 3]), - # Keep unchanged if mask is False - G_H_pre, - ) - )(G_H_pre_array, L_H_pre_mask) + # Get the nominal transforms + L_H_G = hw_metadata.L_H_G + L_H_vis = hw_metadata.L_H_vis + L_H_pre_array = hw_metadata.L_H_pre + L_H_pre_mask = hw_metadata.L_H_pre_mask - # Get back to the link frame - L_H̅_G = jaxsim.math.Transform.inverse(G_H̅_L) - L_H̅_vis = L_H̅_G @ G_H̅_vis - L_H̅_pre_array = jax.vmap(lambda G_H̅_pre: L_H̅_G @ G_H̅_pre)(G_H̅_pre_array) + # Express the transforms in the G frame + G_H_L = jaxsim.math.Transform.inverse(L_H_G) + G_H_vis = G_H_L @ L_H_vis - # ============================ - # Update the shape parameters - # ============================ + G_H_pre_array = ( + jax.vmap(lambda L_H_pre: G_H_L @ L_H_pre)(L_H_pre_array) + if has_joints + else L_H_pre_array + ) - updated_dims = hw_metadata.dims * scaling_factors.dims + # Apply the scaling to the position vectors + G_H̅_vis = G_H_vis.at[:3, 3].set(scale_vector * G_H_vis[:3, 3]) - # ============================== - # Scale the density of the link - # ============================== + # Apply scaling to the position vectors in G_H_pre_array based on the mask + G_H̅_pre_array = ( + G_H_pre_array.at[:, :3, 3].set( + jnp.where( + L_H_pre_mask[:, None], + scale_vector[None, :] * G_H_pre_array[:, :3, 3], + G_H_pre_array[:, :3, 3], + ) + ) + if has_joints + else G_H_pre_array + ) - updated_density = hw_metadata.density * scaling_factors.density + # Get back to the link frame + L_H̅_G = L_H_G.at[:3, 3].set(scale_vector * L_H_G[:3, 3]) + L_H̅_vis = L_H̅_G @ G_H̅_vis + L_H̅_pre_array = ( + jax.vmap(lambda G_H̅_pre: L_H̅_G @ G_H̅_pre)(G_H̅_pre_array) + if has_joints + else G_H̅_pre_array + ) - # ============================ - # Return updated HwLinkMetadata - # ============================ + # =========================== + # Update the shape parameters + # =========================== - return hw_metadata.replace( - dims=updated_dims, - density=updated_density, - L_H_G=L_H̅_G, - L_H_vis=L_H̅_vis, - L_H_pre=L_H̅_pre_array, - ) + updated_dims = hw_metadata.dims * scaling_factors.dims + + # ============================= + # Scale the density of the link + # ============================= + + updated_density = hw_metadata.density * scaling_factors.density + + # ============================= + # Return updated HwLinkMetadata + # ============================= - # Use jax.lax.cond to handle unsupported links - return jax.lax.cond( - hw_metadata.shape == LinkParametrizableShape.Unsupported, - lambda: unsupported_case(hw_metadata, scaling_factors), - lambda: supported_case(hw_metadata, scaling_factors), + return hw_metadata.replace( + dims=updated_dims, + density=updated_density, + L_H_G=L_H̅_G, + L_H_vis=L_H̅_vis, + L_H_pre=L_H̅_pre_array, ) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index ea513fdbe..c5dec0197 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -10,13 +10,12 @@ import jax import jax.numpy as jnp import jax_dataclasses +import numpy as np import rod -import rod.urdf from jax_dataclasses import Static from rod.urdf.exporter import UrdfExporter import jaxsim.api as js -import jaxsim.exceptions import jaxsim.terrain import jaxsim.typing as jtp from jaxsim import logging @@ -472,7 +471,7 @@ def compute_hw_link_metadata(self) -> HwLinkMetadata: L_H_pre_masks.append( [ int(joint_index in child_joints_indices) - for joint_index in range(0, self.number_of_joints()) + for joint_index in range(self.number_of_joints()) ] ) L_H_pre.append( @@ -482,13 +481,13 @@ def compute_hw_link_metadata(self) -> HwLinkMetadata: if joint_index in child_joints_indices else jnp.eye(4) ) - for joint_index in range(0, self.number_of_joints()) + for joint_index in range(self.number_of_joints()) ] ) # Stack collected data into JAX arrays return HwLinkMetadata( - shape=jnp.array(shapes, dtype=int), + _shape=shapes, dims=jnp.array(dims, dtype=float), density=jnp.array(densities, dtype=float), L_H_G=jnp.array(L_H_Gs, dtype=float), @@ -508,8 +507,6 @@ def export_updated_model(self) -> str: This method is not meant to be used in JIT-compiled functions. """ - import numpy as np - if isinstance(jnp.zeros(0), jax.core.Tracer): raise RuntimeError("This method cannot be used in JIT-compiled functions") @@ -565,11 +562,16 @@ def export_updated_model(self) -> str: dims = hw_metadata.dims[link_index] if shape == LinkParametrizableShape.Box: links_dict[link_name].visual.geometry.box.size = dims.tolist() + links_dict[link_name].collision.geometry.box.size = dims.tolist() elif shape == LinkParametrizableShape.Sphere: links_dict[link_name].visual.geometry.sphere.radius = float(dims[0]) + links_dict[link_name].collision.geometry.sphere.radius = float(dims[0]) elif shape == LinkParametrizableShape.Cylinder: links_dict[link_name].visual.geometry.cylinder.radius = float(dims[0]) links_dict[link_name].visual.geometry.cylinder.length = float(dims[1]) + links_dict[link_name].collision.geometry.cylinder.radius = float( + dims[0] + ) else: logging.debug(f"Skipping unsupported shape for link '{link_name}'") continue @@ -2356,14 +2358,48 @@ def update_hw_parameters( link_parameters: LinkParameters = kin_dyn_params.link_parameters hw_link_metadata: HwLinkMetadata = kin_dyn_params.hw_link_metadata + has_joints = model.number_of_joints() > 0 + + supported_mask = hw_link_metadata.shape != LinkParametrizableShape.Unsupported + + supported_metadata = jax.tree.map(lambda l: l[supported_mask], hw_link_metadata) + + supported_scaling_factors = jax.tree.map( + lambda l: l[supported_mask], scaling_factors + ) + + scale_vector = HwLinkMetadata._convert_scaling_to_3d_vector( + supported_metadata.shape, supported_scaling_factors + ) + # Apply scaling to hw_link_metadata using vmap - updated_hw_link_metadata = jax.vmap(HwLinkMetadata.apply_scaling)( - hw_link_metadata, scaling_factors + scaled_hw_link_metadata_supported = jax.vmap( + HwLinkMetadata.apply_scaling, in_axes=(None,) + )( + has_joints, + scale_vector=scale_vector, + hw_metadata=supported_metadata, + scaling_factors=scaling_factors, + ) + + # Helper function to merge pytrees leaf-wise with boolean mask + def merge_pytree_by_mask(scaled_pytree, original_pytree, mask): + + def merge_leaf(scaled_leaf, original_leaf): + mask_shape = (mask.shape[0],) + (1,) * (scaled_leaf.ndim - 1) + mask_broadcasted = mask.reshape(mask_shape) + + return jnp.where(mask_broadcasted, scaled_leaf, original_leaf) + + return jax.tree.map(merge_leaf, scaled_pytree, original_pytree) + + updated_hw_link_metadata = merge_pytree_by_mask( + scaled_hw_link_metadata_supported, hw_link_metadata, supported_mask ) # Compute mass and inertia once and unpack the results - m_updated, I_com_updated = jax.vmap(HwLinkMetadata.compute_mass_and_inertia)( - updated_hw_link_metadata + m_updated, I_com_updated = HwLinkMetadata.compute_mass_and_inertia( + hw_link_metadata.shape, updated_hw_link_metadata ) # Rotate the inertia tensor at CoM with the link orientation, and store @@ -2383,46 +2419,61 @@ def update_hw_parameters( ), ) + # Compute the contact parameters + points = HwLinkMetadata.compute_contact_points( + original_contact_params=kin_dyn_params.contact_parameters, + shape_types=updated_hw_link_metadata.shape, + original_com_positions=link_parameters.center_of_mass, + updated_com_positions=updated_link_parameters.center_of_mass, + scaling_factors=scaling_factors, + ) + + # Update contact parameters + updated_contact_parameters = kin_dyn_params.contact_parameters.replace(point=points) + # Update joint model transforms (λ_H_pre) def update_λ_H_pre(joint_index): # Extract the transforms and masks for the current joint index across all links L_H_pre_for_joint = updated_hw_link_metadata.L_H_pre[:, joint_index] L_H_pre_mask_for_joint = updated_hw_link_metadata.L_H_pre_mask[:, joint_index] - # Use the mask to select the first valid transform or fall back to the original - valid_transforms = jnp.where( - L_H_pre_mask_for_joint[:, None, None], # Expand mask for broadcasting - L_H_pre_for_joint, # Use the transform if the mask is True - jnp.zeros_like(L_H_pre_for_joint), # Otherwise, use a zero matrix - ) + # Select the first valid transform (if any) using the mask + first_valid_index = jnp.argmax(L_H_pre_mask_for_joint) + selected_transform = L_H_pre_for_joint[first_valid_index] + + # Check if any valid transform exists + has_valid_transform = L_H_pre_mask_for_joint.any() + + # Fallback to the original λ_H_pre if no valid transform exists + fallback_transform = kin_dyn_params.joint_model.λ_H_pre[joint_index + 1] - # Sum the valid transforms (only one will be non-zero due to the mask) - selected_transform = jnp.sum(valid_transforms, axis=0) + # Return the selected transform or fallback + return jnp.where(has_valid_transform, selected_transform, fallback_transform) - # If no valid transform exists, fall back to the original λ_H_pre - return jax.lax.cond( - jnp.any(L_H_pre_mask_for_joint), - lambda: selected_transform, - lambda: kin_dyn_params.joint_model.λ_H_pre[joint_index + 1], + if has_joints: + # Apply the update function to all joint indices + updated_λ_H_pre = jax.vmap(update_λ_H_pre)( + jnp.arange(kin_dyn_params.number_of_joints()) ) - # Apply the update function to all joint indices - updated_λ_H_pre = jax.vmap(update_λ_H_pre)( - jnp.arange(kin_dyn_params.number_of_joints()) - ) - # NOTE: λ_H_pre should be of len (1+n_joints) with the 0-th element equal - # to identity to represent the world-to-base tree transform. See JointModel class - updated_λ_H_pre_with_base = jnp.concatenate( - (jnp.eye(4).reshape(1, 4, 4), updated_λ_H_pre), axis=0 - ) - # Replace the joint model with the updated transforms - updated_joint_model = kin_dyn_params.joint_model.replace( - λ_H_pre=updated_λ_H_pre_with_base - ) + # NOTE: λ_H_pre should be of len (1+n_joints) with the 0-th element equal + # to identity to represent the world-to-base tree transform. See JointModel class + updated_λ_H_pre_with_base = jnp.concatenate( + (jnp.eye(4).reshape(1, 4, 4), updated_λ_H_pre), axis=0 + ) + + # Replace the joint model with the updated transforms + updated_joint_model = kin_dyn_params.joint_model.replace( + λ_H_pre=updated_λ_H_pre_with_base + ) + else: + # If there are no joints, we can just use the identity transform + updated_joint_model = kin_dyn_params.joint_model # Replace the kin_dyn_parameters with updated values updated_kin_dyn_params = kin_dyn_params.replace( link_parameters=updated_link_parameters, + contact_parameters=updated_contact_parameters, hw_link_metadata=updated_hw_link_metadata, joint_model=updated_joint_model, ) diff --git a/src/jaxsim/mujoco/__init__.py b/src/jaxsim/mujoco/__init__.py index 6019903fc..36cbea1e7 100644 --- a/src/jaxsim/mujoco/__init__.py +++ b/src/jaxsim/mujoco/__init__.py @@ -1,4 +1,4 @@ from .loaders import ModelToMjcf, RodModelToMjcf, SdfToMjcf, UrdfToMjcf from .model import MujocoModelHelper -from .utils import mujoco_data_from_jaxsim +from .utils import MujocoCamera, mujoco_data_from_jaxsim from .visualizer import MujocoVideoRecorder, MujocoVisualizer diff --git a/src/jaxsim/typing.py b/src/jaxsim/typing.py index 2e4b4eb9b..96eca65da 100644 --- a/src/jaxsim/typing.py +++ b/src/jaxsim/typing.py @@ -1,5 +1,5 @@ from collections.abc import Hashable -from typing import Any, TypeVar +from typing import Any, NewType, TypeVar import jax @@ -16,13 +16,14 @@ Bool = Scalar Float = Scalar -PyTree: object = ( +PyTree = NewType( + "PyTree", dict[Hashable, TypeVar("PyTree")] | list[TypeVar("PyTree")] | tuple[TypeVar("PyTree")] | jax.Array | Any - | None + | None, ) # ======================= diff --git a/tests/test_api_model_hw_parametrization.py b/tests/test_api_model_hw_parametrization.py index 5cdb1243c..df1ea2cf3 100644 --- a/tests/test_api_model_hw_parametrization.py +++ b/tests/test_api_model_hw_parametrization.py @@ -95,38 +95,34 @@ def test_model_scaling_against_rod( ) # Compare hardware parameters of the scaled JaxSim model with the pre-scaled JaxSim model - for link_idx, link_name in enumerate(jaxsim_model_garpez.link_names()): - scaled_metadata = jax.tree_util.tree_map( - lambda x, link_idx=link_idx: x[link_idx], - updated_model.kin_dyn_parameters.hw_link_metadata, - ) - pre_scaled_metadata = jax.tree_util.tree_map( - lambda x, link_idx=link_idx: x[link_idx], - jaxsim_model_garpez_scaled.kin_dyn_parameters.hw_link_metadata, - ) + scaled_metadata = updated_model.kin_dyn_parameters.hw_link_metadata - # Compare shape dimensions - assert jnp.allclose(scaled_metadata.dims, pre_scaled_metadata.dims, atol=1e-6) + pre_scaled_metadata = jaxsim_model_garpez_scaled.kin_dyn_parameters.hw_link_metadata - # Compare mass - scaled_mass, _ = HwLinkMetadata.compute_mass_and_inertia(scaled_metadata) - pre_scaled_mass, _ = HwLinkMetadata.compute_mass_and_inertia( - pre_scaled_metadata - ) - assert scaled_mass == pytest.approx(pre_scaled_mass, abs=1e-6) + # Compare shape dimensions + assert jnp.allclose(scaled_metadata.dims, pre_scaled_metadata.dims, atol=1e-6) - # Compare inertia tensors - _, scaled_inertia = HwLinkMetadata.compute_mass_and_inertia(scaled_metadata) - _, pre_scaled_inertia = HwLinkMetadata.compute_mass_and_inertia( - pre_scaled_metadata - ) - assert jnp.allclose(scaled_inertia, pre_scaled_inertia, atol=1e-6) + # Compare mass + scaled_mass, _ = HwLinkMetadata.compute_mass_and_inertia( + scaled_metadata.shape, scaled_metadata + ) + pre_scaled_mass, _ = HwLinkMetadata.compute_mass_and_inertia( + pre_scaled_metadata.shape, pre_scaled_metadata + ) + assert scaled_mass == pytest.approx(pre_scaled_mass, abs=1e-6) - # Compare transformations - assert jnp.allclose(scaled_metadata.L_H_G, pre_scaled_metadata.L_H_G, atol=1e-6) - assert jnp.allclose( - scaled_metadata.L_H_vis, pre_scaled_metadata.L_H_vis, atol=1e-6 - ) + # Compare inertia tensors + _, scaled_inertia = HwLinkMetadata.compute_mass_and_inertia( + scaled_metadata.shape, scaled_metadata + ) + _, pre_scaled_inertia = HwLinkMetadata.compute_mass_and_inertia( + pre_scaled_metadata.shape, pre_scaled_metadata + ) + assert jnp.allclose(scaled_inertia, pre_scaled_inertia, atol=1e-6) + + # Compare transformations + assert jnp.allclose(scaled_metadata.L_H_G, pre_scaled_metadata.L_H_G, atol=1e-6) + assert jnp.allclose(scaled_metadata.L_H_vis, pre_scaled_metadata.L_H_vis, atol=1e-6) def test_update_hw_parameters_vmap( @@ -325,10 +321,7 @@ def test_hw_parameters_optimization(jaxsim_model_garpez: js.model.JaxSimModel): # Define the initial hardware parameters (scaling factors). initial_dims = jnp.ones( - ( - model.number_of_links(), - 3, - ) + (model.number_of_links(), 3) ) # Initial dimensions (1.0 for all links). initial_density = jnp.ones( (model.number_of_links(),) @@ -378,3 +371,85 @@ def loss(scaling_factors): # Assert that the final loss is close to zero. assert current_loss < 1e-3, "Optimization did not converge to the target height." + + +def test_hw_parameters_collision_scaling( + jaxsim_model_box: js.model.JaxSimModel, prng_key: jax.Array +): + """ + Test that the collision elements of the model are updated correctly during the scaling of the model hw parameters. + """ + + _, subkey = jax.random.split(prng_key, num=2) + + # TODO: the jaxsim_model_box has an additional frame, which is handled wrongly + # during the export of the updated model. For this reason, we recreate the model + # from scratch here. + del jaxsim_model_box + + import rod.builder.primitives + + # Create on-the-fly a ROD model of a box. + rod_model = ( + rod.builder.primitives.BoxBuilder(x=0.3, y=0.2, z=0.1, mass=1.0, name="box") + .build_model() + .add_link(name="box_link") + .add_inertial() + .add_visual() + .add_collision() + .build() + ) + + model = js.model.JaxSimModel.build_from_model_description( + model_description=rod_model + ) + + # Define the scaling factor for the sphere's radius + scaling_factor = 5.0 + + # Define the nominal radius of the sphere + nominal_height = model.kin_dyn_parameters.hw_link_metadata.dims[0, 2] + + # Define scaling parameters + scaling_parameters = ScalingFactors( + dims=jnp.ones((model.number_of_links(), 3)) * scaling_factor, + density=jnp.array([1.0]), + ) + + # Update the model with the scaling parameters + updated_model = js.model.update_hw_parameters(model, scaling_parameters) + + # Simulate the box falling under gravity + data = js.data.JaxSimModelData.build( + model=updated_model, + # Set the initial position of the box's base to be slightly above the ground + # to allow it to settle at the expected height after scaling. + # The base position is set to the nominal height of the box scaled by the scaling factor, + # plus a small offset to avoid immediate collision with the ground. + # This ensures that the box has enough space to fall and settle at the expected height. + base_position=jnp.array( + [ + *jax.random.uniform(subkey, shape=(2,)), + nominal_height * scaling_factor + 0.01, + ] + ), + ) + + num_steps = 1000 # Number of simulation steps + + for _ in range(num_steps): + data = js.model.step( + model=updated_model, + data=data, + ) + + # Get the final height of the box's base + updated_base_height = data.base_position[2] + + # Compute the expected height (nominal radius * scaling factor) + expected_height = nominal_height * scaling_factor / 2 + + # Assert that the box settles at the expected height + assert jnp.isclose( + updated_base_height, expected_height, atol=1e-3 + ), f"model base height mismatch: expected {expected_height}, got {updated_base_height}"