Skip to content

Commit a4165da

Browse files
authored
Merge pull request #173 from nasa/feature/direct_access_params_alg
Add direct access
2 parents 8c3d540 + 3a3c461 commit a4165da

File tree

4 files changed

+59
-1
lines changed

4 files changed

+59
-1
lines changed

src/progpy/predictors/predictor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,9 @@ def predict(self, state: UncertainData, future_loading_eqn: Callable, **kwargs)
7474
* time_of_event (UncertainData): Distribution of predicted Time of Event (ToE) for each predicted event, represented by some subclass of UncertaintData (e.g., MultivariateNormalDist)
7575
"""
7676
pass
77+
78+
def __getitem__(self, arg):
79+
return self.parameters[arg]
80+
81+
def __setitem__(self, key, value):
82+
self.parameters[key] = value

src/progpy/state_estimators/state_estimator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,9 @@ def x(self) -> UncertainData:
111111
-------
112112
state = filt.x
113113
"""
114+
115+
def __getitem__(self, arg):
116+
return self.parameters[arg]
117+
118+
def __setitem__(self, key, value):
119+
self.parameters[key] = value

tests/test_predictors.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,12 @@ def test_UTP_ThrownObject(self):
8282
self.assertAlmostEqual(results.time_of_event.mean['falling'], 4.15, 0)
8383
# self.assertAlmostEqual(mc_results.times[-1], 9, 1) # Saving every second, last time should be around the 1s after impact event (because one of the sigma points fails afterwards)
8484

85+
# Test setting dt at class level (otherwise default of 1 will be used and this wont work)
86+
pred['dt'] = 0.01
87+
results = pred.predict(samples, save_freq=1)
88+
self.assertAlmostEqual(results.time_of_event.mean['impact'], 8.21, 0)
89+
self.assertAlmostEqual(results.time_of_event.mean['falling'], 4.15, 0)
90+
8591
# Setting event manually
8692
results = pred.predict(samples, dt=0.01, save_freq=1, events=['falling'])
8793
self.assertAlmostEqual(results.time_of_event.mean['falling'], 3.8, 5)

tests/test_state_estimators.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,48 @@ def __test_state_est(self, filt, m):
115115
# should be close to right
116116
self.assertAlmostEqual(x_est[key], x[key], delta=0.4)
117117

118+
def __test_state_est_no_dt(self, filt, m):
119+
x = m.initialize()
120+
filt['dt'] = 0.2
121+
122+
self.assertTrue(all(key in filt.x.mean for key in m.states))
123+
124+
# run for a while
125+
dt = 0.2
126+
u = m.InputContainer({})
127+
last_time = 0
128+
for i in range(500):
129+
# Get simulated output (would be measured in a real application)
130+
x = m.next_state(x, u, dt)
131+
z = m.output(x)
132+
133+
# Estimate New State every few steps
134+
if i % 8 == 0:
135+
# This is to test dt setting at the estimator lvl
136+
# Without dt, this would fail
137+
last_time = (i+1)*dt
138+
filt.estimate((i+1)*dt, u, z)
139+
140+
if last_time != (i+1)*dt:
141+
# Final estimate
142+
filt.estimate((i+1)*dt, u, z)
143+
144+
# Check results - make sure it converged
145+
x_est = filt.x.mean
146+
for key in m.states:
147+
# should be close to right
148+
self.assertAlmostEqual(x_est[key], x[key], delta=0.4)
149+
118150
def test_UKF(self):
119151
m = ThrownObject(process_noise=5e-2, measurement_noise=5e-2)
120152
x_guess = {'x': 1.75, 'v': 35} # Guess of initial state, actual is {'x': 1.83, 'v': 40}
121153

122154
filt = UnscentedKalmanFilter(m, x_guess)
123155
self.__test_state_est(filt, m)
124156

157+
filt = UnscentedKalmanFilter(m, x_guess)
158+
self.__test_state_est_no_dt(filt, m)
159+
125160
m = ThrownObject(process_noise=5e-2, measurement_noise=5e-2)
126161

127162
# Test UnscentedKalmanFilter ScalarData
@@ -322,6 +357,9 @@ def test_PF(self):
322357
filt = ParticleFilter(m, x_guess, num_particles = 1000, measurement_noise = {'x': 1})
323358
self.__test_state_est(filt, m)
324359

360+
filt = ParticleFilter(m, x_guess, num_particles = 1000, measurement_noise = {'x': 1})
361+
self.__test_state_est_no_dt(filt, m)
362+
325363
# Test ParticleFilter ScalarData
326364
x_scalar = ScalarData({'x': 1.75, 'v': 38.5})
327365
filt_scalar = ParticleFilter(m, x_scalar, num_particles = 20) # Sample count does not affect ScalarData testing
@@ -438,9 +476,11 @@ def event_state(self, x):
438476
x_guess = {'x': 1.75, 'v': 35} # Guess of initial state, actual is {'x': 1.83, 'v': 40}
439477

440478
filt = KalmanFilter(m, x_guess)
441-
442479
self.__test_state_est(filt, m)
443480

481+
filt = KalmanFilter(m, x_guess)
482+
self.__test_state_est_no_dt(filt, m)
483+
444484
m = ThrownObject(process_noise=5e-2, measurement_noise=5e-2)
445485

446486
# Test KalmanFilter ScalarData

0 commit comments

Comments
 (0)