Skip to content

Commit bbb9363

Browse files
committed
feat(PlannerConfig)!: Implement PlannerConfig to carry planner specific params
* In this case it is to provide a deterministic seed for the rng * Also wrap this inside PlannerConfig
1 parent d427514 commit bbb9363

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+723
-256
lines changed

Cargo.lock

Lines changed: 4 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

oxmpl-js/src/lib.rs

Lines changed: 57 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use js_sys::Float64Array;
88
use oxmpl::base::{
99
error::StateSamplingError,
1010
goal::{Goal, GoalRegion, GoalSampleableRegion},
11-
planner::{Path, Planner},
11+
planner::{Path, Planner, PlannerConfig},
1212
problem_definition::ProblemDefinition,
1313
space::{RealVectorStateSpace, StateSpace},
1414
state::RealVectorState,
@@ -305,17 +305,35 @@ impl JsPath {
305305
}
306306
}
307307

308+
#[wasm_bindgen(js_name = PlannerConfig)]
309+
pub struct JsPlannerConfig {
310+
seed: Option<u64>,
311+
}
312+
313+
impl From<&JsPlannerConfig> for PlannerConfig {
314+
fn from(js_planner_config: &JsPlannerConfig) -> Self {
315+
PlannerConfig {
316+
seed: (js_planner_config.seed),
317+
}
318+
}
319+
}
320+
308321
#[wasm_bindgen(js_name = RRT)]
309322
pub struct JsRRT {
310323
planner: RRT<RealVectorState, RealVectorStateSpace, JsGoal>,
311324
}
312325

313326
#[wasm_bindgen(js_class = RRT)]
314327
impl JsRRT {
328+
/**
329+
* @param {number} max_distance
330+
* @param {number} goal_bias
331+
* @param {JsPlannerConfig} config
332+
*/
315333
#[wasm_bindgen(constructor)]
316-
pub fn new(max_distance: f32, goal_bias: f32) -> Self {
334+
pub fn new(max_distance: f32, goal_bias: f32, config: &JsPlannerConfig) -> Self {
317335
Self {
318-
planner: RRT::new(max_distance as f64, goal_bias as f64),
336+
planner: RRT::new(max_distance as f64, goal_bias as f64, &config.into()),
319337
}
320338
}
321339

@@ -345,10 +363,15 @@ pub struct JsRRTConnect {
345363

346364
#[wasm_bindgen(js_class = RRTConnect)]
347365
impl JsRRTConnect {
366+
/**
367+
* @param {number} max_distance
368+
* @param {number} goal_bias
369+
* @param {JsPlannerConfig} config
370+
*/
348371
#[wasm_bindgen(constructor)]
349-
pub fn new(max_distance: f32, goal_bias: f32) -> Self {
372+
pub fn new(max_distance: f32, goal_bias: f32, config: &JsPlannerConfig) -> Self {
350373
Self {
351-
planner: RRTConnect::new(max_distance as f64, goal_bias as f64),
374+
planner: RRTConnect::new(max_distance as f64, goal_bias as f64, &config.into()),
352375
}
353376
}
354377

@@ -378,10 +401,26 @@ pub struct JsRRTStar {
378401

379402
#[wasm_bindgen(js_class = RRTStar)]
380403
impl JsRRTStar {
404+
/**
405+
* @param {number} max_distance
406+
* @param {number} goal_bias
407+
* @param {number} search_radius
408+
* @param {JsPlannerConfig} config
409+
*/
381410
#[wasm_bindgen(constructor)]
382-
pub fn new(max_distance: f32, goal_bias: f32, search_radius: f32) -> Self {
411+
pub fn new(
412+
max_distance: f32,
413+
goal_bias: f32,
414+
search_radius: f32,
415+
config: &JsPlannerConfig,
416+
) -> Self {
383417
Self {
384-
planner: RRTStar::new(max_distance as f64, goal_bias as f64, search_radius as f64),
418+
planner: RRTStar::new(
419+
max_distance as f64,
420+
goal_bias as f64,
421+
search_radius as f64,
422+
&config.into(),
423+
),
385424
}
386425
}
387426

@@ -411,10 +450,19 @@ pub struct JsPRM {
411450

412451
#[wasm_bindgen(js_class = PRM)]
413452
impl JsPRM {
453+
/**
454+
* @param {number} timeout_secs
455+
* @param {number} connection_radius
456+
* @param {JsPlannerConfig} config
457+
*/
414458
#[wasm_bindgen(constructor)]
415-
pub fn new(timeout_secs: f32, connection_radius: f32) -> Self {
459+
pub fn new(timeout_secs: f32, connection_radius: f32, config: &JsPlannerConfig) -> Self {
416460
Self {
417-
planner: PRM::new(timeout_secs.into(), connection_radius as f64),
461+
planner: PRM::new(
462+
timeout_secs.into(),
463+
connection_radius as f64,
464+
&config.into(),
465+
),
418466
}
419467
}
420468

oxmpl-py/src/base/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ mod compound_state;
88
mod compound_state_space;
99
mod goal;
1010
mod path;
11+
mod planner;
1112
mod problem_definition;
1213
mod py_state_convert;
1314
mod real_vector_state;
@@ -26,6 +27,7 @@ pub use compound_state::PyCompoundState;
2627
pub use compound_state_space::PyCompoundStateSpace;
2728
pub use goal::PyGoal;
2829
pub use path::PyPath;
30+
pub use planner::PyPlannerConfig;
2931
pub use problem_definition::ProblemDefinitionVariant;
3032
pub use problem_definition::PyProblemDefinition;
3133
pub use real_vector_state::PyRealVectorState;
@@ -55,6 +57,7 @@ pub fn create_module(_py: Python<'_>) -> PyResult<Bound<'_, PyModule>> {
5557
base_module.add_class::<PySO3State>()?;
5658
base_module.add_class::<PySO3StateSpace>()?;
5759
base_module.add_class::<PyPath>()?;
60+
base_module.add_class::<PyPlannerConfig>()?;
5861
base_module.add_class::<PyProblemDefinition>()?;
5962
Ok(base_module)
6063
}

oxmpl-py/src/base/planner.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
use std::rc::Rc;
2+
3+
use oxmpl::base::planner::PlannerConfig as OxmplPlannerConfig;
4+
use pyo3::prelude::*;
5+
6+
#[pyclass(name = "PlannerConfig", unsendable)]
7+
pub struct PyPlannerConfig(pub Rc<OxmplPlannerConfig>);
8+
9+
#[pymethods]
10+
impl PyPlannerConfig {
11+
#[new]
12+
#[pyo3(signature = (seed=None))]
13+
fn new(seed: Option<u64>) -> Self {
14+
let planner_config = OxmplPlannerConfig { seed };
15+
Self(Rc::new(planner_config))
16+
}
17+
18+
#[getter]
19+
fn get_seed(&self) -> Option<u64> {
20+
self.0.seed
21+
}
22+
23+
fn __repr__(&self) -> String {
24+
format!("<PlannerConfig seed={:?}>", self.0.seed)
25+
}
26+
}

oxmpl-py/src/geometric/prm.rs

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ use pyo3::prelude::*;
66
use std::{cell::RefCell, rc::Rc, sync::Arc, time::Duration};
77

88
use crate::base::{
9-
ProblemDefinitionVariant, PyGoal, PyPath, PyProblemDefinition, PyStateValidityChecker,
9+
ProblemDefinitionVariant, PyGoal, PyPath, PyPlannerConfig, PyProblemDefinition,
10+
PyStateValidityChecker,
1011
};
1112
use oxmpl::{
1213
base::{
@@ -46,52 +47,66 @@ pub struct PyPrm {
4647
impl PyPrm {
4748
/// Creates a new PRM planner instance.
4849
///
50+
/// Args:
51+
/// timeout (float): The time in seconds to spend building the roadmap.
52+
/// connection_radius (float): The radius for connecting new nodes to the roadmap.
53+
/// problem_definition (ProblemDefinition): The problem definition.
54+
/// planner_config (PlannerConfig): The planner configuration with planner specific
55+
/// parameters.
56+
///
4957
/// The constructor inspects the `problem_definition` to determine which
5058
/// underlying state space to use (e.g., RealVectorStateSpace, SO2StateSpace).
5159
#[new]
5260
fn new(
5361
timeout: f64,
5462
connection_radius: f64,
5563
problem_definition: &PyProblemDefinition,
64+
planner_config: &PyPlannerConfig,
5665
) -> PyResult<Self> {
5766
let (planner, pd) = match &problem_definition.0 {
5867
ProblemDefinitionVariant::RealVector(pd) => {
59-
let planner_instance = PrmForRealVector::new(timeout, connection_radius);
68+
let planner_instance =
69+
PrmForRealVector::new(timeout, connection_radius, &planner_config.0);
6070
(
6171
PlannerVariant::RealVector(Rc::new(RefCell::new(planner_instance))),
6272
ProblemDefinitionVariant::RealVector(pd.clone()),
6373
)
6474
}
6575
ProblemDefinitionVariant::SO2(pd) => {
66-
let planner_instance = PrmForSO2::new(timeout, connection_radius);
76+
let planner_instance =
77+
PrmForSO2::new(timeout, connection_radius, &planner_config.0);
6778
(
6879
PlannerVariant::SO2(Rc::new(RefCell::new(planner_instance))),
6980
ProblemDefinitionVariant::SO2(pd.clone()),
7081
)
7182
}
7283
ProblemDefinitionVariant::SO3(pd) => {
73-
let planner_instance = PrmForSO3::new(timeout, connection_radius);
84+
let planner_instance =
85+
PrmForSO3::new(timeout, connection_radius, &planner_config.0);
7486
(
7587
PlannerVariant::SO3(Rc::new(RefCell::new(planner_instance))),
7688
ProblemDefinitionVariant::SO3(pd.clone()),
7789
)
7890
}
7991
ProblemDefinitionVariant::Compound(pd) => {
80-
let planner_instance = PrmForCompound::new(timeout, connection_radius);
92+
let planner_instance =
93+
PrmForCompound::new(timeout, connection_radius, &planner_config.0);
8194
(
8295
PlannerVariant::Compound(Rc::new(RefCell::new(planner_instance))),
8396
ProblemDefinitionVariant::Compound(pd.clone()),
8497
)
8598
}
8699
ProblemDefinitionVariant::SE2(pd) => {
87-
let planner_instance = PrmForSE2::new(timeout, connection_radius);
100+
let planner_instance =
101+
PrmForSE2::new(timeout, connection_radius, &planner_config.0);
88102
(
89103
PlannerVariant::SE2(Rc::new(RefCell::new(planner_instance))),
90104
ProblemDefinitionVariant::SE2(pd.clone()),
91105
)
92106
}
93107
ProblemDefinitionVariant::SE3(pd) => {
94-
let planner_instance = PrmForSE3::new(timeout, connection_radius);
108+
let planner_instance =
109+
PrmForSE3::new(timeout, connection_radius, &planner_config.0);
95110
(
96111
PlannerVariant::SE3(Rc::new(RefCell::new(planner_instance))),
97112
ProblemDefinitionVariant::SE3(pd.clone()),

oxmpl-py/src/geometric/rrt.rs

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ use pyo3::prelude::*;
66
use std::{cell::RefCell, rc::Rc, sync::Arc, time::Duration};
77

88
use crate::base::{
9-
ProblemDefinitionVariant, PyGoal, PyPath, PyProblemDefinition, PyStateValidityChecker,
9+
ProblemDefinitionVariant, PyGoal, PyPath, PyPlannerConfig, PyProblemDefinition,
10+
PyStateValidityChecker,
1011
};
1112
use oxmpl::{
1213
base::{
@@ -46,52 +47,62 @@ pub struct PyRrt {
4647
impl PyRrt {
4748
/// Creates a new RRT planner instance.
4849
///
50+
/// Args:
51+
/// max_distance (float): The maximum length of a single branch in the tree.
52+
/// goal_bias (float): The probability (0.0 to 1.0) of sampling the goal.
53+
/// problem_definition (ProblemDefinition): The problem definition.
54+
/// planner_config (PlannerConfig): The planner configuration with planner specific
55+
/// parameters.
56+
///
4957
/// The constructor inspects the `problem_definition` to determine which
5058
/// underlying state space to use (e.g., RealVectorStateSpace, SO2StateSpace).
5159
#[new]
5260
fn new(
5361
max_distance: f64,
5462
goal_bias: f64,
5563
problem_definition: &PyProblemDefinition,
64+
planner_config: &PyPlannerConfig,
5665
) -> PyResult<Self> {
5766
let (planner, pd) = match &problem_definition.0 {
5867
ProblemDefinitionVariant::RealVector(pd) => {
59-
let planner_instance = RrtForRealVector::new(max_distance, goal_bias);
68+
let planner_instance =
69+
RrtForRealVector::new(max_distance, goal_bias, &planner_config.0);
6070
(
6171
PlannerVariant::RealVector(Rc::new(RefCell::new(planner_instance))),
6272
ProblemDefinitionVariant::RealVector(pd.clone()),
6373
)
6474
}
6575
ProblemDefinitionVariant::SO2(pd) => {
66-
let planner_instance = RrtForSO2::new(max_distance, goal_bias);
76+
let planner_instance = RrtForSO2::new(max_distance, goal_bias, &planner_config.0);
6777
(
6878
PlannerVariant::SO2(Rc::new(RefCell::new(planner_instance))),
6979
ProblemDefinitionVariant::SO2(pd.clone()),
7080
)
7181
}
7282
ProblemDefinitionVariant::SO3(pd) => {
73-
let planner_instance = RrtForSO3::new(max_distance, goal_bias);
83+
let planner_instance = RrtForSO3::new(max_distance, goal_bias, &planner_config.0);
7484
(
7585
PlannerVariant::SO3(Rc::new(RefCell::new(planner_instance))),
7686
ProblemDefinitionVariant::SO3(pd.clone()),
7787
)
7888
}
7989
ProblemDefinitionVariant::Compound(pd) => {
80-
let planner_instance = RrtForCompound::new(max_distance, goal_bias);
90+
let planner_instance =
91+
RrtForCompound::new(max_distance, goal_bias, &planner_config.0);
8192
(
8293
PlannerVariant::Compound(Rc::new(RefCell::new(planner_instance))),
8394
ProblemDefinitionVariant::Compound(pd.clone()),
8495
)
8596
}
8697
ProblemDefinitionVariant::SE2(pd) => {
87-
let planner_instance = RrtForSE2::new(max_distance, goal_bias);
98+
let planner_instance = RrtForSE2::new(max_distance, goal_bias, &planner_config.0);
8899
(
89100
PlannerVariant::SE2(Rc::new(RefCell::new(planner_instance))),
90101
ProblemDefinitionVariant::SE2(pd.clone()),
91102
)
92103
}
93104
ProblemDefinitionVariant::SE3(pd) => {
94-
let planner_instance = RrtForSE3::new(max_distance, goal_bias);
105+
let planner_instance = RrtForSE3::new(max_distance, goal_bias, &planner_config.0);
95106
(
96107
PlannerVariant::SE3(Rc::new(RefCell::new(planner_instance))),
97108
ProblemDefinitionVariant::SE3(pd.clone()),

0 commit comments

Comments
 (0)