Skip to content

Commit 9c34efb

Browse files
committed
Release hash aggregation memory on output
Incrementally releases memory from FlatGroupByHash when HashAggregationOperator starts producing output.
1 parent 8f72331 commit 9c34efb

File tree

9 files changed

+149
-6
lines changed

9 files changed

+149
-6
lines changed

core/trino-main/src/main/java/io/trino/operator/AppendOnlyVariableWidthData.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,27 @@ public byte[] getChunk(byte[] pointer, int pointerOffset)
138138
return chunks.get(chunkIndex);
139139
}
140140

141+
public void freeChunksBefore(byte[] pointer, int pointerOffset)
142+
{
143+
int chunkIndex = getChunkIndex(pointer, pointerOffset);
144+
if (chunks.isEmpty()) {
145+
verify(chunkIndex == 0);
146+
return;
147+
}
148+
checkIndex(chunkIndex, chunks.size());
149+
// Release any previous chunks until a null chunk is encountered, which means it and any previous
150+
// batches have already been released
151+
int releaseIndex = chunkIndex - 1;
152+
while (releaseIndex >= 0) {
153+
byte[] releaseChunk = chunks.set(releaseIndex, null);
154+
if (releaseChunk == null) {
155+
break;
156+
}
157+
chunksRetainedSizeInBytes -= releaseChunk.length;
158+
releaseIndex--;
159+
}
160+
}
161+
141162
// growth factor for each chunk doubles up to 512KB, then increases by 1.5x for each chunk after that
142163
private static long nextChunkSize(long previousChunkSize)
143164
{

core/trino-main/src/main/java/io/trino/operator/BigintGroupByHash.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,12 @@ public int getGroupCount()
126126
return nextGroupId;
127127
}
128128

129+
@Override
130+
public void startReleasingOutput()
131+
{
132+
// NOOP
133+
}
134+
129135
@Override
130136
public void appendValuesTo(int groupId, PageBuilder pageBuilder)
131137
{

core/trino-main/src/main/java/io/trino/operator/FlatGroupByHash.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,12 @@ public int getGroupCount()
129129
return flatHash.size();
130130
}
131131

132+
@Override
133+
public void startReleasingOutput()
134+
{
135+
flatHash.startReleasingOutput();
136+
}
137+
132138
@Override
133139
public void appendValuesTo(int groupId, PageBuilder pageBuilder)
134140
{

core/trino-main/src/main/java/io/trino/operator/FlatHash.java

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,8 @@ public FlatHash(FlatHash other)
124124
this.mask = other.mask;
125125
this.nextGroupId = other.nextGroupId;
126126
this.maxFill = other.maxFill;
127-
this.control = Arrays.copyOf(other.control, other.control.length);
128-
this.groupIdsByHash = Arrays.copyOf(other.groupIdsByHash, other.groupIdsByHash.length);
127+
this.control = other.control == null ? null : Arrays.copyOf(other.control, other.control.length);
128+
this.groupIdsByHash = other.groupIdsByHash == null ? null : Arrays.copyOf(other.groupIdsByHash, other.groupIdsByHash.length);
129129
this.fixedSizeRecords = Arrays.stream(other.fixedSizeRecords)
130130
.map(fixedSizeRecords -> fixedSizeRecords == null ? null : Arrays.copyOf(fixedSizeRecords, fixedSizeRecords.length))
131131
.toArray(byte[][]::new);
@@ -153,11 +153,28 @@ public int getCapacity()
153153
return capacity;
154154
}
155155

156+
public void startReleasingOutput()
157+
{
158+
if (isReleasingOutput()) {
159+
throw new IllegalStateException("already releasing output");
160+
}
161+
control = null;
162+
groupIdsByHash = null;
163+
}
164+
165+
public boolean isReleasingOutput()
166+
{
167+
return control == null;
168+
}
169+
156170
public long hashPosition(int groupId)
157171
{
158172
if (groupId < 0) {
159173
throw new IllegalArgumentException("groupId is negative");
160174
}
175+
if (isReleasingOutput()) {
176+
throw new IllegalStateException("already releasing output");
177+
}
161178
byte[] fixedSizeRecords = getFixedSizeRecords(groupId);
162179
int fixedRecordOffset = getFixedRecordOffset(groupId);
163180
if (cacheHashValue) {
@@ -182,7 +199,8 @@ public void appendTo(int groupId, BlockBuilder[] blockBuilders)
182199
{
183200
checkArgument(groupId < nextGroupId, "groupId out of range");
184201

185-
byte[] fixedSizeRecords = getFixedSizeRecords(groupId);
202+
int recordGroupIndex = recordGroupIndexForGroupId(groupId);
203+
byte[] fixedSizeRecords = this.fixedSizeRecords[recordGroupIndex];
186204
int recordOffset = getFixedRecordOffset(groupId);
187205

188206
byte[] variableWidthChunk = null;
@@ -202,6 +220,18 @@ public void appendTo(int groupId, BlockBuilder[] blockBuilders)
202220
if (hasPrecomputedHash) {
203221
BIGINT.writeLong(blockBuilders[blockBuilders.length - 1], (long) LONG_HANDLE.get(fixedSizeRecords, recordOffset));
204222
}
223+
// Release memory from the previous fixed size records batch
224+
if (isReleasingOutput() && recordOffset == 0 && recordGroupIndex > 0) {
225+
byte[] releasedRecords = this.fixedSizeRecords[recordGroupIndex - 1];
226+
this.fixedSizeRecords[recordGroupIndex - 1] = null;
227+
if (releasedRecords == null) {
228+
throw new IllegalStateException("already released previous record batch");
229+
}
230+
fixedRecordGroupsRetainedSize -= sizeOf(releasedRecords);
231+
if (variableWidthData != null) {
232+
variableWidthData.freeChunksBefore(fixedSizeRecords, recordOffset + variableWidthOffset);
233+
}
234+
}
205235
}
206236

207237
public void computeHashes(Block[] blocks, long[] hashes, int offset, int length)
@@ -251,6 +281,9 @@ public int putIfAbsent(Block[] blocks, int position, long hash)
251281

252282
private int getIndex(Block[] blocks, int position, long hash)
253283
{
284+
if (isReleasingOutput()) {
285+
throw new IllegalStateException("already releasing output");
286+
}
254287
byte hashPrefix = (byte) (hash & 0x7F | 0x80);
255288
int bucket = bucket((int) (hash >> 7));
256289

@@ -351,6 +384,9 @@ private void setControl(int index, byte hashPrefix)
351384

352385
public boolean ensureAvailableCapacity(int batchSize)
353386
{
387+
if (isReleasingOutput()) {
388+
throw new IllegalStateException("already releasing output");
389+
}
354390
long requiredMaxFill = nextGroupId + batchSize;
355391
if (requiredMaxFill >= maxFill) {
356392
long minimumRequiredCapacity = (requiredMaxFill + 1) * 16 / 15;

core/trino-main/src/main/java/io/trino/operator/GroupByHash.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ static GroupByHash createGroupByHash(
109109

110110
void appendValuesTo(int groupId, PageBuilder pageBuilder);
111111

112+
void startReleasingOutput();
113+
112114
Work<?> addPage(Page page);
113115

114116
/**

core/trino-main/src/main/java/io/trino/operator/NoChannelGroupByHash.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ public void appendValuesTo(int groupId, PageBuilder pageBuilder)
5050
throw new UnsupportedOperationException("NoChannelGroupByHash does not support appendValuesTo");
5151
}
5252

53+
@Override
54+
public void startReleasingOutput()
55+
{
56+
throw new UnsupportedOperationException("NoChannelGroupByHash does not support startReleasingOutput");
57+
}
58+
5359
@Override
5460
public Work<?> addPage(Page page)
5561
{

core/trino-main/src/main/java/io/trino/operator/aggregation/builder/InMemoryHashAggregationBuilder.java

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -248,12 +248,18 @@ public WorkProcessor<Page> buildResult()
248248
for (GroupedAggregator groupedAggregator : groupedAggregators) {
249249
groupedAggregator.prepareFinal();
250250
}
251-
return buildResult(consecutiveGroupIds());
251+
// Only incrementally release memory for final aggregations, since partial aggregations have a fixed
252+
// memory limit and can be expected to fully flush and release their output quickly
253+
boolean releaseOutputMemory = !partial;
254+
if (releaseOutputMemory) {
255+
groupByHash.startReleasingOutput();
256+
}
257+
return buildResult(consecutiveGroupIds(), releaseOutputMemory);
252258
}
253259

254260
public WorkProcessor<Page> buildHashSortedResult()
255261
{
256-
return buildResult(hashSortedGroupIds());
262+
return buildResult(hashSortedGroupIds(), false);
257263
}
258264

259265
public List<Type> buildSpillTypes()
@@ -271,7 +277,7 @@ public int getCapacity()
271277
return groupByHash.getCapacity();
272278
}
273279

274-
private WorkProcessor<Page> buildResult(IntIterator groupIds)
280+
private WorkProcessor<Page> buildResult(IntIterator groupIds, boolean releaseMemoryOnOutput)
275281
{
276282
PageBuilder pageBuilder = new PageBuilder(buildTypes());
277283
return WorkProcessor.create(() -> {
@@ -294,6 +300,11 @@ private WorkProcessor<Page> buildResult(IntIterator groupIds)
294300
}
295301
}
296302

303+
// Update memory usage after producing each page of output
304+
if (releaseMemoryOnOutput) {
305+
updateMemory();
306+
}
307+
297308
return ProcessState.ofResult(pageBuilder.build());
298309
});
299310
}

core/trino-main/src/test/java/io/trino/operator/CyclingGroupByHash.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,12 @@ public int getGroupCount()
5454
return maxGroupId + 1;
5555
}
5656

57+
@Override
58+
public void startReleasingOutput()
59+
{
60+
throw new UnsupportedOperationException("Not yet supported");
61+
}
62+
5763
@Override
5864
public void appendValuesTo(int groupId, PageBuilder pageBuilder)
5965
{

core/trino-main/src/test/java/io/trino/operator/TestGroupByHash.java

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
import static io.trino.spi.type.VarcharType.VARCHAR;
4747
import static io.trino.type.TypeTestUtils.getHashBlock;
4848
import static org.assertj.core.api.Assertions.assertThat;
49+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
4950

5051
public class TestGroupByHash
5152
{
@@ -333,6 +334,54 @@ public void testUpdateMemoryBigint()
333334
assertThat(rehashCount.get()).isEqualTo(2 * BIGINT_EXPECTED_REHASH);
334335
}
335336

337+
@Test
338+
public void testReleaseMemoryOnOutput()
339+
{
340+
Type type = VARCHAR;
341+
// values expands into multiple FlatGroupByHash fixed record groups
342+
Block valuesBlock = createStringSequenceBlock(0, 1_000_000);
343+
344+
GroupByHash groupByHash = createGroupByHash(ImmutableList.of(type), selectGroupByHashMode(false, false, ImmutableList.of(type)), 10_000, false, new FlatHashStrategyCompiler(new TypeOperators()), () -> true);
345+
assertThat(groupByHash.addPage(new Page(valuesBlock)).process()).isTrue();
346+
assertThat(groupByHash.getGroupCount()).isEqualTo(valuesBlock.getPositionCount());
347+
348+
long memoryUsageAfterInput = groupByHash.getEstimatedSize();
349+
groupByHash.startReleasingOutput();
350+
// memory usage should have decreased from dropping the hash table
351+
long memoryUsageAfterReleasingOutput = groupByHash.getEstimatedSize();
352+
// single immediate release of memory for the control and groupId by hash values
353+
assertThat(memoryUsageAfterReleasingOutput).isLessThan(memoryUsageAfterInput);
354+
355+
// no more inputs accepted after switching to releasing output
356+
assertThatThrownBy(() -> groupByHash.addPage(new Page(valuesBlock)).process())
357+
.isInstanceOf(IllegalStateException.class)
358+
.hasMessage("already releasing output");
359+
360+
PageBuilder pageBuilder = new PageBuilder(ImmutableList.of(type));
361+
int groupId = 0;
362+
// FlatGroupByHash first 1024 records are within the first record group
363+
for (; groupId < 1024; groupId++) {
364+
groupByHash.appendValuesTo(groupId, pageBuilder);
365+
pageBuilder.declarePosition();
366+
}
367+
pageBuilder.build();
368+
// No memory released yet after completing the first group
369+
assertThat(groupByHash.getEstimatedSize()).isEqualTo(memoryUsageAfterReleasingOutput);
370+
371+
groupByHash.appendValuesTo(groupId++, pageBuilder);
372+
pageBuilder.declarePosition();
373+
// Memory released
374+
long memoryUsageAfterFirstRelease = groupByHash.getEstimatedSize();
375+
assertThat(memoryUsageAfterFirstRelease).isLessThan(memoryUsageAfterReleasingOutput);
376+
377+
for (; groupId < valuesBlock.getPositionCount(); groupId++) {
378+
groupByHash.appendValuesTo(groupId, pageBuilder);
379+
pageBuilder.declarePosition();
380+
}
381+
// More memory released
382+
assertThat(groupByHash.getEstimatedSize()).isLessThan(memoryUsageAfterFirstRelease);
383+
}
384+
336385
@Test
337386
public void testMemoryReservationYield()
338387
{

0 commit comments

Comments
 (0)