Skip to content

Commit 438c9a9

Browse files
committed
refactor!: Implementing RealVectorState in JS/WASM
1 parent 92f2c8a commit 438c9a9

9 files changed

+92
-35
lines changed

oxmpl-js/src/base/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@ pub mod js_state_convert;
77
pub mod path;
88
pub mod planner;
99
pub mod problem_definition;
10+
pub mod real_vector_state;
1011
pub mod real_vector_state_space;
1112
pub mod state_validity_checker;
1213

1314
pub use goal::JsGoal;
1415
pub use path::JsPath;
1516
pub use planner::JsPlannerConfig;
1617
pub use problem_definition::JsProblemDefinition;
18+
pub use real_vector_state::JsRealVectorState;
1719
pub use real_vector_state_space::JsRealVectorStateSpace;
1820
pub use state_validity_checker::JsStateValidityChecker;

oxmpl-js/src/base/path.rs

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
//
33
// SPDX-License-Identifier: BSD-3-Clause
44

5-
use crate::base::js_state_convert::*;
6-
use js_sys::Float64Array;
5+
use js_sys::Array;
76
use oxmpl::base::{planner::Path, state::RealVectorState};
87
use wasm_bindgen::prelude::*;
98

@@ -16,11 +15,21 @@ pub struct JsPath {
1615
#[wasm_bindgen(js_class = Path)]
1716
impl JsPath {
1817
#[wasm_bindgen(js_name = getStates)]
19-
pub fn get_states(&self) -> Vec<Float64Array> {
20-
self.states.0.iter().map(state_to_js_array).collect()
18+
pub fn get_states(&self) -> Array {
19+
self.states
20+
.0
21+
.iter()
22+
.map(|s| {
23+
s.values
24+
.iter()
25+
.map(|&v| JsValue::from_f64(v))
26+
.collect::<Array>()
27+
})
28+
.collect::<Array>()
2129
}
2230

23-
pub fn length(&self) -> usize {
31+
#[wasm_bindgen(js_name = getLength)]
32+
pub fn get_length(&self) -> usize {
2433
self.states.0.len()
2534
}
2635
}

oxmpl-js/src/base/problem_definition.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ impl JsProblemDefinition {
2020
pub fn new(space: &JsRealVectorStateSpace, start: Vec<f64>, goal: JsGoal) -> Self {
2121
let start_state = RealVectorState::new(start);
2222
let problem_def = ProblemDefinition {
23-
space: space.inner.clone(),
23+
space: Arc::new(space.inner.lock().unwrap().clone()),
2424
start_states: vec![start_state],
2525
goal: Arc::new(goal),
2626
};
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// Copyright (c) 2025 Junior Sundar
2+
//
3+
// SPDX-License-Identifier: BSD-3-Clause
4+
5+
use std::sync::Arc;
6+
7+
use oxmpl::base::state::RealVectorState as OxmplRealVectorState;
8+
use wasm_bindgen::prelude::*;
9+
10+
#[wasm_bindgen(js_name = RealVectorState)]
11+
pub struct JsRealVectorState {
12+
#[wasm_bindgen(skip)]
13+
pub inner: Arc<OxmplRealVectorState>,
14+
}
15+
16+
#[wasm_bindgen(js_class = RealVectorState)]
17+
impl JsRealVectorState {
18+
#[wasm_bindgen(constructor)]
19+
pub fn new(values: Vec<f64>) -> Self {
20+
let state = OxmplRealVectorState::new(values);
21+
Self {
22+
inner: Arc::new(state),
23+
}
24+
}
25+
26+
#[wasm_bindgen(getter)]
27+
pub fn values(&self) -> Vec<f64> {
28+
self.inner.values.clone()
29+
}
30+
}

oxmpl-js/src/base/real_vector_state_space.rs

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,19 @@
33
// SPDX-License-Identifier: BSD-3-Clause
44

55
use oxmpl::base::{
6-
space::{RealVectorStateSpace, StateSpace},
6+
space::{RealVectorStateSpace as OxmplRealVectorStateSpace, StateSpace},
77
state::RealVectorState,
88
};
99
use rand::rng;
10-
use std::sync::Arc;
10+
use std::sync::{Arc, Mutex};
1111
use wasm_bindgen::prelude::*;
1212

13+
use crate::base::JsRealVectorState;
14+
1315
#[wasm_bindgen(js_name = RealVectorStateSpace)]
1416
pub struct JsRealVectorStateSpace {
1517
#[wasm_bindgen(skip)]
16-
pub inner: Arc<RealVectorStateSpace>,
18+
pub inner: Arc<Mutex<OxmplRealVectorStateSpace>>,
1719
}
1820

1921
#[wasm_bindgen(js_class = RealVectorStateSpace)]
@@ -39,30 +41,40 @@ impl JsRealVectorStateSpace {
3941
None
4042
};
4143

42-
match RealVectorStateSpace::new(dimension, bounds_vec) {
44+
match OxmplRealVectorStateSpace::new(dimension, bounds_vec) {
4345
Ok(space) => Ok(Self {
44-
inner: Arc::new(space),
46+
inner: Arc::new(Mutex::new(space)),
4547
}),
4648
Err(e) => Err(e.to_string()),
4749
}
4850
}
4951

50-
pub fn sample(&self) -> Result<Vec<f64>, String> {
52+
#[wasm_bindgen(js_name = sample)]
53+
pub fn sample(&self) -> Result<JsRealVectorState, String> {
5154
let mut rng = rng();
52-
match self.inner.sample_uniform(&mut rng) {
53-
Ok(state) => Ok(state.values),
55+
match self.inner.lock().unwrap().sample_uniform(&mut rng) {
56+
Ok(state) => Ok(JsRealVectorState::new(state.values)),
5457
Err(e) => Err(e.to_string()),
5558
}
5659
}
5760

58-
pub fn distance(&self, state1: Vec<f64>, state2: Vec<f64>) -> f64 {
59-
let s1 = RealVectorState::new(state1);
60-
let s2 = RealVectorState::new(state2);
61-
self.inner.distance(&s1, &s2)
61+
#[wasm_bindgen(js_name = distance)]
62+
pub fn distance(&self, state1: &JsRealVectorState, state2: &JsRealVectorState) -> f64 {
63+
let s1 = RealVectorState::new(state1.inner.values.clone());
64+
let s2 = RealVectorState::new(state2.inner.values.clone());
65+
self.inner.lock().unwrap().distance(&s1, &s2)
6266
}
6367

6468
#[wasm_bindgen(js_name = getDimension)]
6569
pub fn get_dimension(&self) -> usize {
66-
self.inner.dimension
70+
self.inner.lock().unwrap().dimension
71+
}
72+
73+
#[wasm_bindgen(js_name = setLongestValidLineSegmentFraction)]
74+
pub fn set_longest_valid_segment_fraction(&mut self, fraction: f64) {
75+
self.inner
76+
.lock()
77+
.unwrap()
78+
.set_longest_valid_segment_fraction(fraction);
6779
}
6880
}

oxmpl-js/tests/test_prm_integration.test.js

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,21 +85,22 @@ describe('PRM Integration Tests', () => {
8585
let path;
8686
try {
8787
path = planner.solve(timeoutSecs);
88-
console.log(`Solution found with ${path.length()} states.`);
88+
console.log(`Solution found with ${path.getLength()} states.`);
8989
} catch (error) {
9090
throw new Error(`Planner failed to find a solution when one should exist. Error: ${error}`);
9191
}
9292

9393
// VALIDATE THE SOLUTION PATH
9494
const states = path.getStates();
95-
const pathLength = path.length();
95+
const pathLength = path.getLength();
9696

9797
expect(pathLength).toBeGreaterThan(1);
9898
expect(states.length).toBe(pathLength);
9999

100100
// Check start position
101-
const pathStart = states[0];
102-
const startDistance = space.distance(pathStart, startState);
101+
const pathStart = new oxmpl.RealVectorState(states[0]);
102+
const startStateObj = new oxmpl.RealVectorState(startState);
103+
const startDistance = space.distance(pathStart, startStateObj);
103104
expect(startDistance).toBeLessThan(1e-9);
104105

105106
// Check goal is reached

oxmpl-js/tests/test_rrt_connect_integration.test.js

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,21 +84,22 @@ describe('RRT Connect Integration Tests', () => {
8484
let path;
8585
try {
8686
path = planner.solve(timeoutSecs);
87-
console.log(`Solution found with ${path.length()} states.`);
87+
console.log(`Solution found with ${path.getLength()} states.`);
8888
} catch (error) {
8989
throw new Error(`Planner failed to find a solution when one should exist. Error: ${error}`);
9090
}
9191

9292
// VALIDATE THE SOLUTION PATH
9393
const states = path.getStates();
94-
const pathLength = path.length();
94+
const pathLength = path.getLength();
9595

9696
expect(pathLength).toBeGreaterThan(1);
9797
expect(states.length).toBe(pathLength);
9898

9999
// Check start position
100-
const pathStart = states[0];
101-
const startDistance = space.distance(pathStart, startState);
100+
const pathStart = new oxmpl.RealVectorState(states[0]);
101+
const startStateObj = new oxmpl.RealVectorState(startState);
102+
const startDistance = space.distance(pathStart, startStateObj);
102103
expect(startDistance).toBeLessThan(1e-9);
103104

104105
// Check goal is reached

oxmpl-js/tests/test_rrt_integration.test.js

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,21 +84,22 @@ describe('RRT Integration Tests', () => {
8484
let path;
8585
try {
8686
path = planner.solve(timeoutSecs);
87-
console.log(`Solution found with ${path.length()} states.`);
87+
console.log(`Solution found with ${path.getLength()} states.`);
8888
} catch (error) {
8989
throw new Error(`Planner failed to find a solution when one should exist. Error: ${error}`);
9090
}
9191

9292
// VALIDATE THE SOLUTION PATH
9393
const states = path.getStates();
94-
const pathLength = path.length();
94+
const pathLength = path.getLength();
9595

9696
expect(pathLength).toBeGreaterThan(1);
9797
expect(states.length).toBe(pathLength);
9898

9999
// Check start position
100-
const pathStart = states[0];
101-
const startDistance = space.distance(pathStart, startState);
100+
const pathStart = new oxmpl.RealVectorState(states[0]);
101+
const startStateObj = new oxmpl.RealVectorState(startState);
102+
const startDistance = space.distance(pathStart, startStateObj);
102103
expect(startDistance).toBeLessThan(1e-9);
103104

104105
// Check goal is reached

oxmpl-js/tests/test_rrt_star_integration.test.js

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,21 +85,22 @@ describe('RRT* Integration Tests', () => {
8585
let path;
8686
try {
8787
path = planner.solve(timeoutSecs);
88-
console.log(`Solution found with ${path.length()} states.`);
88+
console.log(`Solution found with ${path.getLength()} states.`);
8989
} catch (error) {
9090
throw new Error(`Planner failed to find a solution when one should exist. Error: ${error}`);
9191
}
9292

9393
// VALIDATE THE SOLUTION PATH
9494
const states = path.getStates();
95-
const pathLength = path.length();
95+
const pathLength = path.getLength();
9696

9797
expect(pathLength).toBeGreaterThan(1);
9898
expect(states.length).toBe(pathLength);
9999

100100
// Check start position
101-
const pathStart = states[0];
102-
const startDistance = space.distance(pathStart, startState);
101+
const pathStart = new oxmpl.RealVectorState(states[0]);
102+
const startStateObj = new oxmpl.RealVectorState(startState);
103+
const startDistance = space.distance(pathStart, startStateObj);
103104
expect(startDistance).toBeLessThan(1e-9);
104105

105106
// Check goal is reached

0 commit comments

Comments
 (0)