From ba18a0200c441faaad72e867f83f66c759267608 Mon Sep 17 00:00:00 2001 From: geruh Date: Mon, 29 Dec 2025 01:22:35 -0800 Subject: [PATCH 1/3] feat: Add support set current snapshot Co-authored-by: Chinmay Bhat <12948588+chinmay-bhat@users.noreply.github.com> --- pyiceberg/table/__init__.py | 9 +- pyiceberg/table/update/snapshot.py | 45 +++++ tests/integration/test_snapshot_operations.py | 88 +++++++++ tests/table/test_manage_snapshots.py | 179 ++++++++++++++++++ 4 files changed, 319 insertions(+), 2 deletions(-) create mode 100644 tests/table/test_manage_snapshots.py diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 2e26a4ccc2..8c249a362f 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -275,7 +275,12 @@ def __exit__(self, exctype: type[BaseException] | None, excinst: BaseException | if exctype is None and excinst is None and exctb is None: self.commit_transaction() - def _apply(self, updates: tuple[TableUpdate, ...], requirements: tuple[TableRequirement, ...] = ()) -> Transaction: + def _apply( + self, + updates: tuple[TableUpdate, ...], + requirements: tuple[TableRequirement, ...] = (), + commit_transaction_if_autocommit: bool = True, + ) -> Transaction: """Check if the requirements are met, and applies the updates to the metadata.""" for requirement in requirements: requirement.validate(self.table_metadata) @@ -289,7 +294,7 @@ def _apply(self, updates: tuple[TableUpdate, ...], requirements: tuple[TableRequ if type(new_requirement) not in existing_requirements: self._requirements = self._requirements + (new_requirement,) - if self._autocommit: + if self._autocommit and commit_transaction_if_autocommit: self.commit_transaction() return self diff --git a/pyiceberg/table/update/snapshot.py b/pyiceberg/table/update/snapshot.py index e89cd45d34..6e468285d2 100644 --- a/pyiceberg/table/update/snapshot.py +++ b/pyiceberg/table/update/snapshot.py @@ -843,6 +843,13 @@ def _commit(self) -> UpdatesAndRequirements: """Apply the pending changes and commit.""" return self._updates, self._requirements + def _commit_if_ref_updates_exist(self) -> None: + """Commit any pending ref updates to the transaction.""" + if self._updates: + self._transaction._apply(*self._commit(), commit_transaction_if_autocommit=False) + self._updates = () + self._requirements = () + def _remove_ref_snapshot(self, ref_name: str) -> ManageSnapshots: """Remove a snapshot ref. @@ -941,6 +948,44 @@ def remove_branch(self, branch_name: str) -> ManageSnapshots: """ return self._remove_ref_snapshot(ref_name=branch_name) + def set_current_snapshot(self, snapshot_id: int | None = None, ref_name: str | None = None) -> ManageSnapshots: + """Set the current snapshot to a specific snapshot ID or ref. + + Args: + snapshot_id: The ID of the snapshot to set as current. + ref_name: The snapshot reference (branch or tag) to set as current. + + Returns: + This for method chaining. + + Raises: + ValueError: If neither or both arguments are provided, or if the snapshot/ref does not exist. + """ + self._commit_if_ref_updates_exist() + + if (snapshot_id is None) == (ref_name is None): + raise ValueError("Either snapshot_id or ref_name must be provided, not both") + + target_snapshot_id: int + if snapshot_id is not None: + target_snapshot_id = snapshot_id + else: + if ref_name not in self._transaction.table_metadata.refs: + raise ValueError(f"Cannot find matching snapshot ID for ref: {ref_name}") + target_snapshot_id = self._transaction.table_metadata.refs[ref_name].snapshot_id + + if self._transaction.table_metadata.snapshot_by_id(target_snapshot_id) is None: + raise ValueError(f"Cannot set current snapshot to unknown snapshot id: {target_snapshot_id}") + + update, requirement = self._transaction._set_ref_snapshot( + snapshot_id=target_snapshot_id, + ref_name=MAIN_BRANCH, + type="branch", + ) + self._updates += update + self._requirements += requirement + return self + class ExpireSnapshots(UpdateTableMetadata["ExpireSnapshots"]): """Expire snapshots by ID. diff --git a/tests/integration/test_snapshot_operations.py b/tests/integration/test_snapshot_operations.py index 1b7f2d3a29..2f0447ec52 100644 --- a/tests/integration/test_snapshot_operations.py +++ b/tests/integration/test_snapshot_operations.py @@ -72,3 +72,91 @@ def test_remove_branch(catalog: Catalog) -> None: # now, remove the branch tbl.manage_snapshots().remove_branch(branch_name=branch_name).commit() assert tbl.metadata.refs.get(branch_name, None) is None + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_set_current_snapshot(catalog: Catalog) -> None: + identifier = "default.test_table_snapshot_operations" + tbl = catalog.load_table(identifier) + assert len(tbl.history()) > 2 + + # first get the current snapshot and an older one + current_snapshot_id = tbl.history()[-1].snapshot_id + older_snapshot_id = tbl.history()[-2].snapshot_id + + # set the current snapshot to the older one + tbl.manage_snapshots().set_current_snapshot(snapshot_id=older_snapshot_id).commit() + + tbl = catalog.load_table(identifier) + updated_snapshot = tbl.current_snapshot() + assert updated_snapshot and updated_snapshot.snapshot_id == older_snapshot_id + + # restore table + tbl.manage_snapshots().set_current_snapshot(snapshot_id=current_snapshot_id).commit() + tbl = catalog.load_table(identifier) + restored_snapshot = tbl.current_snapshot() + assert restored_snapshot and restored_snapshot.snapshot_id == current_snapshot_id + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_set_current_snapshot_by_ref(catalog: Catalog) -> None: + identifier = "default.test_table_snapshot_operations" + tbl = catalog.load_table(identifier) + assert len(tbl.history()) > 2 + + # first get the current snapshot and an older one + current_snapshot_id = tbl.history()[-1].snapshot_id + older_snapshot_id = tbl.history()[-2].snapshot_id + assert older_snapshot_id != current_snapshot_id + + # create a tag pointing to the older snapshot + tag_name = "my-tag" + tbl.manage_snapshots().create_tag(snapshot_id=older_snapshot_id, tag_name=tag_name).commit() + + # set current snapshot using the tag name + tbl = catalog.load_table(identifier) + tbl.manage_snapshots().set_current_snapshot(ref_name=tag_name).commit() + + tbl = catalog.load_table(identifier) + updated_snapshot = tbl.current_snapshot() + assert updated_snapshot and updated_snapshot.snapshot_id == older_snapshot_id + + # restore table + tbl.manage_snapshots().set_current_snapshot(snapshot_id=current_snapshot_id).commit() + tbl = catalog.load_table(identifier) + tbl.manage_snapshots().remove_tag(tag_name=tag_name).commit() + assert tbl.metadata.refs.get(tag_name, None) is None + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_set_current_snapshot_chained_with_create_tag(catalog: Catalog) -> None: + identifier = "default.test_table_snapshot_operations" + tbl = catalog.load_table(identifier) + assert len(tbl.history()) > 2 + + current_snapshot_id = tbl.history()[-1].snapshot_id + older_snapshot_id = tbl.history()[-2].snapshot_id + assert older_snapshot_id != current_snapshot_id + + # create a tag and use it to set current snapshot + tag_name = "my-tag" + ( + tbl.manage_snapshots() + .create_tag(snapshot_id=older_snapshot_id, tag_name=tag_name) + .set_current_snapshot(ref_name=tag_name) + .commit() + ) + + tbl = catalog.load_table(identifier) + updated_snapshot = tbl.current_snapshot() + assert updated_snapshot + assert updated_snapshot.snapshot_id == older_snapshot_id + + # restore table + tbl.manage_snapshots().set_current_snapshot(snapshot_id=current_snapshot_id).commit() + tbl = catalog.load_table(identifier) + tbl.manage_snapshots().remove_tag(tag_name=tag_name).commit() + assert tbl.metadata.refs.get(tag_name, None) is None diff --git a/tests/table/test_manage_snapshots.py b/tests/table/test_manage_snapshots.py new file mode 100644 index 0000000000..93301a01c7 --- /dev/null +++ b/tests/table/test_manage_snapshots.py @@ -0,0 +1,179 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from unittest.mock import MagicMock +from uuid import uuid4 + +import pytest + +from pyiceberg.table import CommitTableResponse, Table +from pyiceberg.table.update import SetSnapshotRefUpdate, TableUpdate + + +def _mock_commit_response(table: Table) -> CommitTableResponse: + return CommitTableResponse( + metadata=table.metadata, + metadata_location="s3://bucket/tbl", + uuid=uuid4(), + ) + + +def _get_updates(mock_catalog: MagicMock) -> tuple[TableUpdate, ...]: + args, _ = mock_catalog.commit_table.call_args + return args[2] + + +def test_set_current_snapshot_basic(table_v2: Table) -> None: + snapshot_one = 3051729675574597004 + + table_v2.catalog = MagicMock() + table_v2.catalog.commit_table.return_value = _mock_commit_response(table_v2) + + table_v2.manage_snapshots().set_current_snapshot(snapshot_id=snapshot_one).commit() + + table_v2.catalog.commit_table.assert_called_once() + + updates = _get_updates(table_v2.catalog) + set_ref_updates = [u for u in updates if isinstance(u, SetSnapshotRefUpdate)] + + assert len(set_ref_updates) == 1 + update = set_ref_updates[0] + assert update.snapshot_id == snapshot_one + assert update.ref_name == "main" + assert update.type == "branch" + + +def test_set_current_snapshot_unknown_id(table_v2: Table) -> None: + invalid_snapshot_id = 1234567890000 + table_v2.catalog = MagicMock() + + with pytest.raises(ValueError, match="Cannot set current snapshot to unknown snapshot id"): + table_v2.manage_snapshots().set_current_snapshot(snapshot_id=invalid_snapshot_id).commit() + + table_v2.catalog.commit_table.assert_not_called() + + +def test_set_current_snapshot_to_current(table_v2: Table) -> None: + current_snapshot = table_v2.current_snapshot() + assert current_snapshot is not None + + table_v2.catalog = MagicMock() + table_v2.catalog.commit_table.return_value = _mock_commit_response(table_v2) + + table_v2.manage_snapshots().set_current_snapshot(snapshot_id=current_snapshot.snapshot_id).commit() + + table_v2.catalog.commit_table.assert_called_once() + + +def test_set_current_snapshot_chained_with_tag(table_v2: Table) -> None: + snapshot_one = 3051729675574597004 + table_v2.catalog = MagicMock() + table_v2.catalog.commit_table.return_value = _mock_commit_response(table_v2) + + (table_v2.manage_snapshots().set_current_snapshot(snapshot_id=snapshot_one).create_tag(snapshot_one, "my-tag").commit()) + + table_v2.catalog.commit_table.assert_called_once() + + updates = _get_updates(table_v2.catalog) + set_ref_updates = [u for u in updates if isinstance(u, SetSnapshotRefUpdate)] + + assert len(set_ref_updates) == 2 + assert {u.ref_name for u in set_ref_updates} == {"main", "my-tag"} + + +def test_set_current_snapshot_with_extensive_snapshots(table_v2_with_extensive_snapshots: Table) -> None: + snapshots = table_v2_with_extensive_snapshots.metadata.snapshots + assert len(snapshots) > 100 + + target_snapshot = snapshots[50].snapshot_id + + table_v2_with_extensive_snapshots.catalog = MagicMock() + table_v2_with_extensive_snapshots.catalog.commit_table.return_value = _mock_commit_response(table_v2_with_extensive_snapshots) + + table_v2_with_extensive_snapshots.manage_snapshots().set_current_snapshot(snapshot_id=target_snapshot).commit() + + table_v2_with_extensive_snapshots.catalog.commit_table.assert_called_once() + + updates = _get_updates(table_v2_with_extensive_snapshots.catalog) + set_ref_updates = [u for u in updates if isinstance(u, SetSnapshotRefUpdate)] + + assert len(set_ref_updates) == 1 + assert set_ref_updates[0].snapshot_id == target_snapshot + + +def test_set_current_snapshot_by_ref_name(table_v2: Table) -> None: + current_snapshot = table_v2.current_snapshot() + assert current_snapshot is not None + + table_v2.catalog = MagicMock() + table_v2.catalog.commit_table.return_value = _mock_commit_response(table_v2) + + table_v2.manage_snapshots().set_current_snapshot(ref_name="main").commit() + + updates = _get_updates(table_v2.catalog) + set_ref_updates = [u for u in updates if isinstance(u, SetSnapshotRefUpdate)] + + assert len(set_ref_updates) == 1 + assert set_ref_updates[0].snapshot_id == current_snapshot.snapshot_id + assert set_ref_updates[0].ref_name == "main" + + +def test_set_current_snapshot_unknown_ref(table_v2: Table) -> None: + table_v2.catalog = MagicMock() + + with pytest.raises(ValueError, match="Cannot find matching snapshot ID for ref: nonexistent"): + table_v2.manage_snapshots().set_current_snapshot(ref_name="nonexistent").commit() + + table_v2.catalog.commit_table.assert_not_called() + + +def test_set_current_snapshot_requires_one_argument(table_v2: Table) -> None: + table_v2.catalog = MagicMock() + + with pytest.raises(ValueError, match="Either snapshot_id or ref_name must be provided, not both"): + table_v2.manage_snapshots().set_current_snapshot().commit() + + with pytest.raises(ValueError, match="Either snapshot_id or ref_name must be provided, not both"): + table_v2.manage_snapshots().set_current_snapshot(snapshot_id=123, ref_name="main").commit() + + table_v2.catalog.commit_table.assert_not_called() + + +def test_set_current_snapshot_chained_with_create_tag(table_v2: Table) -> None: + snapshot_one = 3051729675574597004 + table_v2.catalog = MagicMock() + table_v2.catalog.commit_table.return_value = _mock_commit_response(table_v2) + + # create a tag and immediately use it to set current snapshot + ( + table_v2.manage_snapshots() + .create_tag(snapshot_id=snapshot_one, tag_name="new-tag") + .set_current_snapshot(ref_name="new-tag") + .commit() + ) + + table_v2.catalog.commit_table.assert_called_once() + + updates = _get_updates(table_v2.catalog) + set_ref_updates = [u for u in updates if isinstance(u, SetSnapshotRefUpdate)] + + # should have the tag and the main branch update + assert len(set_ref_updates) == 2 + assert {u.ref_name for u in set_ref_updates} == {"new-tag", "main"} + + # The main branch should point to the same snapshot as the tag + main_update = next(u for u in set_ref_updates if u.ref_name == "main") + assert main_update.snapshot_id == snapshot_one From 4b193f9113f676f4d668768a78414646f4ef2b23 Mon Sep 17 00:00:00 2001 From: geruh Date: Fri, 2 Jan 2026 16:18:19 -0800 Subject: [PATCH 2/3] feat: Add support for rolling back to snapshot Co-authored-by: Chinmay Bhat <12948588+chinmay-bhat@users.noreply.github.com> --- pyiceberg/table/update/snapshot.py | 33 +++++ tests/integration/test_snapshot_operations.py | 26 ++++ tests/table/test_manage_snapshots.py | 126 ++++++++++++++++++ 3 files changed, 185 insertions(+) diff --git a/pyiceberg/table/update/snapshot.py b/pyiceberg/table/update/snapshot.py index 6e468285d2..626df237c8 100644 --- a/pyiceberg/table/update/snapshot.py +++ b/pyiceberg/table/update/snapshot.py @@ -64,6 +64,7 @@ Snapshot, SnapshotSummaryCollector, Summary, + ancestors_of, update_snapshot_summaries, ) from pyiceberg.table.update import ( @@ -986,6 +987,38 @@ def set_current_snapshot(self, snapshot_id: int | None = None, ref_name: str | N self._requirements += requirement return self + def rollback_to_snapshot(self, snapshot_id: int) -> ManageSnapshots: + """Rollback the table to the given snapshot id. + + The snapshot needs to be an ancestor of the current table state. + + Args: + snapshot_id (int): rollback to this snapshot_id that used to be current. + Returns: + This for method chaining + Raises: + ValueError: If the snapshot does not exist or is not an ancestor of the current table state. + """ + if not self._transaction.table_metadata.snapshot_by_id(snapshot_id): + raise ValueError(f"Cannot roll back to unknown snapshot id: {snapshot_id}") + + if not self._is_current_ancestor(snapshot_id): + raise ValueError(f"Cannot roll back to snapshot, not an ancestor of the current state: {snapshot_id}") + + return self.set_current_snapshot(snapshot_id=snapshot_id) + + def _is_current_ancestor(self, snapshot_id: int) -> bool: + return snapshot_id in self._current_ancestors() + + def _current_ancestors(self) -> set[int]: + return { + a.snapshot_id + for a in ancestors_of( + self._transaction._table.current_snapshot(), + self._transaction.table_metadata, + ) + } + class ExpireSnapshots(UpdateTableMetadata["ExpireSnapshots"]): """Expire snapshots by ID. diff --git a/tests/integration/test_snapshot_operations.py b/tests/integration/test_snapshot_operations.py index 2f0447ec52..68cec645ac 100644 --- a/tests/integration/test_snapshot_operations.py +++ b/tests/integration/test_snapshot_operations.py @@ -160,3 +160,29 @@ def test_set_current_snapshot_chained_with_create_tag(catalog: Catalog) -> None: tbl = catalog.load_table(identifier) tbl.manage_snapshots().remove_tag(tag_name=tag_name).commit() assert tbl.metadata.refs.get(tag_name, None) is None + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_rollback_to_snapshot(catalog: Catalog) -> None: + identifier = "default.test_table_snapshot_operations" + tbl = catalog.load_table(identifier) + assert len(tbl.history()) > 2 + + # get the current snapshot and an ancestor + current_snapshot_id = tbl.history()[-1].snapshot_id + ancestor_snapshot_id = tbl.history()[-2].snapshot_id + assert ancestor_snapshot_id != current_snapshot_id + + # rollback to the ancestor snapshot + tbl.manage_snapshots().rollback_to_snapshot(snapshot_id=ancestor_snapshot_id).commit() + + tbl = catalog.load_table(identifier) + updated_snapshot = tbl.current_snapshot() + assert updated_snapshot and updated_snapshot.snapshot_id == ancestor_snapshot_id + + # restore table + tbl.manage_snapshots().set_current_snapshot(snapshot_id=current_snapshot_id).commit() + tbl = catalog.load_table(identifier) + restored_snapshot = tbl.current_snapshot() + assert restored_snapshot and restored_snapshot.snapshot_id == current_snapshot_id diff --git a/tests/table/test_manage_snapshots.py b/tests/table/test_manage_snapshots.py index 93301a01c7..20c23fc91c 100644 --- a/tests/table/test_manage_snapshots.py +++ b/tests/table/test_manage_snapshots.py @@ -19,6 +19,7 @@ import pytest +from pyiceberg.io import load_file_io from pyiceberg.table import CommitTableResponse, Table from pyiceberg.table.update import SetSnapshotRefUpdate, TableUpdate @@ -177,3 +178,128 @@ def test_set_current_snapshot_chained_with_create_tag(table_v2: Table) -> None: # The main branch should point to the same snapshot as the tag main_update = next(u for u in set_ref_updates if u.ref_name == "main") assert main_update.snapshot_id == snapshot_one + + +def test_rollback_to_snapshot(table_v2: Table) -> None: + ancestor_snapshot_id = 3051729675574597004 + + table_v2.catalog = MagicMock() + table_v2.catalog.commit_table.return_value = _mock_commit_response(table_v2) + + table_v2.manage_snapshots().rollback_to_snapshot(snapshot_id=ancestor_snapshot_id).commit() + + table_v2.catalog.commit_table.assert_called_once() + + updates = _get_updates(table_v2.catalog) + set_ref_updates = [u for u in updates if isinstance(u, SetSnapshotRefUpdate)] + + assert len(set_ref_updates) == 1 + update = set_ref_updates[0] + assert update.snapshot_id == ancestor_snapshot_id + assert update.ref_name == "main" + assert update.type == "branch" + + +def test_rollback_to_snapshot_unknown_id(table_v2: Table) -> None: + invalid_snapshot_id = 1234567890000 + table_v2.catalog = MagicMock() + + with pytest.raises(ValueError, match="Cannot roll back to unknown snapshot id"): + table_v2.manage_snapshots().rollback_to_snapshot(snapshot_id=invalid_snapshot_id).commit() + + table_v2.catalog.commit_table.assert_not_called() + + +def test_rollback_to_snapshot_not_ancestor(table_v2: Table) -> None: + from pyiceberg.table.metadata import TableMetadataV2 + + # create a table with a branching snapshot history: + snapshot_a = 1 + snapshot_b = 2 # current + snapshot_c = 3 # branch from a, not ancestor of b + + metadata_dict = { + "format-version": 2, + "table-uuid": "9c12d441-03fe-4693-9a96-a0705ddf69c1", + "location": "s3://bucket/test/location", + "last-sequence-number": 3, + "last-updated-ms": 1602638573590, + "last-column-id": 1, + "current-schema-id": 0, + "schemas": [{"type": "struct", "schema-id": 0, "fields": [{"id": 1, "name": "x", "required": True, "type": "long"}]}], + "default-spec-id": 0, + "partition-specs": [{"spec-id": 0, "fields": []}], + "last-partition-id": 999, + "default-sort-order-id": 0, + "current-snapshot-id": snapshot_b, + "snapshots": [ + { + "snapshot-id": snapshot_a, + "timestamp-ms": 1000, + "sequence-number": 1, + "manifest-list": "s3://a/1.avro", + }, + { + "snapshot-id": snapshot_b, + "parent-snapshot-id": snapshot_a, + "timestamp-ms": 2000, + "sequence-number": 2, + "manifest-list": "s3://a/2.avro", + }, + { + "snapshot-id": snapshot_c, + "parent-snapshot-id": snapshot_a, + "timestamp-ms": 3000, + "sequence-number": 3, + "manifest-list": "s3://a/3.avro", + }, + ], + } + + from pyiceberg.table import Table + + branching_table = Table( + identifier=("db", "table"), + metadata=TableMetadataV2(**metadata_dict), + metadata_location="s3://bucket/test/metadata.json", + io=load_file_io(), + catalog=MagicMock(), + ) + + # snapshot_c exists but is not an ancestor of snapshot_b (current) + with pytest.raises(ValueError, match="Cannot roll back to snapshot, not an ancestor of the current state"): + branching_table.manage_snapshots().rollback_to_snapshot(snapshot_id=snapshot_c).commit() + + +def test_rollback_to_snapshot_chained_with_tag(table_v2: Table) -> None: + ancestor_snapshot_id = 3051729675574597004 + + table_v2.catalog = MagicMock() + table_v2.catalog.commit_table.return_value = _mock_commit_response(table_v2) + + ( + table_v2.manage_snapshots() + .create_tag(snapshot_id=ancestor_snapshot_id, tag_name="before-rollback") + .rollback_to_snapshot(snapshot_id=ancestor_snapshot_id) + .commit() + ) + + table_v2.catalog.commit_table.assert_called_once() + + updates = _get_updates(table_v2.catalog) + set_ref_updates = [u for u in updates if isinstance(u, SetSnapshotRefUpdate)] + + assert len(set_ref_updates) == 2 + ref_names = {u.ref_name for u in set_ref_updates} + assert ref_names == {"before-rollback", "main"} + + +def test_rollback_to_current_snapshot(table_v2: Table) -> None: + current_snapshot = table_v2.current_snapshot() + assert current_snapshot is not None + + table_v2.catalog = MagicMock() + table_v2.catalog.commit_table.return_value = _mock_commit_response(table_v2) + + table_v2.manage_snapshots().rollback_to_snapshot(snapshot_id=current_snapshot.snapshot_id).commit() + table_v2.catalog.commit_table.assert_called_once() From a04e503f02fff86d2b685b2e23b3898c5851af68 Mon Sep 17 00:00:00 2001 From: geruh Date: Fri, 2 Jan 2026 16:38:10 -0800 Subject: [PATCH 3/3] feat: Add support for rolling back to timestamp Co-authored-by: Chinmay Bhat <12948588+chinmay-bhat@users.noreply.github.com> --- pyiceberg/table/snapshots.py | 21 +++++ pyiceberg/table/update/snapshot.py | 28 +++++++ tests/integration/test_snapshot_operations.py | 29 +++++++ tests/table/test_manage_snapshots.py | 75 ++++++++++++++++++ tests/table/test_snapshots.py | 77 +++++++++++++++++++ 5 files changed, 230 insertions(+) diff --git a/pyiceberg/table/snapshots.py b/pyiceberg/table/snapshots.py index 4ef1645df6..7ebb20a6ad 100644 --- a/pyiceberg/table/snapshots.py +++ b/pyiceberg/table/snapshots.py @@ -472,3 +472,24 @@ def ancestors_between(from_snapshot: Snapshot | None, to_snapshot: Snapshot, tab break else: yield from ancestors_of(to_snapshot, table_metadata) + + +def latest_ancestor_before_timestamp(table_metadata: TableMetadata, timestamp_ms: int) -> Snapshot | None: + """Find the latest ancestor snapshot whose timestamp is before the provided timestamp. + + Args: + table_metadata: The table metadata for a table + timestamp_ms: lookup snapshots before this timestamp + + Returns: + The latest ancestor snapshot older than the timestamp, or None if not found. + """ + result: Snapshot | None = None + result_timestamp: int = 0 + + for ancestor in ancestors_of(table_metadata.current_snapshot(), table_metadata): + if timestamp_ms > ancestor.timestamp_ms > result_timestamp: + result = ancestor + result_timestamp = ancestor.timestamp_ms + + return result diff --git a/pyiceberg/table/update/snapshot.py b/pyiceberg/table/update/snapshot.py index 626df237c8..57b25ebbd1 100644 --- a/pyiceberg/table/update/snapshot.py +++ b/pyiceberg/table/update/snapshot.py @@ -65,6 +65,7 @@ SnapshotSummaryCollector, Summary, ancestors_of, + latest_ancestor_before_timestamp, update_snapshot_summaries, ) from pyiceberg.table.update import ( @@ -1007,6 +1008,33 @@ def rollback_to_snapshot(self, snapshot_id: int) -> ManageSnapshots: return self.set_current_snapshot(snapshot_id=snapshot_id) + def rollback_to_timestamp(self, timestamp_ms: int) -> ManageSnapshots: + """Rollback the table to the latest snapshot before the given timestamp. + + Finds the latest ancestor snapshot whose timestamp is before the given timestamp and rolls back to it. + + Args: + timestamp_ms: Rollback to the latest snapshot before this timestamp in milliseconds. + Returns: + This for method chaining + Raises: + ValueError: If no valid snapshot exists older than the given timestamp. + """ + self._commit_if_ref_updates_exist() + + snapshot = latest_ancestor_before_timestamp(self._transaction.table_metadata, timestamp_ms) + if snapshot is None: + raise ValueError(f"Cannot roll back, no valid snapshot older than: {timestamp_ms}") + + update, requirement = self._transaction._set_ref_snapshot( + snapshot_id=snapshot.snapshot_id, + ref_name=MAIN_BRANCH, + type=str(SnapshotRefType.BRANCH), + ) + self._updates += update + self._requirements += requirement + return self + def _is_current_ancestor(self, snapshot_id: int) -> bool: return snapshot_id in self._current_ancestors() diff --git a/tests/integration/test_snapshot_operations.py b/tests/integration/test_snapshot_operations.py index 68cec645ac..6aae91d6e0 100644 --- a/tests/integration/test_snapshot_operations.py +++ b/tests/integration/test_snapshot_operations.py @@ -186,3 +186,32 @@ def test_rollback_to_snapshot(catalog: Catalog) -> None: tbl = catalog.load_table(identifier) restored_snapshot = tbl.current_snapshot() assert restored_snapshot and restored_snapshot.snapshot_id == current_snapshot_id + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_rollback_to_timestamp(catalog: Catalog) -> None: + identifier = "default.test_table_snapshot_operations" + tbl = catalog.load_table(identifier) + + current_snapshot = tbl.current_snapshot() + assert current_snapshot is not None + assert current_snapshot.parent_snapshot_id is not None + + parent_snapshot = tbl.metadata.snapshot_by_id(current_snapshot.parent_snapshot_id) + assert parent_snapshot is not None + + # rollback_to_timestamp finds the latest ancestor with timestamp less than given timestamp + tbl.manage_snapshots().rollback_to_timestamp(timestamp_ms=current_snapshot.timestamp_ms).commit() + + tbl = catalog.load_table(identifier) + updated_snapshot = tbl.current_snapshot() + assert updated_snapshot is not None + assert updated_snapshot.snapshot_id == parent_snapshot.snapshot_id + + # restore table + tbl.manage_snapshots().set_current_snapshot(snapshot_id=current_snapshot.snapshot_id).commit() + tbl = catalog.load_table(identifier) + restored_snapshot = tbl.current_snapshot() + assert restored_snapshot is not None + assert restored_snapshot.snapshot_id == current_snapshot.snapshot_id diff --git a/tests/table/test_manage_snapshots.py b/tests/table/test_manage_snapshots.py index 20c23fc91c..e3022fc3d7 100644 --- a/tests/table/test_manage_snapshots.py +++ b/tests/table/test_manage_snapshots.py @@ -303,3 +303,78 @@ def test_rollback_to_current_snapshot(table_v2: Table) -> None: table_v2.manage_snapshots().rollback_to_snapshot(snapshot_id=current_snapshot.snapshot_id).commit() table_v2.catalog.commit_table.assert_called_once() + + +def test_rollback_to_timestamp() -> None: + from pyiceberg.table.metadata import TableMetadataV2 + + metadata_dict = { + "format-version": 2, + "table-uuid": "9c12d441-03fe-4693-9a96-a0705ddf69c1", + "location": "s3://bucket/test/location", + "last-sequence-number": 4, + "last-updated-ms": 1602638573590, + "last-column-id": 1, + "current-schema-id": 0, + "schemas": [{"type": "struct", "schema-id": 0, "fields": [{"id": 1, "name": "x", "required": True, "type": "long"}]}], + "default-spec-id": 0, + "partition-specs": [{"spec-id": 0, "fields": []}], + "last-partition-id": 999, + "default-sort-order-id": 0, + "sort-orders": [{"order-id": 0, "fields": []}], + "current-snapshot-id": 4, + "snapshots": [ + {"snapshot-id": 1, "timestamp-ms": 1000, "sequence-number": 1, "manifest-list": "s3://a/1.avro"}, + { + "snapshot-id": 2, + "parent-snapshot-id": 1, + "timestamp-ms": 2000, + "sequence-number": 2, + "manifest-list": "s3://a/2.avro", + }, + { + "snapshot-id": 3, + "parent-snapshot-id": 2, + "timestamp-ms": 3000, + "sequence-number": 3, + "manifest-list": "s3://a/3.avro", + }, + { + "snapshot-id": 4, + "parent-snapshot-id": 3, + "timestamp-ms": 4000, + "sequence-number": 4, + "manifest-list": "s3://a/4.avro", + }, + ], + } + + mock_catalog = MagicMock() + table = Table( + identifier=("db", "table"), + metadata=TableMetadataV2(**metadata_dict), + metadata_location="s3://bucket/test/metadata.json", + io=load_file_io(), + catalog=mock_catalog, + ) + mock_catalog.commit_table.return_value = _mock_commit_response(table) + + # verify we find the ancestor before timestamp 2500 + table.manage_snapshots().rollback_to_timestamp(timestamp_ms=2500).commit() + + updates = _get_updates(mock_catalog) + set_ref_updates = [u for u in updates if isinstance(u, SetSnapshotRefUpdate)] + + assert len(set_ref_updates) == 1 + assert set_ref_updates[0].snapshot_id == 2 + assert set_ref_updates[0].ref_name == "main" + + +def test_rollback_to_timestamp_no_valid_snapshot(table_v2: Table) -> None: + # The oldest snapshot is at timestamp 1515100955770 + table_v2.catalog = MagicMock() + + with pytest.raises(ValueError, match="Cannot roll back, no valid snapshot older than"): + table_v2.manage_snapshots().rollback_to_timestamp(timestamp_ms=1515100955770).commit() + + table_v2.catalog.commit_table.assert_not_called() diff --git a/tests/table/test_snapshots.py b/tests/table/test_snapshots.py index d26562ad8f..4aa9521b78 100644 --- a/tests/table/test_snapshots.py +++ b/tests/table/test_snapshots.py @@ -30,6 +30,7 @@ Summary, ancestors_between, ancestors_of, + latest_ancestor_before_timestamp, update_snapshot_summaries, ) from pyiceberg.transforms import IdentityTransform @@ -456,3 +457,79 @@ def test_ancestors_between(table_v2_with_extensive_snapshots: Table) -> None: ) == 2000 ) + + +def test_latest_ancestor_before_timestamp() -> None: + from pyiceberg.table.metadata import TableMetadataV2 + + # Create metadata with 4 snapshots at ordered timestamps + metadata = TableMetadataV2( + **{ + "format-version": 2, + "table-uuid": "9c12d441-03fe-4693-9a96-a0705ddf69c1", + "location": "s3://bucket/test/location", + "last-sequence-number": 4, + "last-updated-ms": 1602638573590, + "last-column-id": 1, + "current-schema-id": 0, + "schemas": [{"type": "struct", "schema-id": 0, "fields": [{"id": 1, "name": "x", "required": True, "type": "long"}]}], + "default-spec-id": 0, + "partition-specs": [{"spec-id": 0, "fields": []}], + "last-partition-id": 999, + "default-sort-order-id": 0, + "sort-orders": [{"order-id": 0, "fields": []}], + "current-snapshot-id": 4, + "snapshots": [ + { + "snapshot-id": 1, + "timestamp-ms": 1000, + "sequence-number": 1, + "summary": {"operation": "append"}, + "manifest-list": "s3://a/1.avro", + }, + { + "snapshot-id": 2, + "parent-snapshot-id": 1, + "timestamp-ms": 2000, + "sequence-number": 2, + "summary": {"operation": "append"}, + "manifest-list": "s3://a/2.avro", + }, + { + "snapshot-id": 3, + "parent-snapshot-id": 2, + "timestamp-ms": 3000, + "sequence-number": 3, + "summary": {"operation": "append"}, + "manifest-list": "s3://a/3.avro", + }, + { + "snapshot-id": 4, + "parent-snapshot-id": 3, + "timestamp-ms": 4000, + "sequence-number": 4, + "summary": {"operation": "append"}, + "manifest-list": "s3://a/4.avro", + }, + ], + } + ) + + result = latest_ancestor_before_timestamp(metadata, 3500) + assert result is not None + assert result.snapshot_id == 3 + + result = latest_ancestor_before_timestamp(metadata, 2500) + assert result is not None + assert result.snapshot_id == 2 + + result = latest_ancestor_before_timestamp(metadata, 5000) + assert result is not None + assert result.snapshot_id == 4 + + result = latest_ancestor_before_timestamp(metadata, 3000) + assert result is not None + assert result.snapshot_id == 2 + + result = latest_ancestor_before_timestamp(metadata, 1000) + assert result is None