Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JNI for HISTOGRAM and MERGE_HISTOGRAM aggregations #14154

Merged
merged 109 commits into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
109 commits
Select commit Hold shift + click to select a range
e385fda
Add `COUNT_FREQUENCY` and `MERGE_FREQUENCY` aggregations
ttnghia Sep 6, 2023
e3df8d4
Change the new aggregations to `HISTOGRAM` and `MERGE_HISTOGRAM`
ttnghia Sep 6, 2023
7bc7f91
Update copyright year
ttnghia Sep 7, 2023
0fd2000
Implement interface for the new aggregations
ttnghia Sep 7, 2023
1b04436
Add new files
ttnghia Sep 7, 2023
1977d69
Add skeleton APIs
ttnghia Sep 11, 2023
6fa93fc
Extract hash_reduce_by_row
ttnghia Sep 11, 2023
d11dd7f
Adopt `hash_reduce_by_row` in `distinct_reduce`
ttnghia Sep 12, 2023
b632570
Merge branch 'branch-23.10' into percentile
ttnghia Sep 12, 2023
e58f3e3
Rename struct and simplify code
ttnghia Sep 12, 2023
3cf1948
Refactor `hash_reduce_by_row`
ttnghia Sep 12, 2023
8488646
Rewrite `hash_reduce_by_row.cuh`
ttnghia Sep 12, 2023
1994684
Rename and rewrite `distinct_reduce.hpp`
ttnghia Sep 12, 2023
5dcbac9
Rewrite `distinct.cu`
ttnghia Sep 12, 2023
6236fcc
Rewrite `distinct_reduce.cu`
ttnghia Sep 12, 2023
42a778f
Rewrite `hash_reduce_by_row.cuh`
ttnghia Sep 12, 2023
584ff8d
Minor changes
ttnghia Sep 12, 2023
4a3d60d
Fix style
ttnghia Sep 12, 2023
4e74119
Merge branch 'refactor_hash_reduce' into percentile
ttnghia Sep 12, 2023
34cb488
Fix comment
ttnghia Sep 12, 2023
c863b53
Merge branch 'refactor_hash_reduce' into percentile
ttnghia Sep 12, 2023
e73c07f
Move file
ttnghia Sep 12, 2023
40e8730
Merge `distinct_reduce.*` into `distinct.cu`
ttnghia Sep 12, 2023
95e4463
Move file
ttnghia Sep 12, 2023
723ae4c
Merge `distinct_reduce.*` into `distinct.cu`
ttnghia Sep 12, 2023
921243e
Merge branch 'refactor_hash_reduce' into percentile
ttnghia Sep 12, 2023
8fb7a9e
Revert "Merge `distinct_reduce.*` into `distinct.cu`"
ttnghia Sep 12, 2023
65427c8
Rename function
ttnghia Sep 12, 2023
bcc2db4
Merge branch 'refactor_hash_reduce' into percentile
ttnghia Sep 12, 2023
0c0c7ac
Fix output type
ttnghia Sep 12, 2023
c066276
Merge branch 'refactor_hash_reduce' into percentile
ttnghia Sep 12, 2023
01cc1c2
Move file
ttnghia Sep 13, 2023
f5a6a1a
Rename function
ttnghia Sep 13, 2023
dfbb720
Merge branch 'refactor_hash_reduce' into percentile
ttnghia Sep 13, 2023
924a2d6
Implement histogram reduction
ttnghia Sep 13, 2023
a1b516e
Support partial count
ttnghia Sep 13, 2023
e196ab4
Return list scalar of structs
ttnghia Sep 13, 2023
09f68af
Add factory functions for histogram and merge histogram
ttnghia Sep 13, 2023
f107d98
Fix aggregation dispatcher
ttnghia Sep 13, 2023
cc185d8
Fix bug
ttnghia Sep 13, 2023
547be01
Working test
ttnghia Sep 13, 2023
4d93b1e
Implement merge histogram
ttnghia Sep 13, 2023
6d8be79
Add test for merge histogram
ttnghia Sep 13, 2023
2d08539
Cleanup
ttnghia Sep 13, 2023
7999c7e
Cleanup
ttnghia Sep 13, 2023
2d47048
Add tests with nulls
ttnghia Sep 13, 2023
824dcad
Add sliced input tests
ttnghia Sep 13, 2023
3fb43f4
Fix sliced input
ttnghia Sep 13, 2023
ee229a0
Add binding for `HISTOGRAM` and `MERGE_HISTOGRAM` aggregations
ttnghia Sep 13, 2023
0ece05d
Merge branch 'branch-23.10' into percentile
ttnghia Sep 14, 2023
b71c7a8
Fix compiling issue
ttnghia Sep 14, 2023
1edeb4c
Remove header
ttnghia Sep 14, 2023
75c35c4
Change test types
ttnghia Sep 15, 2023
35f6374
Merge branch 'jni_histogram' into percentile
ttnghia Sep 18, 2023
c6c2c43
Rewrite tests
ttnghia Sep 18, 2023
b5dd22a
Misc
ttnghia Sep 18, 2023
17b8975
Cleanup
ttnghia Sep 18, 2023
c0b245f
Revert changes
ttnghia Sep 18, 2023
a8b3696
Add more assert statements
ttnghia Sep 18, 2023
a7fee30
Clean up tests
ttnghia Sep 18, 2023
829017a
Add docs
ttnghia Sep 18, 2023
e53042e
Rewrite docs
ttnghia Sep 18, 2023
49608ab
Add a helper file
ttnghia Sep 18, 2023
08aac0e
Rewrite histogram
ttnghia Sep 18, 2023
aaaf347
Add docs
ttnghia Sep 18, 2023
2f5b343
Remove file
ttnghia Sep 18, 2023
c11f939
Rewrite docs
ttnghia Sep 18, 2023
d10842e
Change docs
ttnghia Sep 18, 2023
6abc7b5
Add headers
ttnghia Sep 19, 2023
f833f58
Implement groupby histogram and merge histogram aggs
ttnghia Sep 19, 2023
ef308e8
Update header copyright
ttnghia Sep 19, 2023
70e624d
Rename function
ttnghia Sep 19, 2023
ee91b2e
Fix typos
ttnghia Sep 19, 2023
7c51faa
Add file
ttnghia Sep 19, 2023
270bcb8
Add docs
ttnghia Sep 19, 2023
6447877
Add empty tests
ttnghia Sep 19, 2023
c766e43
Implement histogram tests
ttnghia Sep 19, 2023
baddf18
Move tests
ttnghia Sep 20, 2023
0afad9c
Rename alias
ttnghia Sep 20, 2023
c05e595
Add target types
ttnghia Sep 20, 2023
8653053
Add empty return
ttnghia Sep 20, 2023
8d6fdfe
MISC
ttnghia Sep 20, 2023
d1fbda4
Add more assertions
ttnghia Sep 20, 2023
199d97b
Implement unit tests for groupby histogram
ttnghia Sep 20, 2023
4b0983e
Reimplement merge histogram
ttnghia Sep 20, 2023
0a8a03d
Implement unit tests for merge histogram
ttnghia Sep 20, 2023
201d432
Fix empty output for merge histogram
ttnghia Sep 20, 2023
edf6816
Fix empty input test
ttnghia Sep 20, 2023
8ac649e
Remove comment
ttnghia Sep 20, 2023
04965fa
Cleanup
ttnghia Sep 20, 2023
63ef1fa
Fix docs
ttnghia Sep 20, 2023
d31de20
Rewrite docs
ttnghia Sep 20, 2023
34a4268
Rewrite histogram.cu
ttnghia Sep 20, 2023
502a3da
Fix typo
ttnghia Sep 20, 2023
61377e0
Fix header
ttnghia Sep 20, 2023
dd72159
Revert changes
ttnghia Sep 20, 2023
56516e9
Merge branch 'branch-23.10' into percentile
ttnghia Sep 20, 2023
00c9c79
Merge remote-tracking branch 'nghia/percentile' into percentile
ttnghia Sep 21, 2023
424196b
Add empty input handling
ttnghia Sep 21, 2023
26238dd
Rename function and change return type
ttnghia Sep 21, 2023
5001cbd
Merge remote-tracking branch 'nghia/percentile' into percentile
ttnghia Sep 21, 2023
e701908
Merge branch 'branch-23.10' into jni_histogram
ttnghia Sep 21, 2023
76f77a0
Format
ttnghia Sep 21, 2023
31093cd
Add docs and reduction aggregations
ttnghia Sep 21, 2023
2ce59d1
Update copyright years
ttnghia Sep 21, 2023
39ce6d1
Merge remote-tracking branch 'nghia/percentile' into percentile
ttnghia Sep 21, 2023
ad09d30
Revert "Add binding for `HISTOGRAM` and `MERGE_HISTOGRAM` aggregations"
ttnghia Sep 21, 2023
69218fb
Add Java tests
ttnghia Sep 21, 2023
8f42bc7
Merge branch 'branch-23.10' into jni_histogram
ttnghia Sep 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions java/src/main/java/ai/rapids/cudf/Aggregation.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
*
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
* Copyright (c) 2020-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -68,7 +68,9 @@ enum Kind {
DENSE_RANK(29),
PERCENT_RANK(30),
TDIGEST(31), // This can take a delta argument for accuracy level
MERGE_TDIGEST(32); // This can take a delta argument for accuracy level
MERGE_TDIGEST(32), // This can take a delta argument for accuracy level
HISTOGRAM(33),
MERGE_HISTOGRAM(34);

final int nativeId;

Expand Down Expand Up @@ -918,6 +920,26 @@ static TDigestAggregation mergeTDigest(int delta) {
return new TDigestAggregation(Kind.MERGE_TDIGEST, delta);
}

static final class HistogramAggregation extends NoParamAggregation {
private HistogramAggregation() {
super(Kind.HISTOGRAM);
}
}

static final class MergeHistogramAggregation extends NoParamAggregation {
private MergeHistogramAggregation() {
super(Kind.MERGE_HISTOGRAM);
}
}

static HistogramAggregation histogram() {
return new HistogramAggregation();
}

static MergeHistogramAggregation mergeHistogram() {
return new MergeHistogramAggregation();
}

/**
* Create one of the aggregations that only needs a kind, no other parameters. This does not
* work for all types and for code safety reasons each kind is added separately.
Expand Down
24 changes: 23 additions & 1 deletion java/src/main/java/ai/rapids/cudf/GroupByAggregation.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -315,4 +315,26 @@ public static GroupByAggregation createTDigest(int delta) {
public static GroupByAggregation mergeTDigest(int delta) {
return new GroupByAggregation(Aggregation.mergeTDigest(delta));
}

/**
* Histogram aggregation, computing the frequencies for each unique row.
*
* A histogram is given as a lists column, in which the first child stores unique rows from
* the input values and the second child stores their corresponding frequencies.
*
* @return A lists of structs column in which each list contains a histogram corresponding to
* an input key.
*/
public static GroupByAggregation histogram() {
return new GroupByAggregation(Aggregation.histogram());
}

/**
* MergeHistogram aggregation, to merge multiple histograms.
*
* @return A new histogram in which the frequencies of the unique rows are sum up.
*/
public static GroupByAggregation mergeHistogram() {
return new GroupByAggregation(Aggregation.mergeHistogram());
}
}
20 changes: 19 additions & 1 deletion java/src/main/java/ai/rapids/cudf/ReductionAggregation.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
*
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -286,4 +286,22 @@ public static ReductionAggregation mergeSets(NullEquality nullEquality, NaNEqual
return new ReductionAggregation(Aggregation.mergeSets(nullEquality, nanEquality));
}

/**
* Create HistogramAggregation, computing the frequencies for each unique row.
*
* @return A structs column in which the first child stores unique rows from the input and the
* second child stores their corresponding frequencies.
*/
public static ReductionAggregation histogram() {
return new ReductionAggregation(Aggregation.histogram());
}

/**
* Create MergeHistogramAggregation, to merge multiple histograms.
*
* @return A new histogram in which the frequencies of the unique rows are sum up.
*/
public static ReductionAggregation mergeHistogram() {
return new ReductionAggregation(Aggregation.mergeHistogram());
}
}
7 changes: 6 additions & 1 deletion java/src/main/native/src/AggregationJni.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
* Copyright (c) 2020-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -90,6 +90,11 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createNoParamAgg(JNIEnv
case 30: // ANSI SQL PERCENT_RANK
return cudf::make_rank_aggregation(cudf::rank_method::MIN, {}, cudf::null_policy::INCLUDE,
{}, cudf::rank_percentage::ONE_NORMALIZED);
case 33: // HISTOGRAM
return cudf::make_histogram_aggregation();
case 34: // MERGE_HISTOGRAM
return cudf::make_merge_histogram_aggregation();

default: throw std::logic_error("Unsupported No Parameter Aggregation Operation");
}
}();
Expand Down
109 changes: 109 additions & 0 deletions java/src/test/java/ai/rapids/cudf/TableTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4129,6 +4129,115 @@ void testMergeTDigestReduction() {
}
}

@Test
void testGroupbyHistogram() {
StructType histogramStruct = new StructType(false,
new BasicType(false, DType.INT32), // values
new BasicType(false, DType.INT64)); // frequencies
ListType histogramList = new ListType(false, histogramStruct);

// key = 0: values = [2, 2, -3, -2, 2]
// key = 1: values = [2, 0, 5, 2, 1]
// key = 2: values = [-3, 1, 1, 2, 2]
try (Table input = new Table.TestBuilder()
.column(2, 0, 2, 1, 1, 1, 0, 0, 0, 1, 2, 2, 1, 0, 2)
.column(-3, 2, 1, 2, 0, 5, 2, -3, -2, 2, 1, 2, 1, 2, 2)
.build();
Table result = input.groupBy(0)
.aggregate(GroupByAggregation.histogram().onColumn(1));
Table sortedResult = result.orderBy(OrderByArg.asc(0));
ColumnVector sortedOutHistograms = sortedResult.getColumn(1).listSortRows(false, false);

ColumnVector expectedKeys = ColumnVector.fromInts(0, 1, 2);
ColumnVector expectedHistograms = ColumnVector.fromLists(histogramList,
Arrays.asList(new StructData(-3, 1L), new StructData(-2, 1L), new StructData(2, 3L)),
Arrays.asList(new StructData(0, 1L), new StructData(1, 1L), new StructData(2, 2L),
new StructData(5, 1L)),
Arrays.asList(new StructData(-3, 1L), new StructData(1, 2L), new StructData(2, 2L)))
) {
assertColumnsAreEqual(expectedKeys, sortedResult.getColumn(0));
assertColumnsAreEqual(expectedHistograms, sortedOutHistograms);
}
}

@Test
void testGroupbyMergeHistogram() {
StructType histogramStruct = new StructType(false,
new BasicType(false, DType.INT32), // values
new BasicType(false, DType.INT64)); // frequencies
ListType histogramList = new ListType(false, histogramStruct);

// key = 0: histograms = [[<-3, 1>, <-2, 1>, <2, 3>], [<0, 1>, <1, 1>], [<-3, 3>, <0, 1>, <1, 2>]]
// key = 1: histograms = [[<-2, 1>, <1, 3>, <2, 2>], [<0, 2>, <1, 1>, <2, 2>]]
try (Table input = new Table.TestBuilder()
.column(0, 1, 0, 1, 0)
.column(histogramStruct,
new StructData[]{new StructData(-3, 1L), new StructData(-2, 1L), new StructData(2, 3L)},
new StructData[]{new StructData(-2, 1L), new StructData(1, 3L), new StructData(2, 2L)},
new StructData[]{new StructData(0, 1L), new StructData(1, 1L)},
new StructData[]{new StructData(0, 2L), new StructData(1, 1L), new StructData(2, 2L)},
new StructData[]{new StructData(-3, 3L), new StructData(0, 1L), new StructData(1, 2L)})
.build();
Table result = input.groupBy(0)
.aggregate(GroupByAggregation.mergeHistogram().onColumn(1));
Table sortedResult = result.orderBy(OrderByArg.asc(0));
ColumnVector sortedOutHistograms = sortedResult.getColumn(1).listSortRows(false, false);

ColumnVector expectedKeys = ColumnVector.fromInts(0, 1);
ColumnVector expectedHistograms = ColumnVector.fromLists(histogramList,
Arrays.asList(new StructData(-3, 4L), new StructData(-2, 1L), new StructData(0, 2L),
new StructData(1, 3L), new StructData(2, 3L)),
Arrays.asList(new StructData(-2, 1L), new StructData(0, 2L), new StructData(1, 4L),
new StructData(2, 4L)))
) {
assertColumnsAreEqual(expectedKeys, sortedResult.getColumn(0));
assertColumnsAreEqual(expectedHistograms, sortedOutHistograms);
}
}

@Test
void testReductionHistogram() {
StructType histogramStruct = new StructType(false,
new BasicType(false, DType.INT32), // values
new BasicType(false, DType.INT64)); // frequencies

try (ColumnVector input = ColumnVector.fromInts(-3, 2, 1, 2, 0, 5, 2, -3, -2, 2, 1);
Scalar result = input.reduce(ReductionAggregation.histogram(), DType.LIST);
ColumnVector resultCV = result.getListAsColumnView().copyToColumnVector();
Table resultTable = new Table(resultCV);
Table sortedResult = resultTable.orderBy(OrderByArg.asc(0));

ColumnVector expectedHistograms = ColumnVector.fromStructs(histogramStruct,
new StructData(-3, 2L), new StructData(-2, 1L), new StructData(0, 1L),
new StructData(1, 2L), new StructData(2, 4L), new StructData(5, 1L))
) {
assertColumnsAreEqual(expectedHistograms, sortedResult.getColumn(0));
}
}

@Test
void testReductionMergeHistogram() {
StructType histogramStruct = new StructType(false,
new BasicType(false, DType.INT32), // values
new BasicType(false, DType.INT64)); // frequencies

try (ColumnVector input = ColumnVector.fromStructs(histogramStruct,
new StructData(-3, 2L), new StructData(2, 1L), new StructData(1, 1L),
new StructData(2, 2L), new StructData(0, 4L), new StructData(5, 1L),
new StructData(2, 2L), new StructData(-3, 3L), new StructData(-2, 5L),
new StructData(2, 3L), new StructData(1, 4L));
Scalar result = input.reduce(ReductionAggregation.mergeHistogram(), DType.LIST);
ColumnVector resultCV = result.getListAsColumnView().copyToColumnVector();
Table resultTable = new Table(resultCV);
Table sortedResult = resultTable.orderBy(OrderByArg.asc(0));

ColumnVector expectedHistograms = ColumnVector.fromStructs(histogramStruct,
new StructData(-3, 5L), new StructData(-2, 5L), new StructData(0, 4L),
new StructData(1, 5L), new StructData(2, 8L), new StructData(5, 1L))
) {
assertColumnsAreEqual(expectedHistograms, sortedResult.getColumn(0));
}
}
@Test
void testGroupByMinMaxDecimal() {
try (Table t1 = new Table.TestBuilder()
Expand Down