Skip to content

Release hash aggregation memory on output #25879

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,27 @@ public byte[] getChunk(byte[] pointer, int pointerOffset)
return chunks.get(chunkIndex);
}

public void freeChunksBefore(byte[] pointer, int pointerOffset)
{
int chunkIndex = getChunkIndex(pointer, pointerOffset);
if (chunks.isEmpty()) {
verify(chunkIndex == 0);
return;
}
checkIndex(chunkIndex, chunks.size());
// Release any previous chunks until a null chunk is encountered, which means it and any previous
// batches have already been released
int releaseIndex = chunkIndex - 1;
while (releaseIndex >= 0) {
byte[] releaseChunk = chunks.set(releaseIndex, null);
if (releaseChunk == null) {
break;
}
chunksRetainedSizeInBytes -= releaseChunk.length;
releaseIndex--;
}
}

// growth factor for each chunk doubles up to 512KB, then increases by 1.5x for each chunk after that
private static long nextChunkSize(long previousChunkSize)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,13 @@ public int getGroupCount()
return nextGroupId;
}

@Override
public void startReleasingOutput()
{
dictionaryLookBack = null;
currentPageSizeInBytes = 0;
}

@Override
public void appendValuesTo(int groupId, PageBuilder pageBuilder)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,16 @@ public int getGroupCount()
return flatHash.size();
}

@Override
public void startReleasingOutput()
{
currentHashes = null;
dictionaryLookBack = null;
Arrays.fill(currentBlocks, null);
currentPageSizeInBytes = 0;
flatHash.startReleasingOutput();
}

@Override
public void appendValuesTo(int groupId, PageBuilder pageBuilder)
{
Expand Down Expand Up @@ -393,11 +403,10 @@ public AddDictionaryPageWork(Block[] blocks)
{
verify(canProcessDictionary(blocks), "invalid call to addDictionaryPage");
this.dictionaryBlock = (DictionaryBlock) blocks[0];

this.dictionaries = Arrays.stream(blocks)
.map(block -> (DictionaryBlock) block)
.map(DictionaryBlock::getDictionary)
.toArray(Block[]::new);
this.dictionaries = blocks;
for (int i = 0; i < dictionaries.length; i++) {
dictionaries[i] = ((DictionaryBlock) dictionaries[i]).getDictionary();
}
updateDictionaryLookBack(dictionaries[0]);
}

Expand Down Expand Up @@ -510,7 +519,7 @@ class GetNonDictionaryGroupIdsWork
public GetNonDictionaryGroupIdsWork(Block[] blocks)
{
this.blocks = blocks;
this.groupIds = new int[currentBlocks[0].getPositionCount()];
this.groupIds = new int[blocks[0].getPositionCount()];
}

@Override
Expand Down Expand Up @@ -620,13 +629,12 @@ public GetDictionaryGroupIdsWork(Block[] blocks)
verify(canProcessDictionary(blocks), "invalid call to processDictionary");

this.dictionaryBlock = (DictionaryBlock) blocks[0];
this.groupIds = new int[dictionaryBlock.getPositionCount()];

this.dictionaries = Arrays.stream(blocks)
.map(block -> (DictionaryBlock) block)
.map(DictionaryBlock::getDictionary)
.toArray(Block[]::new);
this.dictionaries = blocks;
for (int i = 0; i < dictionaries.length; i++) {
dictionaries[i] = ((DictionaryBlock) dictionaries[i]).getDictionary();
}
updateDictionaryLookBack(dictionaries[0]);
this.groupIds = new int[dictionaryBlock.getPositionCount()];
}

@Override
Expand Down
52 changes: 47 additions & 5 deletions core/trino-main/src/main/java/io/trino/operator/FlatHash.java
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ public FlatHash(FlatHash other)
this.mask = other.mask;
this.nextGroupId = other.nextGroupId;
this.maxFill = other.maxFill;
this.control = Arrays.copyOf(other.control, other.control.length);
this.groupIdsByHash = Arrays.copyOf(other.groupIdsByHash, other.groupIdsByHash.length);
this.control = other.control == null ? null : Arrays.copyOf(other.control, other.control.length);
this.groupIdsByHash = other.groupIdsByHash == null ? null : Arrays.copyOf(other.groupIdsByHash, other.groupIdsByHash.length);
this.fixedSizeRecords = Arrays.stream(other.fixedSizeRecords)
.map(fixedSizeRecords -> fixedSizeRecords == null ? null : Arrays.copyOf(fixedSizeRecords, fixedSizeRecords.length))
.toArray(byte[][]::new);
Expand Down Expand Up @@ -153,13 +153,36 @@ public int getCapacity()
return capacity;
}

/**
* Releases memory associated with the hash table which is no longer necessary to produce output. Subsequent
* calls to insert new elements are rejected, and calls to {@link FlatHash#appendTo(int, BlockBuilder[])} will
* incrementally release memory associated with prior groupId values assuming that the caller will only call into
* the method to produce output in a sequential fashion.
*/
public void startReleasingOutput()
{
if (isReleasingOutput()) {
throw new IllegalStateException("already releasing output");
}
control = null;
groupIdsByHash = null;
}

private boolean isReleasingOutput()
{
return control == null;
}

public long hashPosition(int groupId)
{
if (groupId < 0) {
throw new IllegalArgumentException("groupId is negative");
if (groupId < 0 || groupId >= nextGroupId) {
throw new IllegalArgumentException("groupId out of range: " + groupId);
}
byte[] fixedSizeRecords = getFixedSizeRecords(groupId);
int fixedRecordOffset = getFixedRecordOffset(groupId);
if (isReleasingOutput() && fixedSizeRecords == null) {
throw new IllegalStateException("groupId already released");
}
if (cacheHashValue) {
return (long) LONG_HANDLE.get(fixedSizeRecords, fixedRecordOffset);
}
Expand All @@ -182,7 +205,8 @@ public void appendTo(int groupId, BlockBuilder[] blockBuilders)
{
checkArgument(groupId < nextGroupId, "groupId out of range");

byte[] fixedSizeRecords = getFixedSizeRecords(groupId);
int recordGroupIndex = recordGroupIndexForGroupId(groupId);
byte[] fixedSizeRecords = this.fixedSizeRecords[recordGroupIndex];
int recordOffset = getFixedRecordOffset(groupId);

byte[] variableWidthChunk = null;
Expand All @@ -202,6 +226,18 @@ public void appendTo(int groupId, BlockBuilder[] blockBuilders)
if (hasPrecomputedHash) {
BIGINT.writeLong(blockBuilders[blockBuilders.length - 1], (long) LONG_HANDLE.get(fixedSizeRecords, recordOffset));
}
// Release memory from the previous fixed size records batch
if (isReleasingOutput() && recordOffset == 0 && recordGroupIndex > 0) {
byte[] releasedRecords = this.fixedSizeRecords[recordGroupIndex - 1];
this.fixedSizeRecords[recordGroupIndex - 1] = null;
if (releasedRecords == null) {
throw new IllegalStateException("already released previous record batch");
}
fixedRecordGroupsRetainedSize -= sizeOf(releasedRecords);
if (variableWidthData != null) {
variableWidthData.freeChunksBefore(fixedSizeRecords, recordOffset + variableWidthOffset);
}
}
}

public void computeHashes(Block[] blocks, long[] hashes, int offset, int length)
Expand Down Expand Up @@ -251,6 +287,9 @@ public int putIfAbsent(Block[] blocks, int position, long hash)

private int getIndex(Block[] blocks, int position, long hash)
{
if (isReleasingOutput()) {
throw new IllegalStateException("already releasing output");
}
byte hashPrefix = (byte) (hash & 0x7F | 0x80);
int bucket = bucket((int) (hash >> 7));

Expand Down Expand Up @@ -351,6 +390,9 @@ private void setControl(int index, byte hashPrefix)

public boolean ensureAvailableCapacity(int batchSize)
{
if (isReleasingOutput()) {
throw new IllegalStateException("already releasing output");
}
long requiredMaxFill = nextGroupId + batchSize;
if (requiredMaxFill >= maxFill) {
long minimumRequiredCapacity = (requiredMaxFill + 1) * 16 / 15;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,14 @@ static GroupByHash createGroupByHash(

void appendValuesTo(int groupId, PageBuilder pageBuilder);

/**
* Signals that no more entries will be inserted, and that only calls to {@link GroupByHash#appendValuesTo(int, PageBuilder)}
* with sequential groupId values will be observed after this point, allowing the implementation to potentially
* release memory associated with structures required for inserts or associated with values that have already been
* output.
*/
void startReleasingOutput();

Work<?> addPage(Page page);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ public void appendValuesTo(int groupId, PageBuilder pageBuilder)
throw new UnsupportedOperationException("NoChannelGroupByHash does not support appendValuesTo");
}

@Override
public void startReleasingOutput()
{
throw new UnsupportedOperationException("NoChannelGroupByHash does not support startReleasingOutput");
}

@Override
public Work<?> addPage(Page page)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,12 +248,18 @@ public WorkProcessor<Page> buildResult()
for (GroupedAggregator groupedAggregator : groupedAggregators) {
groupedAggregator.prepareFinal();
}
return buildResult(consecutiveGroupIds(), new PageBuilder(buildTypes()), false);
// Only incrementally release memory for final aggregations, since partial aggregations have a fixed
// memory limit and can be expected to fully flush and release their output quickly
boolean releaseMemoryOnOutput = !partial;
if (releaseMemoryOnOutput) {
groupByHash.startReleasingOutput();
}
return buildResult(consecutiveGroupIds(), new PageBuilder(buildTypes()), false, releaseMemoryOnOutput);
}

public WorkProcessor<Page> buildSpillResult()
{
return buildResult(hashSortedGroupIds(), new PageBuilder(buildSpillTypes()), true);
return buildResult(hashSortedGroupIds(), new PageBuilder(buildSpillTypes()), true, false);
}

public List<Type> buildSpillTypes()
Expand All @@ -273,7 +279,7 @@ public int getCapacity()
return groupByHash.getCapacity();
}

private WorkProcessor<Page> buildResult(IntIterator groupIds, PageBuilder pageBuilder, boolean appendRawHash)
private WorkProcessor<Page> buildResult(IntIterator groupIds, PageBuilder pageBuilder, boolean appendRawHash, boolean releaseMemoryOnOutput)
{
int rawHashIndex = groupByChannels.length + groupedAggregators.size();
return WorkProcessor.create(() -> {
Expand All @@ -300,6 +306,11 @@ private WorkProcessor<Page> buildResult(IntIterator groupIds, PageBuilder pageBu
}
}

// Update memory usage after producing each page of output
if (releaseMemoryOnOutput) {
updateMemory();
}

return ProcessState.ofResult(pageBuilder.build());
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ public int getGroupCount()
return maxGroupId + 1;
}

@Override
public void startReleasingOutput()
{
throw new UnsupportedOperationException("Not yet supported");
}

@Override
public void appendValuesTo(int groupId, PageBuilder pageBuilder)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import static io.trino.spi.type.VarcharType.VARCHAR;
import static io.trino.type.TypeTestUtils.getHashBlock;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

public class TestGroupByHash
{
Expand Down Expand Up @@ -333,6 +334,60 @@ public void testUpdateMemoryBigint()
assertThat(rehashCount.get()).isEqualTo(2 * BIGINT_EXPECTED_REHASH);
}

@Test
public void testReleaseMemoryOnOutput()
{
Type type = VARCHAR;
// values expands into multiple FlatGroupByHash fixed record groups
Block valuesBlock = createStringSequenceBlock(0, 1_000_000);

GroupByHash groupByHash = createGroupByHash(ImmutableList.of(type), selectGroupByHashMode(false, false, ImmutableList.of(type)), 10_000, false, new FlatHashStrategyCompiler(new TypeOperators()), () -> true);
assertThat(groupByHash.addPage(new Page(valuesBlock)).process()).isTrue();
assertThat(groupByHash.getGroupCount()).isEqualTo(valuesBlock.getPositionCount());

long memoryUsageAfterInput = groupByHash.getEstimatedSize();
groupByHash.startReleasingOutput();
// memory usage should have decreased from dropping the hash table
long memoryUsageAfterReleasingOutput = groupByHash.getEstimatedSize();
// single immediate release of memory for the control and groupId by hash values
assertThat(memoryUsageAfterReleasingOutput).isLessThan(memoryUsageAfterInput);

// no more inputs accepted after switching to releasing output
assertThatThrownBy(() -> groupByHash.addPage(new Page(valuesBlock)).process())
.isInstanceOf(IllegalStateException.class)
.hasMessage("already releasing output");
assertThatThrownBy(() -> groupByHash.startReleasingOutput())
.isInstanceOf(IllegalStateException.class)
.hasMessage("already releasing output");

PageBuilder pageBuilder = new PageBuilder(ImmutableList.of(type));
int groupId = 0;
// FlatGroupByHash first 1024 records are within the first record group
for (; groupId < 1024; groupId++) {
groupByHash.appendValuesTo(groupId, pageBuilder);
pageBuilder.declarePosition();
}
pageBuilder.build();
// No memory released yet after completing the first group
assertThat(groupByHash.getEstimatedSize()).isEqualTo(memoryUsageAfterReleasingOutput);

groupByHash.appendValuesTo(groupId++, pageBuilder);
pageBuilder.declarePosition();
// Memory released
long memoryUsageAfterFirstRelease = groupByHash.getEstimatedSize();
assertThat(memoryUsageAfterFirstRelease).isLessThan(memoryUsageAfterReleasingOutput);
assertThatThrownBy(() -> groupByHash.getRawHash(0))
.isInstanceOf(IllegalStateException.class)
.hasMessage("groupId already released");

for (; groupId < valuesBlock.getPositionCount(); groupId++) {
groupByHash.appendValuesTo(groupId, pageBuilder);
pageBuilder.declarePosition();
}
// More memory released
assertThat(groupByHash.getEstimatedSize()).isLessThan(memoryUsageAfterFirstRelease);
}

@Test
public void testMemoryReservationYield()
{
Expand Down