-
Notifications
You must be signed in to change notification settings - Fork 5.5k
feat(function): Add array_union_sum aggregation function #26842
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Reviewer's GuideImplements a new array_union_sum SQL aggregation function for non-decimal numeric arrays, including execution engine support, state management, serialization, registration in the built-in function set, tests, and documentation/release notes updates. Sequence diagram for array_union_sum aggregation executionsequenceDiagram
actor User
participant Parser as SqlParser
participant Planner as QueryPlanner
participant Engine as ExecutionEngine
participant Acc as ArrayUnionSumAccumulator
participant State as ArrayUnionSumState
participant Result as ArrayUnionSumResult
User->>Parser: Submit SELECT array_union_sum(arr) ...
Parser->>Planner: ParsedQuery(plan with array_union_sum)
Planner->>Engine: PhysicalPlan(using ArrayUnionSumAggregation)
loop For each input page
Engine->>Acc: addInput(page)
loop For each row with array value
Acc->>+State: input(elementType, state, arrayBlock)
alt state.get() is null (first array)
State-->>State: create ArrayUnionSumResult.create(elementType, adder, arrayBlock)
State-->>Acc: updated state with SingleArrayBlock
else subsequent arrays
State->>Result: get()
Result->>Result: unionSum(arrayBlock)
Result-->>State: new AccumulatedValues
end
end
end
note over Engine,Acc: Partial aggregation results may be combined across threads/nodes
Engine->>Acc: combine(partialState1, partialState2)
Acc->>State: combine(state, otherState)
State->>Result: get()
Result->>Result: unionSum(otherState.get())
Result-->>State: merged AccumulatedValues
Engine->>Acc: evaluateFinal()
Acc->>State: output(state, blockBuilder)
alt state.get() is null
State-->>Engine: appendNull()
else non null
State->>Result: get()
Result->>Result: serialize(outBlockBuilder)
Result-->>Engine: final array block
end
Engine-->>User: QueryResult(with aggregated array_union_sum column)
Class diagram for array_union_sum aggregation componentsclassDiagram
class ArrayUnionSumAggregation {
+String NAME
+ArrayUnionSumAggregation()
+String getDescription()
+BuiltInAggregationFunctionImplementation specialize(BoundVariables boundVariables, int arity, FunctionAndTypeManager functionAndTypeManager)
-static BuiltInAggregationFunctionImplementation generateAggregation(Type elementType, ArrayType outputType)
-static List~ParameterMetadata~ createInputParameterMetadata(Type inputType)
+static void input(Type elementType, ArrayUnionSumState state, Block arrayBlock)
+static void combine(ArrayUnionSumState state, ArrayUnionSumState otherState)
+static void output(ArrayUnionSumState state, BlockBuilder out)
-static MethodHandle INPUT_FUNCTION
-static MethodHandle COMBINE_FUNCTION
-static MethodHandle OUTPUT_FUNCTION
}
class ArrayUnionSumResult {
-Type elementType
-Adder adder
+ArrayUnionSumResult(Type elementType, Adder adder)
+static ArrayUnionSumResult create(Type elementType, Adder adder, Block arrayBlock)
+Type getElementType()
+void serialize(BlockBuilder out)
+ArrayUnionSumResult unionSum(ArrayUnionSumResult other)
+ArrayUnionSumResult unionSum(Block arrayBlock)
+long getRetainedSizeInBytes()
+int size()
+void appendValue(int i, BlockBuilder blockBuilder)
+boolean isValueNull(int i)
+Block getValueBlock()
+int getValueBlockIndex(int i)
+static void appendValue(Type elementType, Block block, int position, BlockBuilder blockBuilder)
}
class ArrayUnionSumResult_SingleArrayBlock {
-Block arrayBlock
+SingleArrayBlock(Type elementType, Adder adder, Block arrayBlock)
+int size()
+void appendValue(int i, BlockBuilder blockBuilder)
+boolean isValueNull(int i)
+long getRetainedSizeInBytes()
+Block getValueBlock()
+int getValueBlockIndex(int i)
}
class ArrayUnionSumResult_AccumulatedValues {
-Block valueBlock
+AccumulatedValues(Type elementType, Adder adder, Block valueBlock)
+int size()
+void appendValue(int i, BlockBuilder blockBuilder)
+boolean isValueNull(int i)
+long getRetainedSizeInBytes()
+Block getValueBlock()
+int getValueBlockIndex(int i)
}
class ArrayUnionSumState {
<<interface>>
+ArrayUnionSumResult get()
+void set(ArrayUnionSumResult value)
+void addMemoryUsage(long memory)
+Type getElementType()
+Adder getAdder()
}
class ArrayUnionSumStateFactory {
+ArrayUnionSumStateFactory(Type elementType)
+ArrayUnionSumState createSingleState()
+Class~? extends ArrayUnionSumState~ getSingleStateClass()
+ArrayUnionSumState createGroupedState()
+Class~? extends ArrayUnionSumState~ getGroupedStateClass()
-Type elementType
-Adder adder
-static Adder LONG_ADDER
-static Adder DOUBLE_ADDER
-static Adder FLOAT_ADDER
-static Adder getAdder(Type type)
}
class ArrayUnionSumStateFactory_GroupedState {
-Type elementType
-Adder adder
-ObjectBigArray~ArrayUnionSumResult~ results
-long size
+GroupedState(Type elementType, Adder adder)
+void ensureCapacity(long size)
+ArrayUnionSumResult get()
+void set(ArrayUnionSumResult value)
+void addMemoryUsage(long memory)
+Type getElementType()
+long getEstimatedSize()
+Adder getAdder()
}
class ArrayUnionSumStateFactory_SingleState {
-Type elementType
-Adder adder
-ArrayUnionSumResult result
+SingleState(Type elementType, Adder adder)
+ArrayUnionSumResult get()
+void set(ArrayUnionSumResult value)
+void addMemoryUsage(long memory)
+Type getElementType()
+long getEstimatedSize()
+Adder getAdder()
}
class ArrayUnionSumStateSerializer {
-ArrayType arrayType
+ArrayUnionSumStateSerializer(ArrayType arrayType)
+Type getSerializedType()
+void serialize(ArrayUnionSumState state, BlockBuilder out)
+void deserialize(Block block, int index, ArrayUnionSumState state)
}
class Adder {
<<interface>>
+void writeSum(Type type, Block block1, int position1, Block block2, int position2, BlockBuilder blockBuilder)
}
class SqlAggregationFunction
class AccumulatorStateFactory
class AccumulatorStateSerializer
class AccumulatorState
class AbstractGroupedAccumulatorState
class ArrayType
class Type
class Block
class BlockBuilder
ArrayUnionSumAggregation --> ArrayUnionSumState : uses
ArrayUnionSumAggregation --> ArrayUnionSumStateFactory : creates
ArrayUnionSumAggregation --> ArrayUnionSumStateSerializer : uses
ArrayUnionSumAggregation --> ArrayUnionSumResult : uses
ArrayUnionSumAggregation --|> SqlAggregationFunction
ArrayUnionSumResult <|-- ArrayUnionSumResult_SingleArrayBlock
ArrayUnionSumResult <|-- ArrayUnionSumResult_AccumulatedValues
ArrayUnionSumResult --> Adder : uses
ArrayUnionSumResult --> Type : elementType
ArrayUnionSumResult --> Block : holds
ArrayUnionSumState <|.. ArrayUnionSumStateFactory_GroupedState
ArrayUnionSumState <|.. ArrayUnionSumStateFactory_SingleState
ArrayUnionSumStateFactory --|> AccumulatorStateFactory
ArrayUnionSumStateFactory --> ArrayUnionSumStateFactory_GroupedState : creates
ArrayUnionSumStateFactory --> ArrayUnionSumStateFactory_SingleState : creates
ArrayUnionSumStateFactory --> Adder : configures
ArrayUnionSumStateFactory_GroupedState --|> AbstractGroupedAccumulatorState
ArrayUnionSumStateFactory_GroupedState --> ArrayUnionSumResult : stores
ArrayUnionSumStateFactory_SingleState --> ArrayUnionSumResult : stores
ArrayUnionSumStateSerializer --|> AccumulatorStateSerializer
ArrayUnionSumStateSerializer --> ArrayUnionSumResult : reconstructs
ArrayUnionSumStateSerializer --> ArrayType : uses
ArrayUnionSumState --|> AccumulatorState
ArrayType --> Type
Block --> Type
BlockBuilder --> Type
File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
Adds a new SQL aggregation function that combines arrays by summing values at corresponding indices. Features: - Result array length is the maximum of all input arrays - Missing elements treated as 0 - Null values coalesced to 0 - Supports BIGINT, INTEGER, SMALLINT, TINYINT, DOUBLE, and REAL Files: - ArrayUnionSumAggregation.java - Main aggregation logic - ArrayUnionSumResult.java - Result container with union-sum logic - ArrayUnionSumState*.java - State management for aggregation - TestArrayUnionSumResult.java - Unit tests (14 test cases) - Updated docs and integration tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey - I've found 4 issues, and left some high level feedback:
- In
ArrayUnionSumAggregation.combine, callingstate.get().unionSum(otherState.get())whenotherState.get()is null will throw; add an early return whenotherState.get() == nullto mirror typical combine semantics and avoid NPEs. - The
inputfunction currently treats every row as having a non-null array block; if the contract is that NULL arrays should be skipped (as described in the PR), add an explicit check to ignore null input values rather than creating anArrayUnionSumResultfor them.
Prompt for AI Agents
Please address the comments from this code review:
## Overall Comments
- In `ArrayUnionSumAggregation.combine`, calling `state.get().unionSum(otherState.get())` when `otherState.get()` is null will throw; add an early return when `otherState.get() == null` to mirror typical combine semantics and avoid NPEs.
- The `input` function currently treats every row as having a non-null array block; if the contract is that NULL arrays should be skipped (as described in the PR), add an explicit check to ignore null input values rather than creating an `ArrayUnionSumResult` for them.
## Individual Comments
### Comment 1
<location> `presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/ArrayUnionSumAggregation.java:130-135` </location>
<code_context>
+ arrayUnionSumResult = ArrayUnionSumResult.create(elementType, state.getAdder(), arrayBlock);
+ state.set(arrayUnionSumResult);
+ }
+ else {
+ startSize = arrayUnionSumResult.getRetainedSizeInBytes();
+ state.set(state.get().unionSum(arrayBlock));
+ }
+
+ state.addMemoryUsage(arrayUnionSumResult.getRetainedSizeInBytes() - startSize);
+ }
+
</code_context>
<issue_to_address>
**issue (bug_risk):** Memory accounting in input() uses the old result size instead of the new one
In the `else` branch, `arrayUnionSumResult` still points to the pre-union result. After `state.set(state.get().unionSum(arrayBlock));`, the new size is only available via `state.get()`, so using `arrayUnionSumResult.getRetainedSizeInBytes()` keeps using the old size and the delta is always 0.
You could fix this by capturing the current state and its size, creating the new result, then updating memory based on the new result, e.g.:
```java
else {
ArrayUnionSumResult current = state.get();
startSize = current.getRetainedSizeInBytes();
ArrayUnionSumResult newResult = current.unionSum(arrayBlock);
state.set(newResult);
state.addMemoryUsage(newResult.getRetainedSizeInBytes() - startSize);
}
```
or by recomputing the size from `state.get()` after the `set`.
</issue_to_address>
### Comment 2
<location> `presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/ArrayUnionSumAggregation.java:140-142` </location>
<code_context>
+
+ public static void combine(ArrayUnionSumState state, ArrayUnionSumState otherState)
+ {
+ if (state.get() == null) {
+ state.set(otherState.get());
+ return;
+ }
+
</code_context>
<issue_to_address>
**issue (bug_risk):** combine() does not update memory usage when the target state is initially null
When `state.get() == null`, you assign `otherState.get()` and return without updating memory usage, so combining into an empty state under-reports the memory of the incoming result. You should add the retained size of `otherState.get()` in this branch (as you do in the non-null case) to keep memory accounting consistent.
</issue_to_address>
### Comment 3
<location> `presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java:6339` </location>
<code_context>
+ ImmutableList.of(),
+ parseTypeSignature("array<T>"),
+ ImmutableList.of(parseTypeSignature("array<T>")));
+ }
+
+ @Override
</code_context>
<issue_to_address>
**suggestion (testing):** Consider adding query tests for NULL arrays and empty arrays to cover documented behavior
`testArrayUnionSum` currently only covers arrays with elements and element-level NULLs. Please also add tests that:
- Mix `CAST(NULL AS array<bigint>)` with non-null arrays and assert that NULL-array rows are skipped, producing the same result as aggregating only non-null arrays.
- Mix `CAST(array[] AS array<bigint>)` with non-empty arrays and assert that empty arrays don’t affect the result, and that queries with only empty arrays succeed without errors.
This will exercise the documented semantics for NULL and empty arrays.
</issue_to_address>
### Comment 4
<location> `presto-main-base/src/test/java/com/facebook/presto/operator/aggregation/TestArrayUnionSumResult.java:40` </location>
<code_context>
+ private static final ArrayType ARRAY_DOUBLE = new ArrayType(DOUBLE);
+ private static final ArrayType ARRAY_REAL = new ArrayType(REAL);
+
+ @Test
+ public void testBasicUnionSum()
+ {
</code_context>
<issue_to_address>
**suggestion (testing):** Add a unit test for handling empty input arrays in ArrayUnionSumResult
Current tests don’t cover the case where `positionCount = 0`. Please add a test that builds an empty `Block` (BIGINT, and optionally other types), constructs an `ArrayUnionSumResult` from it, asserts `size() == 0`, verifies serialization produces an array block with 0 elements, and (optionally) unions this empty result/block with a non-empty one to confirm the non-empty result is preserved. This guards against implementations that assume `size() > 0` and validates empty-result serialization behavior.
Suggested implementation:
```java
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertNotNull;
import static org.testng.Assert.assertNull;
import static org.testng.Assert.assertTrue;
```
```java
public class TestArrayUnionSumResult
{
private static final ArrayType ARRAY_BIGINT = new ArrayType(BIGINT);
private static final ArrayType ARRAY_DOUBLE = new ArrayType(DOUBLE);
private static final ArrayType ARRAY_REAL = new ArrayType(REAL);
@Test
public void testEmptyInputArrays()
{
// Build an empty BIGINT array block
BlockBuilder arrayBlockBuilder = ARRAY_BIGINT.createBlockBuilder(null, 0);
Block emptyArrayBlock = arrayBlockBuilder.build();
// Create an ArrayUnionSumState and update it with the empty block
ArrayUnionSumState state = new ArrayUnionSumStateFactory(BIGINT).createSingleState();
state.setResult(new ArrayUnionSumResult(ARRAY_BIGINT, BIGINT, emptyArrayBlock));
ArrayUnionSumResult result = state.getResult();
// Verify empty result
assertNotNull(result);
assertEquals(result.size(), 0);
// Verify serialization produces an empty array block
Block serialized = result.serialize();
assertEquals(serialized.getPositionCount(), 0);
// Build a non-empty BIGINT array block containing [1, 2]
BlockBuilder nonEmptyArrayBlockBuilder = ARRAY_BIGINT.createBlockBuilder(null, 1);
BlockBuilder elementBlockBuilder = BIGINT.createBlockBuilder(null, 2);
BIGINT.writeLong(elementBlockBuilder, 1L);
BIGINT.writeLong(elementBlockBuilder, 2L);
nonEmptyArrayBlockBuilder.writeObject(elementBlockBuilder.build()).closeEntry();
Block nonEmptyArrayBlock = nonEmptyArrayBlockBuilder.build();
ArrayUnionSumState nonEmptyState = new ArrayUnionSumStateFactory(BIGINT).createSingleState();
nonEmptyState.setResult(new ArrayUnionSumResult(ARRAY_BIGINT, BIGINT, nonEmptyArrayBlock));
ArrayUnionSumResult nonEmptyResult = nonEmptyState.getResult();
// Union empty with non-empty and verify non-empty is preserved
ArrayUnionSumResult unionResult = result.union(nonEmptyResult);
assertEquals(unionResult.size(), nonEmptyResult.size());
Block unionSerialized = unionResult.serialize();
assertEquals(unionSerialized.getPositionCount(), nonEmptyArrayBlock.getPositionCount());
}
@Test
public void testBasicUnionSum()
{
ArrayUnionSumState state = new ArrayUnionSumStateFactory(BIGINT).createSingleState();
// Create array [1, 2, 3]
```
The above test assumes the following APIs exist and are accessible in this test:
1. `Block` and `BlockBuilder` are already imported (typically from `com.facebook.presto.spi.block`).
2. `ArrayUnionSumResult` has:
- A constructor `ArrayUnionSumResult(ArrayType arrayType, Type elementType, Block arrayBlock)`.
- Methods `int size()`, `Block serialize()`, and `ArrayUnionSumResult union(ArrayUnionSumResult other)`.
3. `ArrayUnionSumState` has `setResult(ArrayUnionSumResult)` and `getResult()` methods.
4. `ArrayUnionSumStateFactory` has a constructor `ArrayUnionSumStateFactory(Type elementType)` and a `createSingleState()` method.
If the actual APIs differ slightly, adjust the constructor and method calls in `testEmptyInputArrays()` to match the real `ArrayUnionSumResult` / state APIs while preserving the test semantics:
- Construct a state/result from an empty input array.
- Assert `size() == 0`.
- Assert the serialized block has `positionCount == 0`.
- Union the empty result with a non-empty one and assert the non-empty union behavior.
</issue_to_address>Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
| else { | ||
| startSize = arrayUnionSumResult.getRetainedSizeInBytes(); | ||
| state.set(state.get().unionSum(arrayBlock)); | ||
| } | ||
|
|
||
| state.addMemoryUsage(arrayUnionSumResult.getRetainedSizeInBytes() - startSize); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue (bug_risk): Memory accounting in input() uses the old result size instead of the new one
In the else branch, arrayUnionSumResult still points to the pre-union result. After state.set(state.get().unionSum(arrayBlock));, the new size is only available via state.get(), so using arrayUnionSumResult.getRetainedSizeInBytes() keeps using the old size and the delta is always 0.
You could fix this by capturing the current state and its size, creating the new result, then updating memory based on the new result, e.g.:
else {
ArrayUnionSumResult current = state.get();
startSize = current.getRetainedSizeInBytes();
ArrayUnionSumResult newResult = current.unionSum(arrayBlock);
state.set(newResult);
state.addMemoryUsage(newResult.getRetainedSizeInBytes() - startSize);
}or by recomputing the size from state.get() after the set.
| if (state.get() == null) { | ||
| state.set(otherState.get()); | ||
| return; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue (bug_risk): combine() does not update memory usage when the target state is initially null
When state.get() == null, you assign otherState.get() and return without updating memory usage, so combining into an empty state under-reports the memory of the incoming result. You should add the retained size of otherState.get() in this branch (as you do in the non-null case) to keep memory accounting consistent.
Description
Adds a new SQL aggregation function
array_union_sumthat combines multiple arrays by summing values at corresponding indices. This function is analogous tomap_union_sumbut operates on arrays instead of maps.Example usage:
Implementation details:
Motivation and Context
Users often need to aggregate arrays element-wise, such as:
Currently, this requires complex workarounds using unnest, zip, and manual reconstruction. The array_union_sum function provides a clean, efficient solution similar to the existing map_union_sum function.
Impact
Public API changes:
Behavior:
Supported types:
Performance:
Test Plan
Contributor checklist
Release Notes
== RELEASE NOTES ==
General
array_union_sumaggregation function that combines arrays by summing values at corresponding indices. Supports all non-decimal numeric types (BIGINT, INTEGER, DOUBLE, REAL). The result array length is the maximum of all input arrays, with missing elements treated as 0 and NULL values coalesced to 0.