@@ -204,8 +204,12 @@ class NoisyAvgGaussianAggregate : public exec::Aggregate {
204
204
noiseScale = static_cast <double >(decodedNoiseScale_.valueAt <uint64_t >(i));
205
205
}
206
206
accumulator->checkAndSetNoiseScale (noiseScale);
207
- accumulator->updateCount (1 );
208
- accumulator->updateSum (decodedValue_.valueAt <double >(i));
207
+
208
+ // Update sum and count. check input value and dispatch to corresponding
209
+ // type.
210
+ auto inputType = args[0 ]->typeKind ();
211
+ VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH (
212
+ updateTemplate, inputType, accumulator, decodedValue_, i);
209
213
}
210
214
211
215
void updateAccumulatorFromIntermediateResult (
@@ -224,26 +228,84 @@ class NoisyAvgGaussianAggregate : public exec::Aggregate {
224
228
accumulator->checkAndSetNoiseScale (otherAccumulator.getNoiseScale ());
225
229
}
226
230
}
231
+
232
+ // Template helper function to update accumulator, can support all numeric
233
+ // data types. Only used in this class.
234
+ template <TypeKind TData>
235
+ void updateTemplate (
236
+ AccumulatorType* accumulator,
237
+ const DecodedVector& decodedValue,
238
+ vector_size_t i) {
239
+ using T = typename TypeTraits<TData>::NativeType;
240
+ // Handle decimal types separately.
241
+ if constexpr (std::is_same_v<T, int64_t > || std::is_same_v<T, int128_t >) {
242
+ const auto & type = decodedValue.base ()->type ();
243
+ if (type->isDecimal ()) {
244
+ auto value = decodedValue.valueAt <T>(i);
245
+ auto scale = type->isShortDecimal () ? type->asShortDecimal ().scale ()
246
+ : type->asLongDecimal ().scale ();
247
+ double doubleValue = static_cast <double >(value) / pow (10 , scale);
248
+
249
+ accumulator->updateSum (doubleValue);
250
+ accumulator->updateCount (1 );
251
+ return ;
252
+ }
253
+ }
254
+ // Handle other types.
255
+ if constexpr (
256
+ std::is_same_v<T, TypeTraits<TypeKind::TIMESTAMP>> ||
257
+ std::is_same_v<T, TypeTraits<TypeKind::VARBINARY>> ||
258
+ std::is_same_v<T, TypeTraits<TypeKind::VARCHAR>> ||
259
+ std::is_same_v<T, facebook::velox::StringView> ||
260
+ std::is_same_v<T, facebook::velox::Timestamp>) {
261
+ VELOX_FAIL (" NoisySumGaussianAggregate does not support this data type." );
262
+ } else {
263
+ accumulator->updateSum (static_cast <double >(decodedValue.valueAt <T>(i)));
264
+ accumulator->updateCount (1 );
265
+ }
266
+ }
227
267
};
228
268
} // namespace
229
269
230
270
void registerNoisyAvgGaussianAggregate (
231
271
const std::string& prefix,
232
272
bool withCompanionFunctions,
233
273
bool overwrite) {
234
- std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures{
235
- exec::AggregateFunctionSignatureBuilder ()
236
- .returnType (" double" )
237
- .intermediateType (" varbinary" )
238
- .argumentType (" double" ) // input type
239
- .argumentType (" double" ) // noise scale
240
- .build (),
241
- exec::AggregateFunctionSignatureBuilder ()
242
- .returnType (" double" )
243
- .intermediateType (" varbinary" )
244
- .argumentType (" double" ) // input type
245
- .argumentType (" bigint" ) // noise scale
246
- .build ()};
274
+ // Helper function to create a signature builder with return and
275
+ // intermediate types
276
+ auto createBuilder = []() {
277
+ return exec::AggregateFunctionSignatureBuilder ()
278
+ .returnType (" double" )
279
+ .intermediateType (" varbinary" );
280
+ };
281
+
282
+ // List of possible argument types.
283
+ const std::vector<std::string> simpleDataTypes = {
284
+ " tinyint" , " smallint" , " integer" , " bigint" , " real" , " double" };
285
+ const std::vector<std::string> noiseScaleTypes = {" double" , " bigint" };
286
+
287
+ std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures;
288
+
289
+ // Generate signatures for all type combinations.
290
+ for (const auto & noiseScaleType : noiseScaleTypes) {
291
+ // Handle simple types.
292
+ for (const auto & dataType : simpleDataTypes) {
293
+ signatures.push_back (createBuilder ()
294
+ .argumentType (dataType)
295
+ .argumentType (noiseScaleType)
296
+ .build ());
297
+ }
298
+
299
+ // Handle decimal types separately.
300
+ signatures.push_back (exec::AggregateFunctionSignatureBuilder ()
301
+ .integerVariable (" a_precision" )
302
+ .integerVariable (" a_scale" )
303
+ .returnType (" double" )
304
+ .intermediateType (" varbinary" )
305
+ .argumentType (" DECIMAL(a_precision, a_scale)" )
306
+ .argumentType (noiseScaleType)
307
+ .build ());
308
+ }
247
309
248
310
auto name = prefix + kNoisyAvgGaussian ;
249
311
exec::registerAggregateFunction (
0 commit comments