Skip to content

Commit 39b6cf1

Browse files
committed
responding to comments
1 parent 6ba98f5 commit 39b6cf1

File tree

2 files changed

+71
-62
lines changed
  • iceberg/openhouse/internalcatalog/src/main/java/com/linkedin/openhouse/internal/catalog
  • integrations/spark/spark-3.5/openhouse-spark-itest/src/test/java/com/linkedin/openhouse/spark/catalogtest

2 files changed

+71
-62
lines changed

iceberg/openhouse/internalcatalog/src/main/java/com/linkedin/openhouse/internal/catalog/SnapshotDiffApplier.java

Lines changed: 63 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,12 @@ public TableMetadata applySnapshots(TableMetadata base, TableMetadata metadata)
6868

6969
// Compute diff (all maps created once in constructor)
7070
SnapshotDiff diff =
71-
new SnapshotDiff(providedSnapshots, providedRefs, existingSnapshots, existingRefs);
71+
new SnapshotDiff(
72+
providedSnapshots, providedRefs, existingSnapshots, existingRefs, metadata);
7273

7374
// Validate, apply, record metrics, build
7475
diff.validate(base);
75-
TableMetadata.Builder builder = diff.applyTo(metadata);
76+
TableMetadata.Builder builder = diff.applyTo();
7677
diff.recordMetrics(builder);
7778
return builder.build();
7879
}
@@ -88,12 +89,14 @@ private class SnapshotDiff {
8889
private final Map<String, SnapshotRef> providedRefs;
8990
private final List<Snapshot> existingSnapshots;
9091
private final Map<String, SnapshotRef> existingRefs;
92+
private final TableMetadata metadata;
9193

9294
// Computed maps (created once)
93-
private final Map<Long, Snapshot> providedById;
94-
private final Map<Long, Snapshot> existingById;
95-
private final Set<Long> existingBranchIds;
96-
private final Set<Long> providedBranchIds;
95+
private final Map<Long, Snapshot> providedSnapshotByIds;
96+
private final Map<Long, Snapshot> existingSnapshotByIds;
97+
private final Set<Long> metadataSnapshotIds;
98+
private final Set<Long> existingBranchRefIds;
99+
private final Set<Long> providedBranchRefIds;
97100

98101
// Categorized snapshots
99102
private final List<Snapshot> wapSnapshots;
@@ -114,64 +117,80 @@ private class SnapshotDiff {
114117
List<Snapshot> providedSnapshots,
115118
Map<String, SnapshotRef> providedRefs,
116119
List<Snapshot> existingSnapshots,
117-
Map<String, SnapshotRef> existingRefs) {
120+
Map<String, SnapshotRef> existingRefs,
121+
TableMetadata metadata) {
118122
this.providedSnapshots = providedSnapshots;
119123
this.providedRefs = providedRefs;
120124
this.existingSnapshots = existingSnapshots;
121125
this.existingRefs = existingRefs;
126+
this.metadata = metadata;
122127

123128
// Compute all maps once
124-
this.providedById =
129+
this.providedSnapshotByIds =
125130
providedSnapshots.stream().collect(Collectors.toMap(Snapshot::snapshotId, s -> s));
126-
this.existingById =
131+
this.existingSnapshotByIds =
127132
existingSnapshots.stream().collect(Collectors.toMap(Snapshot::snapshotId, s -> s));
128-
this.existingBranchIds =
133+
this.metadataSnapshotIds =
134+
metadata.snapshots().stream().map(Snapshot::snapshotId).collect(Collectors.toSet());
135+
this.existingBranchRefIds =
129136
existingRefs.values().stream().map(SnapshotRef::snapshotId).collect(Collectors.toSet());
130-
this.providedBranchIds =
137+
this.providedBranchRefIds =
131138
providedRefs.values().stream().map(SnapshotRef::snapshotId).collect(Collectors.toSet());
132139

133-
// Compute categorization (order matters: cherry-picked filters WAP)
134-
List<Snapshot> initialWapSnapshots = computeWapSnapshots();
140+
// Compute categorization - process in dependency order
141+
// 1. Cherry-picked has highest priority (includes WAP being published)
142+
// 2. WAP snapshots (staged, not published)
143+
// 3. Regular snapshots (everything else)
135144
this.cherryPickedSnapshots = computeCherryPickedSnapshots();
136-
this.wapSnapshots = filterWapFromCherryPicked(initialWapSnapshots);
137-
this.regularSnapshots = computeRegularSnapshots();
145+
Set<Long> cherryPickedIds =
146+
cherryPickedSnapshots.stream().map(Snapshot::snapshotId).collect(Collectors.toSet());
147+
148+
this.wapSnapshots = computeWapSnapshots(cherryPickedIds);
149+
Set<Long> wapIds =
150+
wapSnapshots.stream().map(Snapshot::snapshotId).collect(Collectors.toSet());
151+
152+
this.regularSnapshots = computeRegularSnapshots(cherryPickedIds, wapIds);
138153

139154
// Compute changes
140155
this.newSnapshots =
141156
providedSnapshots.stream()
142-
.filter(s -> !existingById.containsKey(s.snapshotId()))
157+
.filter(s -> !existingSnapshotByIds.containsKey(s.snapshotId()))
143158
.collect(Collectors.toList());
144159
this.deletedSnapshots =
145160
existingSnapshots.stream()
146-
.filter(s -> !providedById.containsKey(s.snapshotId()))
161+
.filter(s -> !providedSnapshotByIds.containsKey(s.snapshotId()))
147162
.collect(Collectors.toList());
148163
this.branchUpdates = computeBranchUpdates();
149164
this.deletedIds =
150165
deletedSnapshots.stream().map(Snapshot::snapshotId).collect(Collectors.toSet());
151166
this.newRegularSnapshots =
152167
regularSnapshots.stream().filter(newSnapshots::contains).collect(Collectors.toList());
153168
this.staleRefs = Sets.difference(existingRefs.keySet(), providedRefs.keySet());
154-
this.existingAfterDeletionIds = Sets.difference(existingById.keySet(), deletedIds);
169+
this.existingAfterDeletionIds = Sets.difference(existingSnapshotByIds.keySet(), deletedIds);
155170
this.unreferencedNewSnapshots =
156171
providedSnapshots.stream()
157172
.filter(
158173
s ->
159174
!existingAfterDeletionIds.contains(s.snapshotId())
160-
&& !providedBranchIds.contains(s.snapshotId()))
175+
&& !providedBranchRefIds.contains(s.snapshotId())
176+
&& !metadataSnapshotIds.contains(s.snapshotId()))
161177
.collect(Collectors.toList());
162178
}
163179

164-
private List<Snapshot> computeWapSnapshots() {
165-
Set<Long> allBranchIds =
166-
java.util.stream.Stream.concat(existingBranchIds.stream(), providedBranchIds.stream())
180+
private List<Snapshot> computeWapSnapshots(Set<Long> excludeCherryPicked) {
181+
// Depends on: cherry-picked IDs (to exclude WAP snapshots being published)
182+
Set<Long> allBranchRefIds =
183+
java.util.stream.Stream.concat(
184+
existingBranchRefIds.stream(), providedBranchRefIds.stream())
167185
.collect(Collectors.toSet());
168186

169187
return providedSnapshots.stream()
188+
.filter(s -> !excludeCherryPicked.contains(s.snapshotId()))
170189
.filter(
171190
s ->
172191
s.summary() != null
173192
&& s.summary().containsKey(SnapshotSummary.STAGED_WAP_ID_PROP)
174-
&& !allBranchIds.contains(s.snapshotId()))
193+
&& !allBranchRefIds.contains(s.snapshotId()))
175194
.collect(Collectors.toList());
176195
}
177196

@@ -185,7 +204,7 @@ private List<Snapshot> computeCherryPickedSnapshots() {
185204
return providedSnapshots.stream()
186205
.filter(
187206
provided -> {
188-
Snapshot existing = existingById.get(provided.snapshotId());
207+
Snapshot existing = existingSnapshotByIds.get(provided.snapshotId());
189208
if (existing == null) {
190209
return false;
191210
}
@@ -204,30 +223,19 @@ private List<Snapshot> computeCherryPickedSnapshots() {
204223
boolean hasWapId =
205224
provided.summary() != null
206225
&& provided.summary().containsKey(SnapshotSummary.STAGED_WAP_ID_PROP);
207-
boolean wasStaged = !existingBranchIds.contains(provided.snapshotId());
208-
boolean isNowOnBranch = providedBranchIds.contains(provided.snapshotId());
226+
boolean wasStaged = !existingBranchRefIds.contains(provided.snapshotId());
227+
boolean isNowOnBranch = providedBranchRefIds.contains(provided.snapshotId());
209228
return hasWapId && wasStaged && isNowOnBranch;
210229
})
211230
.collect(Collectors.toList());
212231
}
213232

214-
private List<Snapshot> filterWapFromCherryPicked(List<Snapshot> initialWapSnapshots) {
215-
Set<Long> cherryPickedIds =
216-
cherryPickedSnapshots.stream().map(Snapshot::snapshotId).collect(Collectors.toSet());
217-
return initialWapSnapshots.stream()
218-
.filter(s -> !cherryPickedIds.contains(s.snapshotId()))
219-
.collect(Collectors.toList());
220-
}
221-
222-
private List<Snapshot> computeRegularSnapshots() {
223-
Set<Long> excludedIds =
224-
java.util.stream.Stream.concat(
225-
wapSnapshots.stream().map(Snapshot::snapshotId),
226-
cherryPickedSnapshots.stream().map(Snapshot::snapshotId))
227-
.collect(Collectors.toSet());
228-
233+
private List<Snapshot> computeRegularSnapshots(
234+
Set<Long> excludeCherryPicked, Set<Long> excludeWap) {
235+
// Depends on: cherry-picked and WAP IDs (everything else is regular)
229236
return providedSnapshots.stream()
230-
.filter(s -> !excludedIds.contains(s.snapshotId()))
237+
.filter(s -> !excludeCherryPicked.contains(s.snapshotId()))
238+
.filter(s -> !excludeWap.contains(s.snapshotId()))
231239
.collect(Collectors.toList());
232240
}
233241

@@ -349,7 +357,7 @@ private void validateDeletedSnapshotsNotReferenced() {
349357
}
350358
}
351359

352-
TableMetadata.Builder applyTo(TableMetadata metadata) {
360+
TableMetadata.Builder applyTo() {
353361
TableMetadata.Builder builder = TableMetadata.buildFrom(metadata);
354362

355363
// Remove deleted snapshots
@@ -366,15 +374,20 @@ TableMetadata.Builder applyTo(TableMetadata metadata) {
366374
// Set branch pointers
367375
providedRefs.forEach(
368376
(branchName, ref) -> {
369-
Snapshot snapshot = providedById.get(ref.snapshotId());
377+
Snapshot snapshot = providedSnapshotByIds.get(ref.snapshotId());
370378
if (snapshot == null) {
371379
throw new InvalidIcebergSnapshotException(
372380
String.format(
373381
"Branch %s references non-existent snapshot %s",
374382
branchName, ref.snapshotId()));
375383
}
376384

377-
if (existingAfterDeletionIds.contains(snapshot.snapshotId())) {
385+
// Check if snapshot is already in metadata (after deletions)
386+
boolean snapshotExistsInMetadata =
387+
metadataSnapshotIds.contains(snapshot.snapshotId())
388+
&& !deletedIds.contains(snapshot.snapshotId());
389+
390+
if (snapshotExistsInMetadata) {
378391
SnapshotRef existingRef = metadata.refs().get(branchName);
379392
if (existingRef == null || existingRef.snapshotId() != ref.snapshotId()) {
380393
builder.setRef(branchName, ref);
@@ -391,7 +404,7 @@ void recordMetrics(TableMetadata.Builder builder) {
391404
int appendedCount =
392405
(int)
393406
regularSnapshots.stream()
394-
.filter(s -> !existingById.containsKey(s.snapshotId()))
407+
.filter(s -> !existingSnapshotByIds.containsKey(s.snapshotId()))
395408
.count();
396409
int stagedCount = wapSnapshots.size();
397410
int cherryPickedCount = cherryPickedSnapshots.size();
@@ -451,16 +464,9 @@ void recordMetrics(TableMetadata.Builder builder) {
451464
* @return Comma-separated string of snapshot IDs, or empty string if list is empty
452465
*/
453466
private String formatSnapshotIds(List<Snapshot> snapshots) {
454-
if (snapshots.isEmpty()) {
455-
return "";
456-
}
457-
StringBuilder sb = new StringBuilder();
458-
for (int i = 0; i < snapshots.size(); i++) {
459-
if (i > 0) {
460-
sb.append(',');
461-
}
462-
sb.append(snapshots.get(i).snapshotId());
463-
}
464-
return sb.toString();
467+
return snapshots.stream()
468+
.map(Snapshot::snapshotId)
469+
.map(String::valueOf)
470+
.collect(Collectors.joining(","));
465471
}
466472
}

integrations/spark/spark-3.5/openhouse-spark-itest/src/test/java/com/linkedin/openhouse/spark/catalogtest/BranchTestSpark3_5.java

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,9 @@ public void testBasicBranchOperations() throws Exception {
103103
List<Row> refs =
104104
spark.sql("SELECT name FROM " + tableName + ".refs ORDER BY name").collectAsList();
105105
assertEquals(2, refs.size());
106-
assertEquals("feature_a", refs.get(0).getString(0));
107-
assertEquals("main", refs.get(1).getString(0));
106+
Set<String> refNames = refs.stream().map(row -> row.getString(0)).collect(Collectors.toSet());
107+
assertTrue(refNames.contains("feature_a"));
108+
assertTrue(refNames.contains("main"));
108109
}
109110
}
110111

@@ -2004,7 +2005,8 @@ public void testBackwardCompatibilityMainBranchOnly() throws Exception {
20042005
assertEquals(3, spark.sql("SELECT * FROM " + tableName + "").collectAsList().size());
20052006
List<Row> refs = spark.sql("SELECT name FROM " + tableName + ".refs").collectAsList();
20062007
assertEquals(1, refs.size());
2007-
assertEquals("main", refs.get(0).getString(0));
2008+
Set<String> refNames = refs.stream().map(row -> row.getString(0)).collect(Collectors.toSet());
2009+
assertTrue(refNames.contains("main"));
20082010

20092011
// Traditional snapshot queries should work
20102012
assertTrue(
@@ -2279,8 +2281,9 @@ public void testErrorInsertToNonExistentBranch() throws Exception {
22792281
List<Row> refs =
22802282
spark.sql("SELECT name FROM " + tableName + ".refs ORDER BY name").collectAsList();
22812283
assertEquals(2, refs.size());
2282-
assertEquals("feature_a", refs.get(0).getString(0));
2283-
assertEquals("main", refs.get(1).getString(0));
2284+
Set<String> refNames = refs.stream().map(row -> row.getString(0)).collect(Collectors.toSet());
2285+
assertTrue(refNames.contains("feature_a"));
2286+
assertTrue(refNames.contains("main"));
22842287
}
22852288
}
22862289

0 commit comments

Comments
 (0)