Skip to content

Commit

Permalink
fixed issue with rock scattering
Browse files Browse the repository at this point in the history
  • Loading branch information
AntoineRichard committed Sep 10, 2024
1 parent 0b4b779 commit bf9ef0b
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 15 deletions.
30 changes: 16 additions & 14 deletions src/terrain_management/large_scale_terrain/rock_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def build_samplers(self) -> None:
if self.scale_sampler.seed is None:
self.scale_sampler.set_seed(self.settings.seed + 3)

def run(self, region: BoundingBox, map_coordinates: Tuple[float, float]) -> RockBlockData:
def run(self, region: BoundingBox, map_coordinates: Tuple[float, float]) -> Tuple[RockBlockData, bool]:
"""
Runs the rock distribution.
Expand All @@ -468,7 +468,7 @@ def run(self, region: BoundingBox, map_coordinates: Tuple[float, float]) -> Rock
map_coordinates (Tuple[float, float]): coordinates of the map.
Returns:
RockBlockData: block of rocks.
Tuple[RockBlockData, bool]: block of rocks, not empty. Returns False if the block is empty.
"""

xy_position, num_points = self.position_sampler(region=region)
Expand All @@ -477,7 +477,7 @@ def run(self, region: BoundingBox, map_coordinates: Tuple[float, float]) -> Rock
z_position, quat = self.sampling_func(xy_position[:, 0], xy_position[:, 1], map_coordinates, self.settings.seed)
xyz_position = np.stack([xy_position[:, 0], xy_position[:, 1], z_position]).T
block = RockBlockData(xyz_position, quat, scale, ids)
return block
return block, num_points > 0

def get_xy_coordinates_from_block(self, block: RockBlockData) -> np.ndarray:
"""
Expand Down Expand Up @@ -622,8 +622,9 @@ def sample_rocks_by_block(self, block_coordinates: Tuple[int, int], map_coordina
block_coordinates[1],
block_coordinates[1] + self.settings.block_size,
)
block = self.rock_dist_gen.run(bb, map_coordinates)
self.rock_db.add_block_data(block, block_coordinates)
block, not_empty = self.rock_dist_gen.run(bb, map_coordinates)
if not_empty:
self.rock_db.add_block_data(block, block_coordinates)

def dissect_region_blocks(
self, block: RockBlockData, region: BoundingBox
Expand Down Expand Up @@ -725,18 +726,19 @@ def sample_rocks_by_region(self, region: BoundingBox, map_coordinates: Tuple[flo

# Samples rocks in the region
with ScopedTimer("Sampling rocks in region", active=self.profiling):
new_block = self.rock_dist_gen.run(new_region, map_coordinates)
new_block, not_empty = self.rock_dist_gen.run(new_region, map_coordinates)

with ScopedTimer("Getting xy coordinates from block", active=self.profiling):
coords = self.rock_dist_gen.get_xy_coordinates_from_block(new_block)
if not_empty:
with ScopedTimer("Getting xy coordinates from block", active=self.profiling):
coords = self.rock_dist_gen.get_xy_coordinates_from_block(new_block)

# Dissects the region into blocks and adds the rocks to the database
with ScopedTimer("Dissecting region into blocks", active=self.profiling):
new_blocks_list, block_coordinates_list = self.dissect_region_blocks(new_block, new_region)
# Dissects the region into blocks and adds the rocks to the database
with ScopedTimer("Dissecting region into blocks", active=self.profiling):
new_blocks_list, block_coordinates_list = self.dissect_region_blocks(new_block, new_region)

with ScopedTimer("Adding block data to the database", active=self.profiling):
for block_data, block_coordinates in zip(new_blocks_list, block_coordinates_list):
self.rock_db.add_block_data(block_data, block_coordinates)
with ScopedTimer("Adding block data to the database", active=self.profiling):
for block_data, block_coordinates in zip(new_blocks_list, block_coordinates_list):
self.rock_db.add_block_data(block_data, block_coordinates)

# If the largest rectangle is smaller or equal to 1 block, we sample rocks
# on a per block basis.
Expand Down
4 changes: 3 additions & 1 deletion src/terrain_management/large_scale_terrain/rock_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ def get_asset_list(self):
"""

self.file_paths = [
os.path.join(self.assets_path, file) for file in os.listdir(self.assets_path) if file.endswith(".usd")
os.path.join(self.assets_path, file)
for file in os.listdir(self.assets_path)
if (file.endswith(".usd") or file.endswith(".usda") or file.endswith(".usdz"))
]

def load_material(self):
Expand Down

0 comments on commit bf9ef0b

Please sign in to comment.