Skip to content

Commit b0a1d1e

Browse files
authored
Add STDEV() aggregate function (#1614)
Add a new aggregate function `STDEV(X)` which computes the (sample) standard deviation, such that a user will not have to repetitively type `math:sqrt(sum(math:pow((X - avg(X)), 2)) / (count(*) - 1))`. This is not part of the SPARQL standard, but also doesn't cause any conflicts.
1 parent 155718d commit b0a1d1e

21 files changed

+2755
-2420
lines changed

src/engine/GroupBy.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "engine/sparqlExpressions/SampleExpression.h"
2121
#include "engine/sparqlExpressions/SparqlExpression.h"
2222
#include "engine/sparqlExpressions/SparqlExpressionGenerators.h"
23+
#include "engine/sparqlExpressions/StdevExpression.h"
2324
#include "global/RuntimeParameters.h"
2425
#include "index/Index.h"
2526
#include "index/IndexImpl.h"
@@ -1026,6 +1027,8 @@ GroupBy::isSupportedAggregate(sparqlExpression::SparqlExpression* expr) {
10261027
if (auto val = dynamic_cast<GroupConcatExpression*>(expr)) {
10271028
return H{GROUP_CONCAT, val->getSeparator()};
10281029
}
1030+
// NOTE: The STDEV function is not suitable for lazy and hash map
1031+
// optimizations.
10291032
if (dynamic_cast<SampleExpression*>(expr)) return H{SAMPLE};
10301033

10311034
// `expr` is an unsupported aggregate

src/engine/sparqlExpressions/AggregateExpression.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "engine/sparqlExpressions/AggregateExpression.h"
77

88
#include "engine/sparqlExpressions/GroupConcatExpression.h"
9+
#include "engine/sparqlExpressions/StdevExpression.h"
910

1011
namespace sparqlExpression::detail {
1112

@@ -180,6 +181,11 @@ AggregateExpression<AggregateOperation, FinalOperation>::getVariableForCount()
180181
// Explicit instantiation for the AVG expression.
181182
template class AggregateExpression<AvgOperation, decltype(avgFinalOperation)>;
182183

184+
// Explicit instantiation for the STDEV expression.
185+
template class AggregateExpression<AvgOperation, decltype(stdevFinalOperation)>;
186+
template class DeviationAggExpression<AvgOperation,
187+
decltype(stdevFinalOperation)>;
188+
183189
// Explicit instantiations for the other aggregate expressions.
184190
#define INSTANTIATE_AGG_EXP(Function, ValueGetter) \
185191
template class AggregateExpression< \

src/engine/sparqlExpressions/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ add_library(sparqlExpressions
66
SampleExpression.cpp
77
RelationalExpressions.cpp
88
AggregateExpression.cpp
9+
StdevExpression.cpp
910
RegexExpression.cpp
1011
NumericUnaryExpressions.cpp
1112
NumericBinaryExpressions.cpp
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
// Copyright 2024, University of Freiburg,
2+
// Chair of Algorithms and Data Structures.
3+
// Author: Christoph Ullinger <[email protected]>
4+
5+
#include "engine/sparqlExpressions/StdevExpression.h"
6+
7+
#include "engine/sparqlExpressions/SparqlExpressionTypes.h"
8+
9+
namespace sparqlExpression::detail {
10+
11+
// _____________________________________________________________________________
12+
ExpressionResult DeviationExpression::evaluate(
13+
EvaluationContext* context) const {
14+
// Helper: Extracts a double or int (as double) from a variant
15+
auto numValVisitor = []<typename T>(const T& value) -> std::optional<double> {
16+
if constexpr (ad_utility::isSimilar<T, double> ||
17+
ad_utility::isSimilar<T, int64_t>) {
18+
return static_cast<double>(value);
19+
} else {
20+
return std::nullopt;
21+
}
22+
};
23+
24+
// Helper to replace child expression results with their squared deviation
25+
auto devImpl = [context, numValVisitor](
26+
bool& undef,
27+
VectorWithMemoryLimit<IdOrLiteralOrIri>& exprResult,
28+
auto generator) {
29+
double sum = 0.0;
30+
// Intermediate storage of the results returned from the child
31+
// expression
32+
VectorWithMemoryLimit<double> childResults{context->_allocator};
33+
34+
// Collect values as doubles
35+
for (auto& inp : generator) {
36+
const auto& n = detail::NumericValueGetter{}(std::move(inp), context);
37+
auto v = std::visit(numValVisitor, n);
38+
if (v.has_value()) {
39+
childResults.push_back(v.value());
40+
sum += v.value();
41+
} else {
42+
// There is a non-numeric value in the input. Therefore the entire
43+
// result will be undef.
44+
undef = true;
45+
return;
46+
}
47+
context->cancellationHandle_->throwIfCancelled();
48+
}
49+
50+
// Calculate squared deviation and save for result
51+
double avg = sum / static_cast<double>(context->size());
52+
for (size_t i = 0; i < childResults.size(); i++) {
53+
exprResult.at(i) = IdOrLiteralOrIri{
54+
ValueId::makeFromDouble(std::pow(childResults.at(i) - avg, 2))};
55+
}
56+
};
57+
58+
// Visitor for child expression result
59+
auto impl = [context,
60+
devImpl](SingleExpressionResult auto&& el) -> ExpressionResult {
61+
// Prepare space for result
62+
VectorWithMemoryLimit<IdOrLiteralOrIri> exprResult{context->_allocator};
63+
exprResult.resize(context->size());
64+
bool undef = false;
65+
66+
auto generator =
67+
detail::makeGenerator(AD_FWD(el), context->size(), context);
68+
devImpl(undef, exprResult, std::move(generator));
69+
70+
if (undef) {
71+
return IdOrLiteralOrIri{Id::makeUndefined()};
72+
}
73+
return exprResult;
74+
};
75+
76+
auto childRes = child_->evaluate(context);
77+
return std::visit(impl, std::move(childRes));
78+
};
79+
80+
} // namespace sparqlExpression::detail
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
// Copyright 2024, University of Freiburg,
2+
// Chair of Algorithms and Data Structures.
3+
// Author: Christoph Ullinger <[email protected]>
4+
5+
#pragma once
6+
7+
#include <cmath>
8+
#include <functional>
9+
#include <memory>
10+
#include <variant>
11+
12+
#include "engine/sparqlExpressions/AggregateExpression.h"
13+
#include "engine/sparqlExpressions/LiteralExpression.h"
14+
#include "engine/sparqlExpressions/NaryExpression.h"
15+
#include "engine/sparqlExpressions/SparqlExpression.h"
16+
#include "engine/sparqlExpressions/SparqlExpressionTypes.h"
17+
#include "engine/sparqlExpressions/SparqlExpressionValueGetters.h"
18+
#include "global/ValueId.h"
19+
20+
namespace sparqlExpression {
21+
22+
namespace detail {
23+
24+
/// The STDEV Expression
25+
26+
// Helper expression: The individual deviation squares. A DeviationExpression
27+
// over X corresponds to the value (X - AVG(X))^2.
28+
class DeviationExpression : public SparqlExpression {
29+
private:
30+
Ptr child_;
31+
32+
public:
33+
explicit DeviationExpression(Ptr&& child) : child_{std::move(child)} {}
34+
35+
// __________________________________________________________________________
36+
ExpressionResult evaluate(EvaluationContext* context) const override;
37+
38+
// __________________________________________________________________________
39+
AggregateStatus isAggregate() const override {
40+
return SparqlExpression::AggregateStatus::NoAggregate;
41+
}
42+
43+
// __________________________________________________________________________
44+
[[nodiscard]] string getCacheKey(
45+
const VariableToColumnMap& varColMap) const override {
46+
return absl::StrCat("[ SQ.DEVIATION ]", child_->getCacheKey(varColMap));
47+
}
48+
49+
private:
50+
// _________________________________________________________________________
51+
std::span<SparqlExpression::Ptr> childrenImpl() override {
52+
return {&child_, 1};
53+
}
54+
};
55+
56+
// Separate subclass of AggregateOperation, that replaces its child with a
57+
// DeviationExpression of this child. Everything else is left untouched.
58+
template <typename AggregateOperation,
59+
typename FinalOperation = decltype(identity)>
60+
class DeviationAggExpression
61+
: public AggregateExpression<AggregateOperation, FinalOperation> {
62+
public:
63+
// __________________________________________________________________________
64+
DeviationAggExpression(bool distinct, SparqlExpression::Ptr&& child,
65+
AggregateOperation aggregateOp = AggregateOperation{})
66+
: AggregateExpression<AggregateOperation, FinalOperation>(
67+
distinct, std::make_unique<DeviationExpression>(std::move(child)),
68+
aggregateOp){};
69+
};
70+
71+
// The final operation for dividing by degrees of freedom and calculation square
72+
// root after summing up the squared deviation
73+
inline auto stdevFinalOperation = [](const NumericValue& aggregation,
74+
size_t numElements) {
75+
auto divAndRoot = [](double value, double degreesOfFreedom) {
76+
if (degreesOfFreedom <= 0) {
77+
return 0.0;
78+
} else {
79+
return std::sqrt(value / degreesOfFreedom);
80+
}
81+
};
82+
return makeNumericExpressionForAggregate<decltype(divAndRoot)>()(
83+
aggregation, NumericValue{static_cast<double>(numElements) - 1});
84+
};
85+
86+
// The actual Standard Deviation Expression
87+
// Mind the explicit instantiation of StdevExpressionBase in
88+
// AggregateExpression.cpp
89+
using StdevExpressionBase =
90+
DeviationAggExpression<AvgOperation, decltype(stdevFinalOperation)>;
91+
class StdevExpression : public StdevExpressionBase {
92+
using StdevExpressionBase::StdevExpressionBase;
93+
ValueId resultForEmptyGroup() const override { return Id::makeFromDouble(0); }
94+
};
95+
96+
} // namespace detail
97+
98+
using detail::StdevExpression;
99+
100+
} // namespace sparqlExpression

src/parser/sparqlParser/SparqlQleverVisitor.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "engine/sparqlExpressions/RegexExpression.h"
2424
#include "engine/sparqlExpressions/RelationalExpressions.h"
2525
#include "engine/sparqlExpressions/SampleExpression.h"
26+
#include "engine/sparqlExpressions/StdevExpression.h"
2627
#include "engine/sparqlExpressions/UuidExpressions.h"
2728
#include "parser/GraphPatternOperation.h"
2829
#include "parser/RdfParser.h"
@@ -2372,6 +2373,8 @@ ExpressionPtr Visitor::visit(Parser::AggregateContext* ctx) {
23722373
}
23732374

23742375
return makePtr.operator()<GroupConcatExpression>(std::move(separator));
2376+
} else if (functionName == "stdev") {
2377+
return makePtr.operator()<StdevExpression>();
23752378
} else {
23762379
AD_CORRECTNESS_CHECK(functionName == "sample");
23772380
return makePtr.operator()<SampleExpression>();

src/parser/sparqlParser/SparqlQleverVisitor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include "engine/sparqlExpressions/AggregateExpression.h"
1212
#include "engine/sparqlExpressions/NaryExpression.h"
13+
#include "engine/sparqlExpressions/StdevExpression.h"
1314
#include "parser/data/GraphRef.h"
1415
#undef EOF
1516
#include "parser/sparqlParser/generated/SparqlAutomaticVisitor.h"

src/parser/sparqlParser/generated/SparqlAutomatic.g4

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,7 @@ aggregate : COUNT '(' DISTINCT? ( '*' | expression ) ')'
582582
| MIN '(' DISTINCT? expression ')'
583583
| MAX '(' DISTINCT? expression ')'
584584
| AVG '(' DISTINCT? expression ')'
585+
| STDEV '(' DISTINCT? expression ')'
585586
| SAMPLE '(' DISTINCT? expression ')'
586587
| GROUP_CONCAT '(' DISTINCT? expression ( ';' SEPARATOR '=' string )? ')' ;
587588

@@ -763,6 +764,7 @@ SUM : S U M;
763764
MIN : M I N;
764765
MAX : M A X;
765766
AVG : A V G;
767+
STDEV : S T D E V ;
766768
SAMPLE : S A M P L E;
767769
SEPARATOR : S E P A R A T O R;
768770

src/parser/sparqlParser/generated/SparqlAutomatic.interp

Lines changed: 3 additions & 1 deletion
Large diffs are not rendered by default.

src/parser/sparqlParser/generated/SparqlAutomatic.tokens

Lines changed: 38 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -136,43 +136,44 @@ SUM=135
136136
MIN=136
137137
MAX=137
138138
AVG=138
139-
SAMPLE=139
140-
SEPARATOR=140
141-
IRI_REF=141
142-
PNAME_NS=142
143-
PNAME_LN=143
144-
BLANK_NODE_LABEL=144
145-
VAR1=145
146-
VAR2=146
147-
LANGTAG=147
148-
PREFIX_LANGTAG=148
149-
INTEGER=149
150-
DECIMAL=150
151-
DOUBLE=151
152-
INTEGER_POSITIVE=152
153-
DECIMAL_POSITIVE=153
154-
DOUBLE_POSITIVE=154
155-
INTEGER_NEGATIVE=155
156-
DECIMAL_NEGATIVE=156
157-
DOUBLE_NEGATIVE=157
158-
EXPONENT=158
159-
STRING_LITERAL1=159
160-
STRING_LITERAL2=160
161-
STRING_LITERAL_LONG1=161
162-
STRING_LITERAL_LONG2=162
163-
ECHAR=163
164-
NIL=164
165-
ANON=165
166-
PN_CHARS_U=166
167-
VARNAME=167
168-
PN_PREFIX=168
169-
PN_LOCAL=169
170-
PLX=170
171-
PERCENT=171
172-
HEX=172
173-
PN_LOCAL_ESC=173
174-
WS=174
175-
COMMENTS=175
139+
STDEV=139
140+
SAMPLE=140
141+
SEPARATOR=141
142+
IRI_REF=142
143+
PNAME_NS=143
144+
PNAME_LN=144
145+
BLANK_NODE_LABEL=145
146+
VAR1=146
147+
VAR2=147
148+
LANGTAG=148
149+
PREFIX_LANGTAG=149
150+
INTEGER=150
151+
DECIMAL=151
152+
DOUBLE=152
153+
INTEGER_POSITIVE=153
154+
DECIMAL_POSITIVE=154
155+
DOUBLE_POSITIVE=155
156+
INTEGER_NEGATIVE=156
157+
DECIMAL_NEGATIVE=157
158+
DOUBLE_NEGATIVE=158
159+
EXPONENT=159
160+
STRING_LITERAL1=160
161+
STRING_LITERAL2=161
162+
STRING_LITERAL_LONG1=162
163+
STRING_LITERAL_LONG2=163
164+
ECHAR=164
165+
NIL=165
166+
ANON=166
167+
PN_CHARS_U=167
168+
VARNAME=168
169+
PN_PREFIX=169
170+
PN_LOCAL=170
171+
PLX=171
172+
PERCENT=172
173+
HEX=173
174+
PN_LOCAL_ESC=174
175+
WS=175
176+
COMMENTS=176
176177
'*'=1
177178
'('=2
178179
')'=3

0 commit comments

Comments
 (0)