Skip to content

Commit 98c3f03

Browse files
github-actions[bot]amory
andauthored
branch-3.0: [improve](function) support collect_list with nested types param #47965 (#48114)
Cherry-picked from #47965 Co-authored-by: amory <[email protected]>
1 parent c7010e1 commit 98c3f03

File tree

4 files changed

+873
-1
lines changed

4 files changed

+873
-1
lines changed

be/src/vec/aggregate_functions/aggregate_function_collect.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ AggregateFunctionPtr do_create_agg_function_collect(bool distinct, const DataTyp
4848
AggregateFunctionCollectListData<T, HasLimit>, HasLimit, std::false_type>>(
4949
argument_types, result_is_nullable);
5050
}
51+
} else if (!distinct) {
52+
// void type means support array/map/struct type for collect_list
53+
return creator_without_type::create<AggregateFunctionCollect<
54+
AggregateFunctionCollectListData<void, HasLimit>, HasLimit, std::false_type>>(
55+
argument_types, result_is_nullable);
5156
}
5257
return nullptr;
5358
}
@@ -92,6 +97,9 @@ AggregateFunctionPtr create_aggregate_function_collect_impl(const std::string& n
9297
if constexpr (ShowNull::value) {
9398
return do_create_agg_function_collect<void, HasLimit, ShowNull>(
9499
distinct, argument_types, result_is_nullable);
100+
} else {
101+
return do_create_agg_function_collect<void, HasLimit, ShowNull>(
102+
distinct, argument_types, result_is_nullable);
95103
}
96104
}
97105

be/src/vec/aggregate_functions/aggregate_function_collect.h

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,9 @@ struct AggregateFunctionCollectListData {
193193
PaddedPODArray<ElementType> data;
194194
Int64 max_size = -1;
195195

196+
AggregateFunctionCollectListData() {}
197+
AggregateFunctionCollectListData(const DataTypes& argument_types) {}
198+
196199
size_t size() const { return data.size(); }
197200

198201
void add(const IColumn& column, size_t row_num) {
@@ -305,6 +308,67 @@ struct AggregateFunctionCollectListData<StringRef, HasLimit> {
305308
}
306309
};
307310

311+
template <typename HasLimit>
312+
struct AggregateFunctionCollectListData<void, HasLimit> {
313+
using ElementType = StringRef;
314+
using Self = AggregateFunctionCollectListData<void, HasLimit>;
315+
MutableColumnPtr column_data;
316+
Int64 max_size = -1;
317+
318+
AggregateFunctionCollectListData() {}
319+
AggregateFunctionCollectListData(const DataTypes& argument_types) {
320+
DataTypePtr column_type = argument_types[0];
321+
column_data = column_type->create_column();
322+
}
323+
324+
size_t size() const { return column_data->size(); }
325+
326+
void add(const IColumn& column, size_t row_num) { column_data->insert_from(column, row_num); }
327+
328+
void merge(const AggregateFunctionCollectListData& rhs) {
329+
if constexpr (HasLimit::value) {
330+
if (max_size == -1) {
331+
max_size = rhs.max_size;
332+
}
333+
max_size = rhs.max_size;
334+
335+
column_data->insert_range_from(
336+
*rhs.column_data, 0,
337+
std::min(assert_cast<size_t, TypeCheckOnRelease::DISABLE>(
338+
static_cast<size_t>(max_size - size())),
339+
rhs.size()));
340+
} else {
341+
column_data->insert_range_from(*rhs.column_data, 0, rhs.size());
342+
}
343+
}
344+
345+
void write(BufferWritable& buf) const {
346+
const size_t size = column_data->size();
347+
write_binary(size, buf);
348+
for (size_t i = 0; i < size; i++) {
349+
write_string_binary(column_data->get_data_at(i), buf);
350+
}
351+
write_var_int(max_size, buf);
352+
}
353+
354+
void read(BufferReadable& buf) {
355+
size_t size = 0;
356+
read_binary(size, buf);
357+
column_data->reserve(size);
358+
359+
StringRef s;
360+
for (size_t i = 0; i < size; i++) {
361+
read_string_binary(s, buf);
362+
column_data->insert_data(s.data, s.size);
363+
}
364+
read_var_int(max_size, buf);
365+
}
366+
367+
void reset() { column_data->clear(); }
368+
369+
void insert_result_into(IColumn& to) const { to.insert_range_from(*column_data, 0, size()); }
370+
};
371+
308372
template <typename T>
309373
struct AggregateFunctionArrayAggData {
310374
using ElementType = T;
@@ -622,7 +686,11 @@ class AggregateFunctionCollect
622686
new (place) Data();
623687
}
624688
} else {
625-
new (place) Data();
689+
if constexpr (std::is_same_v<Data, AggregateFunctionCollectListData<void, HasLimit>>) {
690+
new (place) Data(argument_types);
691+
} else {
692+
new (place) Data();
693+
}
626694
}
627695
}
628696

0 commit comments

Comments
 (0)