Skip to content

Commit 51c414d

Browse files
Oliver Xufacebook-github-bot
Oliver Xu
authored andcommitted
feat(noisy_avg): Add support for all numeric input types (facebookincubator#13709)
Summary: Pull Request resolved: facebookincubator#13709 ### Summary This diff adds support for all numeric input types to the `noisy_avg` aggregation function. ### Code Changes The diff modifies two files: 1. `NoisyAvgGaussianAggregationTest.cpp`: Adds new test cases for `bigint`, `decimal`, and `real` input types. 2. `NoisyAvgGaussianAggregate.cpp`: Updates the `update` method to handle different input types using `VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH`. ### Impact This diff allows the `noisy_avg` aggregation function to work with a wider range of input types, making it more versatile and useful for various use cases. Differential Revision: D76209005
1 parent a3f2668 commit 51c414d

File tree

2 files changed

+111
-15
lines changed

2 files changed

+111
-15
lines changed

velox/functions/prestosql/aggregates/NoisyAvgGaussianAggregate.cpp

Lines changed: 77 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,12 @@ class NoisyAvgGaussianAggregate : public exec::Aggregate {
204204
noiseScale = static_cast<double>(decodedNoiseScale_.valueAt<uint64_t>(i));
205205
}
206206
accumulator->checkAndSetNoiseScale(noiseScale);
207-
accumulator->updateCount(1);
208-
accumulator->updateSum(decodedValue_.valueAt<double>(i));
207+
208+
// Update sum and count. check input value and dispatch to corresponding
209+
// type.
210+
auto inputType = args[0]->typeKind();
211+
VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH(
212+
updateTemplate, inputType, accumulator, decodedValue_, i);
209213
}
210214

211215
void updateAccumulatorFromIntermediateResult(
@@ -224,26 +228,84 @@ class NoisyAvgGaussianAggregate : public exec::Aggregate {
224228
accumulator->checkAndSetNoiseScale(otherAccumulator.getNoiseScale());
225229
}
226230
}
231+
232+
// Template helper function to update accumulator, can support all numeric
233+
// data types. Only used in this class.
234+
template <TypeKind TData>
235+
void updateTemplate(
236+
AccumulatorType* accumulator,
237+
const DecodedVector& decodedValue,
238+
vector_size_t i) {
239+
using T = typename TypeTraits<TData>::NativeType;
240+
// Handle decimal types separately.
241+
if constexpr (std::is_same_v<T, int64_t> || std::is_same_v<T, int128_t>) {
242+
const auto& type = decodedValue.base()->type();
243+
if (type->isDecimal()) {
244+
auto value = decodedValue.valueAt<T>(i);
245+
auto scale = type->isShortDecimal() ? type->asShortDecimal().scale()
246+
: type->asLongDecimal().scale();
247+
double doubleValue = static_cast<double>(value) / pow(10, scale);
248+
249+
accumulator->updateSum(doubleValue);
250+
accumulator->updateCount(1);
251+
return;
252+
}
253+
}
254+
// Handle other types.
255+
if constexpr (
256+
std::is_same_v<T, TypeTraits<TypeKind::TIMESTAMP>> ||
257+
std::is_same_v<T, TypeTraits<TypeKind::VARBINARY>> ||
258+
std::is_same_v<T, TypeTraits<TypeKind::VARCHAR>> ||
259+
std::is_same_v<T, facebook::velox::StringView> ||
260+
std::is_same_v<T, facebook::velox::Timestamp>) {
261+
VELOX_FAIL("NoisySumGaussianAggregate does not support this data type.");
262+
} else {
263+
accumulator->updateSum(static_cast<double>(decodedValue.valueAt<T>(i)));
264+
accumulator->updateCount(1);
265+
}
266+
}
227267
};
228268
} // namespace
229269

230270
void registerNoisyAvgGaussianAggregate(
231271
const std::string& prefix,
232272
bool withCompanionFunctions,
233273
bool overwrite) {
234-
std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures{
235-
exec::AggregateFunctionSignatureBuilder()
236-
.returnType("double")
237-
.intermediateType("varbinary")
238-
.argumentType("double") // input type
239-
.argumentType("double") // noise scale
240-
.build(),
241-
exec::AggregateFunctionSignatureBuilder()
242-
.returnType("double")
243-
.intermediateType("varbinary")
244-
.argumentType("double") // input type
245-
.argumentType("bigint") // noise scale
246-
.build()};
274+
// Helper function to create a signature builder with return and
275+
// intermediate types
276+
auto createBuilder = []() {
277+
return exec::AggregateFunctionSignatureBuilder()
278+
.returnType("double")
279+
.intermediateType("varbinary");
280+
};
281+
282+
// List of possible argument types.
283+
const std::vector<std::string> simpleDataTypes = {
284+
"tinyint", "smallint", "integer", "bigint", "real", "double"};
285+
const std::vector<std::string> noiseScaleTypes = {"double", "bigint"};
286+
287+
std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures;
288+
289+
// Generate signatures for all type combinations.
290+
for (const auto& noiseScaleType : noiseScaleTypes) {
291+
// Handle simple types.
292+
for (const auto& dataType : simpleDataTypes) {
293+
signatures.push_back(createBuilder()
294+
.argumentType(dataType)
295+
.argumentType(noiseScaleType)
296+
.build());
297+
}
298+
299+
// Handle decimal types separately.
300+
signatures.push_back(exec::AggregateFunctionSignatureBuilder()
301+
.integerVariable("a_precision")
302+
.integerVariable("a_scale")
303+
.returnType("double")
304+
.intermediateType("varbinary")
305+
.argumentType("DECIMAL(a_precision, a_scale)")
306+
.argumentType(noiseScaleType)
307+
.build());
308+
}
247309

248310
auto name = prefix + kNoisyAvgGaussian;
249311
exec::registerAggregateFunction(

velox/functions/prestosql/aggregates/tests/NoisyAvgGaussianAggregationTest.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,18 @@ class NoisyAvgGaussianAggregationTest
1515

1616
RowTypePtr doubleRowType_{
1717
ROW({"c0", "c1", "c2"}, {DOUBLE(), DOUBLE(), DOUBLE()})};
18+
RowTypePtr bigintRowType_{
19+
ROW({"c0", "c1", "c2"}, {BIGINT(), BIGINT(), BIGINT()})};
20+
RowTypePtr decimalRowType_{
21+
ROW({"c0", "c1", "c2"},
22+
{DECIMAL(20, 5), DECIMAL(20, 5), DECIMAL(20, 5)})};
23+
RowTypePtr realRowType_{ROW({"c0", "c1", "c2"}, {REAL(), REAL(), REAL()})};
24+
RowTypePtr integerRowType_{
25+
ROW({"c0", "c1", "c2"}, {INTEGER(), INTEGER(), INTEGER()})};
26+
RowTypePtr smallintRowType_{
27+
ROW({"c0", "c1", "c2"}, {SMALLINT(), SMALLINT(), SMALLINT()})};
28+
RowTypePtr tinyintRowType_{
29+
ROW({"c0", "c1", "c2"}, {TINYINT(), TINYINT(), TINYINT()})};
1830
};
1931

2032
TEST_F(NoisyAvgGaussianAggregationTest, basicNoNoise) {
@@ -111,4 +123,26 @@ TEST_F(NoisyAvgGaussianAggregationTest, aggregateNullsNoNoise) {
111123
vectors, {"c0"}, {"noisy_avg_gaussian(c1, 0.0)"}, {expectedResult});
112124
}
113125

126+
TEST_F(NoisyAvgGaussianAggregationTest, numericInputTypeTestNoNoise) {
127+
auto rowTypes = {
128+
doubleRowType_,
129+
bigintRowType_,
130+
decimalRowType_,
131+
realRowType_,
132+
integerRowType_,
133+
smallintRowType_,
134+
tinyintRowType_};
135+
136+
for (const auto& rowType : rowTypes) {
137+
auto vectors = makeVectors(rowType, 3, 3);
138+
createDuckDbTable(vectors);
139+
140+
testAggregations(
141+
vectors,
142+
{},
143+
{"noisy_avg_gaussian(c2, 0.0)"},
144+
"SELECT AVG(c2) FROM tmp");
145+
}
146+
}
147+
114148
} // namespace facebook::velox::aggregate::test

0 commit comments

Comments
 (0)