Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion extensions/percentile/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use turso_ext::{register_extension, AggFunc, AggregateDerive, Value};

register_extension! {
aggregates: { Median, Percentile, PercentileCont, PercentileDisc }
aggregates: { Median, Percentile, PercentileCont, PercentileDisc, StandardDeviation }
}

#[derive(AggregateDerive)]
Expand Down Expand Up @@ -191,3 +191,47 @@ impl AggFunc for PercentileDisc {
Ok(Value::from_float(values[index]))
}
}

/// Standard Deviation implementation using Welford's algorithm
/// Formula:
///
/// s = sqrt( M2 / (n - 1) )
///
/// Where:
/// - `n` = number of observations
/// - `M2` = sum of squared deviations
#[derive(AggregateDerive)]
struct StandardDeviation;

impl AggFunc for StandardDeviation {
type State = (u64, f64, f64); // Tracks the count, mean and sum of squared differences from the mean
type Error = &'static str;
const NAME: &'static str = "stddev";
const ARGS: i32 = 1;

fn step(state: &mut Self::State, args: &[Value]) {
let (count, mean, m2) = state;

if let Some(x) = args.first().and_then(Value::to_float) {
*count += 1;

// compute deviation from old mean
let delta = x - *mean;
*mean += delta / *count as f64;

// update sum of squared differences
let delta_2 = x - *mean;
*m2 += delta * delta_2;
}
}

fn finalize(state: Self::State) -> Result<Value, Self::Error> {
let (count, _mean, m2) = state;
if count < 2 {
return Ok(Value::null());
}

let variance = m2 / (count - 1) as f64;
Ok(Value::from_float(variance.sqrt()))
}
}
24 changes: 24 additions & 0 deletions testing/cli_tests/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,25 @@ def test_aggregates():
)
turso.run_test_fn("SELECT percentile_cont(value, 0.25) from test;", validate_percentile1)
turso.run_test_fn("SELECT percentile_disc(value, 0.55) from test;", validate_percentile_disc)

turso.run_test_fn(
"SELECT stddev(value) from test;", lambda res: res == "21.6024689946929", "stddev aggregate works on test table"
)
turso.run_test_fn(
"select stddev(value) from numbers;",
lambda res: res == "2.44948974278318",
"stddev aggregate works on numbers table",
)
turso.run_test_fn(
"select stddev(value) from (select value from test limit 1);",
null,
"stddev returns null with < 2 rows",
)
turso.run_test_fn(
"select stddev(percent) from (select percent from test limit 2);",
lambda res: res == "0.0",
"stddev aggregate works on 2 rows",
)
turso.quit()


Expand Down Expand Up @@ -214,6 +233,11 @@ def test_grouped_aggregates():
lambda res: "10.0\n30.0\n50.0\n70.0" == res,
"grouped aggregate percentile_disc function works",
)
turso.run_test_fn(
"SELECT stddev(value) FROM test GROUP BY category HAVING COUNT(*) >= 2 ORDER BY category;",
lambda res: res == "7.07106781186548\n10.0",
"grouped stddev aggregate function works",
)
turso.quit()


Expand Down
Loading