Skip to content

Commit f1f9465

Browse files
committed
Refactor variable usage and statistics collection
Signed-off-by: Carles Pey <[email protected]>
1 parent 80247ca commit f1f9465

File tree

1 file changed

+64
-54
lines changed

1 file changed

+64
-54
lines changed

chipsec/modules/tools/smm/smm_ptr.py

+64-54
Original file line numberDiff line numberDiff line change
@@ -199,21 +199,54 @@ def get_info(self):
199199
return f'duration {self.duration} code {self.code:02X} data {self.data:02X} ({gprs_info(self.gprs)})'
200200

201201

202+
class smi_stats:
203+
def __init__(self):
204+
self.clear()
205+
206+
def clear(self):
207+
self.count = 0
208+
self.mean = 0
209+
self.m2 = 0
210+
self.stdev = 0
211+
self.outliers = 0
212+
213+
#
214+
# Computes the standard deviation using the Welford's online algorithm
215+
#
216+
def update_stats(self, duration):
217+
self.count += 1
218+
difference = duration - self.mean
219+
self.mean += difference / self.count
220+
self.m2 += difference * (duration - self.mean)
221+
variance = self.m2 / self.count
222+
self.stdev = math.sqrt(variance)
223+
224+
def get_info(self):
225+
info = f'average {round(self.mean)} stddev {self.stdev:.2f} checked {self.count}'
226+
return info
227+
228+
#
229+
# Combines the statistics of the two data sets using parallel variance computation
230+
#
231+
def combine(self, partial):
232+
self.outliers += partial.outliers
233+
total_count = self.count + partial.count
234+
difference = partial.mean - self.mean
235+
self.mean = (self.mean * self.count + partial.mean * partial.count) / total_count
236+
self.m2 += partial.m2 + difference**2 * self.count * partial.count / total_count
237+
self.count = total_count
238+
variance = self.m2 / self.count
239+
self.stdev = math.sqrt(variance)
240+
241+
202242
class scan_track:
203243
def __init__(self):
244+
self.current_smi_stats = smi_stats()
245+
self.history_smi_stats = smi_stats()
204246
self.clear()
205-
self.hist_smi_duration = 0
206-
self.hist_smi_num = 0
207-
self.outliers_hist = 0
208247
self.helper = OsHelper().get_default_helper()
209248
self.helper.init()
210249
self.smi_count = self.get_smi_count()
211-
self.needs_calibration = True
212-
self.calib_samples = 0
213-
self.stdev = 0
214-
self.m2 = 0
215-
self.stdev_hist = 0
216-
self.m2_hist = 0
217250

218251
def __del__(self):
219252
self.helper.close()
@@ -251,73 +284,47 @@ def find_address_in_regs(self, gprs):
251284
return key
252285

253286
def clear(self):
254-
self.max = smi_info(0)
255-
self.min = smi_info(2**32 - 1)
256287
self.outlier = smi_info(0)
257-
self.avg_smi_duration = 0
258-
self.avg_smi_num = 0
259-
self.outliers = 0
260288
self.code = None
261-
self.confirmed = False
289+
self.contents_changed = False
262290
self.needs_calibration = True
263291
self.calib_samples = 0
264-
self.stdev = 0
265-
self.m2 = 0
292+
self.current_smi_stats.clear()
266293

267-
def add(self, duration, code, data, gprs, confirmed=False):
294+
def add(self, duration, code, data, gprs, contents_changed=False):
268295
if not self.code:
269296
self.code = code
270297
outlier = self.is_outlier(duration)
271298
if not outlier:
272-
self.update_stdev(duration)
273-
if duration > self.max.duration:
274-
self.max.update(duration, code, data, gprs.copy())
275-
elif duration < self.min.duration:
276-
self.min.update(duration, code, data, gprs.copy())
299+
self.current_smi_stats.update_stats(duration)
277300
elif self.is_slow_outlier(duration):
278-
self.outliers += 1
279-
self.outliers_hist += 1
301+
self.current_smi_stats.outliers += 1
280302
self.outlier.update(duration, code, data, gprs.copy())
281-
self.confirmed = confirmed
282-
283-
#
284-
# Computes the standard deviation using the Welford's online algorithm
285-
#
286-
def update_stdev(self, value):
287-
self.avg_smi_num += 1
288-
self.hist_smi_num += 1
289-
difference = value - self.avg_smi_duration
290-
difference_hist = value - self.hist_smi_duration
291-
self.avg_smi_duration += difference / self.avg_smi_num
292-
self.hist_smi_duration += difference_hist / self.hist_smi_num
293-
self.m2 += difference * (value - self.avg_smi_duration)
294-
self.m2_hist += difference_hist * (value - self.hist_smi_duration)
295-
variance = self.m2 / self.avg_smi_num
296-
variance_hist = self.m2_hist / self.hist_smi_num
297-
self.stdev = math.sqrt(variance)
298-
self.stdev_hist = math.sqrt(variance_hist)
303+
self.contents_changed = contents_changed
299304

300305
def update_calibration(self, duration):
301306
if not self.needs_calibration:
302307
return
303-
self.update_stdev(duration)
308+
self.current_smi_stats.update_stats(duration)
304309
self.calib_samples += 1
305310
if self.calib_samples >= SCAN_CALIB_SAMPLES:
306311
self.needs_calibration = False
307312

308313
def is_slow_outlier(self, value):
309314
ret = False
310-
if value > self.avg_smi_duration + OUTLIER_STD_DEV * self.stdev:
315+
if value > self.current_smi_stats.mean + OUTLIER_STD_DEV * self.current_smi_stats.stdev:
311316
ret = True
312-
if value > self.hist_smi_duration + OUTLIER_STD_DEV * self.stdev_hist:
317+
if self.history_smi_stats.count and \
318+
value > self.history_smi_stats.mean + OUTLIER_STD_DEV * self.history_smi_stats.stdev:
313319
ret = True
314320
return ret
315321

316322
def is_fast_outlier(self, value):
317323
ret = False
318-
if value < self.avg_smi_duration - OUTLIER_STD_DEV * self.stdev:
324+
if value < self.current_smi_stats.mean - OUTLIER_STD_DEV * self.current_smi_stats.stdev:
319325
ret = True
320-
if value < self.hist_smi_duration - OUTLIER_STD_DEV * self.stdev_hist:
326+
if self.history_smi_stats.count and \
327+
value < self.history_smi_stats.mean - OUTLIER_STD_DEV * self.history_smi_stats.stdev:
321328
ret = True
322329
return ret
323330

@@ -332,18 +339,17 @@ def is_outlier(self, value):
332339
return ret
333340

334341
def skip(self):
335-
return self.outliers or self.confirmed
342+
return self.current_smi_stats.outliers or self.contents_changed
336343

337344
def found_outlier(self):
338-
return bool(self.outliers)
345+
return bool(self.current_smi_stats.outliers)
339346

340347
def get_total_outliers(self):
341-
return self.outliers_hist
348+
return self.history_smi_stats.outliers
342349

343350
def get_info(self):
344-
avg = self.avg_smi_duration or self.hist_smi_duration
345-
info = f'average {round(avg)} stddev {self.stdev:.2f} checked {self.avg_smi_num + self.outliers}'
346-
if self.outliers:
351+
info = self.current_smi_stats.get_info()
352+
if self.current_smi_stats.outliers:
347353
info += f'\n Identified outlier: {self.outlier.get_info()}'
348354
return info
349355

@@ -354,6 +360,9 @@ def log_smi_result(self, logger):
354360
else:
355361
logger.log(f'[*] {msg}')
356362

363+
def update_history_stats(self):
364+
self.history_smi_stats.combine(self.current_smi_stats)
365+
357366

358367
class smi_desc:
359368
def __init__(self):
@@ -699,6 +708,7 @@ def test_fuzz(self, thread_id, smic_start, smic_end, _addr, _addr1, scan_mode=Fa
699708
break
700709
if scan_mode:
701710
scan.log_smi_result(self.logger)
711+
scan.update_history_stats()
702712
scan.clear()
703713

704714
return bad_ptr_cnt, scan

0 commit comments

Comments
 (0)