diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 9e9de52dee..1c2e8cf21e 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -115,11 +115,7 @@ update_table_metadata, ) from pyiceberg.table.update.schema import UpdateSchema -from pyiceberg.table.update.snapshot import ( - ManageSnapshots, - UpdateSnapshot, - _FastAppendFiles, -) +from pyiceberg.table.update.snapshot import ExpireSnapshots, ManageSnapshots, UpdateSnapshot, _FastAppendFiles from pyiceberg.table.update.spec import UpdateSpec from pyiceberg.table.update.statistics import UpdateStatistics from pyiceberg.transforms import IdentityTransform @@ -1079,6 +1075,15 @@ def manage_snapshots(self) -> ManageSnapshots: """ return ManageSnapshots(transaction=Transaction(self, autocommit=True)) + def expire_snapshots(self) -> ExpireSnapshots: + """ + Shorthand to run expire snapshots by id or by a timestamp. + + Use table.expire_snapshots().().commit() to run a specific operation. + Use table.expire_snapshots().().().commit() to run multiple operations. + """ + return ExpireSnapshots(transaction=Transaction(self, autocommit=True)) + def update_statistics(self) -> UpdateStatistics: """ Shorthand to run statistics management operations like add statistics and remove statistics. diff --git a/pyiceberg/table/update/snapshot.py b/pyiceberg/table/update/snapshot.py index b53c331758..275f1a56c9 100644 --- a/pyiceberg/table/update/snapshot.py +++ b/pyiceberg/table/update/snapshot.py @@ -55,6 +55,7 @@ from pyiceberg.partitioning import ( PartitionSpec, ) +from pyiceberg.table.refs import SnapshotRefType from pyiceberg.table.snapshots import ( Operation, Snapshot, @@ -66,6 +67,7 @@ AddSnapshotUpdate, AssertRefSnapshotId, RemoveSnapshotRefUpdate, + RemoveSnapshotsUpdate, SetSnapshotRefUpdate, TableRequirement, TableUpdate, @@ -739,6 +741,7 @@ class ManageSnapshots(UpdateTableMetadata["ManageSnapshots"]): ms.create_tag(snapshot_id1, "Tag_A").create_tag(snapshot_id2, "Tag_B") """ + _snapshot_ids_to_expire: Set[int] = set() _updates: Tuple[TableUpdate, ...] = () _requirements: Tuple[TableRequirement, ...] = () @@ -843,3 +846,69 @@ def remove_branch(self, branch_name: str) -> ManageSnapshots: This for method chaining """ return self._remove_ref_snapshot(ref_name=branch_name) + + +class ExpireSnapshots(UpdateTableMetadata["ExpireSnapshots"]): + """ + Expire snapshots by ID. + + Use table.expire_snapshots().().commit() to run a specific operation. + Use table.expire_snapshots().().().commit() to run multiple operations. + Pending changes are applied on commit. + """ + + _snapshot_ids_to_expire: Set[int] = set() + _updates: Tuple[TableUpdate, ...] = () + _requirements: Tuple[TableRequirement, ...] = () + + def _commit(self) -> UpdatesAndRequirements: + """ + Commit the staged updates and requirements. + + This will remove the snapshots with the given IDs. + + Returns: + Tuple of updates and requirements to be committed, + as required by the calling parent apply functions. + """ + update = RemoveSnapshotsUpdate(snapshot_ids=self._snapshot_ids_to_expire) + self._updates += (update,) + return self._updates, self._requirements + + def _get_protected_snapshot_ids(self) -> Set[int]: + """ + Get the IDs of protected snapshots. + + These are the HEAD snapshots of all branches and all tagged snapshots. These ids are to be excluded from expiration. + + Returns: + Set of protected snapshot IDs to exclude from expiration. + """ + protected_ids: Set[int] = set() + + for ref in self._transaction.table_metadata.refs.values(): + if ref.snapshot_ref_type in [SnapshotRefType.TAG, SnapshotRefType.BRANCH]: + protected_ids.add(ref.snapshot_id) + + return protected_ids + + def expire_snapshot_by_id(self, snapshot_id: int) -> ExpireSnapshots: + """ + Expire a snapshot by its ID. + + This will mark the snapshot for expiration. + + Args: + snapshot_id (int): The ID of the snapshot to expire. + Returns: + This for method chaining. + """ + if self._transaction.table_metadata.snapshot_by_id(snapshot_id) is None: + raise ValueError(f"Snapshot with ID {snapshot_id} does not exist.") + + if snapshot_id in self._get_protected_snapshot_ids(): + raise ValueError(f"Snapshot with ID {snapshot_id} is protected and cannot be expired.") + + self._snapshot_ids_to_expire.add(snapshot_id) + + return self diff --git a/tests/table/test_expire_snapshots.py b/tests/table/test_expire_snapshots.py new file mode 100644 index 0000000000..c2702edb3d --- /dev/null +++ b/tests/table/test_expire_snapshots.py @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from unittest.mock import MagicMock +from uuid import uuid4 + +from pyiceberg.table import CommitTableResponse, Table + + +def test_expire_snapshot(table_v2: Table) -> None: + EXPIRE_SNAPSHOT = 3051729675574597004 + KEEP_SNAPSHOT = 3055729675574597004 + # Mock the catalog's commit_table method + mock_response = CommitTableResponse( + # Use the table's current metadata but keep only the snapshot not to be expired + metadata=table_v2.metadata.model_copy(update={"snapshots": [KEEP_SNAPSHOT]}), + metadata_location="mock://metadata/location", + uuid=uuid4(), + ) + + # Mock the catalog object and its commit_table method to return the mock response + table_v2.catalog = MagicMock() + table_v2.catalog.commit_table.return_value = mock_response + + # Print snapshot IDs for debugging + print(f"Snapshot IDs before expiration: {[snapshot.snapshot_id for snapshot in table_v2.metadata.snapshots]}") + + # Assert fixture data to validate test assumptions + assert len(table_v2.metadata.snapshots) == 2 + assert len(table_v2.metadata.snapshot_log) == 2 + assert len(table_v2.metadata.refs) == 2 + + # Expire the snapshot directly without using a transaction + try: + table_v2.expire_snapshots().expire_snapshot_by_id(EXPIRE_SNAPSHOT).commit() + except Exception as e: + raise AssertionError(f"Commit failed with error: {e}") from e + + # Assert that commit_table was called once + table_v2.catalog.commit_table.assert_called_once() + + # Assert the expired snapshot ID is no longer present + remaining_snapshots = table_v2.metadata.snapshots + assert EXPIRE_SNAPSHOT not in remaining_snapshots + + # Assert the length of snapshots after expiration + assert len(table_v2.metadata.snapshots) == 1