Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,12 @@ def __exit__(self, exctype: type[BaseException] | None, excinst: BaseException |
if exctype is None and excinst is None and exctb is None:
self.commit_transaction()

def _apply(self, updates: tuple[TableUpdate, ...], requirements: tuple[TableRequirement, ...] = ()) -> Transaction:
def _apply(
self,
updates: tuple[TableUpdate, ...],
requirements: tuple[TableRequirement, ...] = (),
commit_transaction_if_autocommit: bool = True,
) -> Transaction:
"""Check if the requirements are met, and applies the updates to the metadata."""
for requirement in requirements:
requirement.validate(self.table_metadata)
Expand All @@ -289,7 +294,7 @@ def _apply(self, updates: tuple[TableUpdate, ...], requirements: tuple[TableRequ
if type(new_requirement) not in existing_requirements:
self._requirements = self._requirements + (new_requirement,)

if self._autocommit:
if self._autocommit and commit_transaction_if_autocommit:
self.commit_transaction()

return self
Expand Down
78 changes: 78 additions & 0 deletions pyiceberg/table/update/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
Snapshot,
SnapshotSummaryCollector,
Summary,
ancestors_of,
update_snapshot_summaries,
)
from pyiceberg.table.update import (
Expand Down Expand Up @@ -843,6 +844,13 @@ def _commit(self) -> UpdatesAndRequirements:
"""Apply the pending changes and commit."""
return self._updates, self._requirements

def _commit_if_ref_updates_exist(self) -> None:
"""Commit any pending ref updates to the transaction."""
if self._updates:
self._transaction._apply(*self._commit(), commit_transaction_if_autocommit=False)
self._updates = ()
self._requirements = ()

def _remove_ref_snapshot(self, ref_name: str) -> ManageSnapshots:
"""Remove a snapshot ref.

Expand Down Expand Up @@ -941,6 +949,76 @@ def remove_branch(self, branch_name: str) -> ManageSnapshots:
"""
return self._remove_ref_snapshot(ref_name=branch_name)

def set_current_snapshot(self, snapshot_id: int | None = None, ref_name: str | None = None) -> ManageSnapshots:
"""Set the current snapshot to a specific snapshot ID or ref.

Args:
snapshot_id: The ID of the snapshot to set as current.
ref_name: The snapshot reference (branch or tag) to set as current.

Returns:
This for method chaining.

Raises:
ValueError: If neither or both arguments are provided, or if the snapshot/ref does not exist.
"""
self._commit_if_ref_updates_exist()

if (snapshot_id is None) == (ref_name is None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a silly nit but why do this? This looks almost java-esque and could be

Suggested change
if (snapshot_id is None) == (ref_name is None):
if snapshot_id is None and ref_name is None:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah lol I just cleaned up the original in a pythonish way to say one must be set. However, seems like this is confusing

Original: https://github.com/apache/iceberg-python/pull/758/changes#diff-23e8153e0fd497a9212215bd2067068f3b56fa071770c7ef326db3d3d03cee9bR2092

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah I see hahahaha

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could probably do some fanciness with any and all but not gonna harp on it

raise ValueError("Either snapshot_id or ref_name must be provided, not both")

target_snapshot_id: int
if snapshot_id is not None:
target_snapshot_id = snapshot_id
else:
if ref_name not in self._transaction.table_metadata.refs:
raise ValueError(f"Cannot find matching snapshot ID for ref: {ref_name}")
target_snapshot_id = self._transaction.table_metadata.refs[ref_name].snapshot_id

if self._transaction.table_metadata.snapshot_by_id(target_snapshot_id) is None:
raise ValueError(f"Cannot set current snapshot to unknown snapshot id: {target_snapshot_id}")

update, requirement = self._transaction._set_ref_snapshot(
snapshot_id=target_snapshot_id,
ref_name=MAIN_BRANCH,
type="branch",
)
self._updates += update
self._requirements += requirement
return self

def rollback_to_snapshot(self, snapshot_id: int) -> ManageSnapshots:
"""Rollback the table to the given snapshot id.

The snapshot needs to be an ancestor of the current table state.

Args:
snapshot_id (int): rollback to this snapshot_id that used to be current.
Returns:
This for method chaining
Raises:
ValueError: If the snapshot does not exist or is not an ancestor of the current table state.
"""
if not self._transaction.table_metadata.snapshot_by_id(snapshot_id):
raise ValueError(f"Cannot roll back to unknown snapshot id: {snapshot_id}")

if not self._is_current_ancestor(snapshot_id):
raise ValueError(f"Cannot roll back to snapshot, not an ancestor of the current state: {snapshot_id}")

return self.set_current_snapshot(snapshot_id=snapshot_id)

def _is_current_ancestor(self, snapshot_id: int) -> bool:
return snapshot_id in self._current_ancestors()

def _current_ancestors(self) -> set[int]:
return {
a.snapshot_id
for a in ancestors_of(
self._transaction._table.current_snapshot(),
self._transaction.table_metadata,
)
}


class ExpireSnapshots(UpdateTableMetadata["ExpireSnapshots"]):
"""Expire snapshots by ID.
Expand Down
114 changes: 114 additions & 0 deletions tests/integration/test_snapshot_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,117 @@ def test_remove_branch(catalog: Catalog) -> None:
# now, remove the branch
tbl.manage_snapshots().remove_branch(branch_name=branch_name).commit()
assert tbl.metadata.refs.get(branch_name, None) is None


@pytest.mark.integration
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
def test_set_current_snapshot(catalog: Catalog) -> None:
identifier = "default.test_table_snapshot_operations"
tbl = catalog.load_table(identifier)
assert len(tbl.history()) > 2

# first get the current snapshot and an older one
current_snapshot_id = tbl.history()[-1].snapshot_id
older_snapshot_id = tbl.history()[-2].snapshot_id

# set the current snapshot to the older one
tbl.manage_snapshots().set_current_snapshot(snapshot_id=older_snapshot_id).commit()

tbl = catalog.load_table(identifier)
updated_snapshot = tbl.current_snapshot()
assert updated_snapshot and updated_snapshot.snapshot_id == older_snapshot_id

# restore table
tbl.manage_snapshots().set_current_snapshot(snapshot_id=current_snapshot_id).commit()
tbl = catalog.load_table(identifier)
restored_snapshot = tbl.current_snapshot()
assert restored_snapshot and restored_snapshot.snapshot_id == current_snapshot_id


@pytest.mark.integration
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
def test_set_current_snapshot_by_ref(catalog: Catalog) -> None:
identifier = "default.test_table_snapshot_operations"
tbl = catalog.load_table(identifier)
assert len(tbl.history()) > 2

# first get the current snapshot and an older one
current_snapshot_id = tbl.history()[-1].snapshot_id
older_snapshot_id = tbl.history()[-2].snapshot_id
assert older_snapshot_id != current_snapshot_id

# create a tag pointing to the older snapshot
tag_name = "my-tag"
tbl.manage_snapshots().create_tag(snapshot_id=older_snapshot_id, tag_name=tag_name).commit()

# set current snapshot using the tag name
tbl = catalog.load_table(identifier)
tbl.manage_snapshots().set_current_snapshot(ref_name=tag_name).commit()

tbl = catalog.load_table(identifier)
updated_snapshot = tbl.current_snapshot()
assert updated_snapshot and updated_snapshot.snapshot_id == older_snapshot_id

# restore table
tbl.manage_snapshots().set_current_snapshot(snapshot_id=current_snapshot_id).commit()
tbl = catalog.load_table(identifier)
tbl.manage_snapshots().remove_tag(tag_name=tag_name).commit()
assert tbl.metadata.refs.get(tag_name, None) is None


@pytest.mark.integration
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
def test_set_current_snapshot_chained_with_create_tag(catalog: Catalog) -> None:
identifier = "default.test_table_snapshot_operations"
tbl = catalog.load_table(identifier)
assert len(tbl.history()) > 2

current_snapshot_id = tbl.history()[-1].snapshot_id
older_snapshot_id = tbl.history()[-2].snapshot_id
assert older_snapshot_id != current_snapshot_id

# create a tag and use it to set current snapshot
tag_name = "my-tag"
(
tbl.manage_snapshots()
.create_tag(snapshot_id=older_snapshot_id, tag_name=tag_name)
.set_current_snapshot(ref_name=tag_name)
.commit()
)

tbl = catalog.load_table(identifier)
updated_snapshot = tbl.current_snapshot()
assert updated_snapshot
assert updated_snapshot.snapshot_id == older_snapshot_id

# restore table
tbl.manage_snapshots().set_current_snapshot(snapshot_id=current_snapshot_id).commit()
tbl = catalog.load_table(identifier)
tbl.manage_snapshots().remove_tag(tag_name=tag_name).commit()
assert tbl.metadata.refs.get(tag_name, None) is None


@pytest.mark.integration
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
def test_rollback_to_snapshot(catalog: Catalog) -> None:
identifier = "default.test_table_snapshot_operations"
tbl = catalog.load_table(identifier)
assert len(tbl.history()) > 2

# get the current snapshot and an ancestor
current_snapshot_id = tbl.history()[-1].snapshot_id
ancestor_snapshot_id = tbl.history()[-2].snapshot_id
assert ancestor_snapshot_id != current_snapshot_id

# rollback to the ancestor snapshot
tbl.manage_snapshots().rollback_to_snapshot(snapshot_id=ancestor_snapshot_id).commit()

tbl = catalog.load_table(identifier)
updated_snapshot = tbl.current_snapshot()
assert updated_snapshot and updated_snapshot.snapshot_id == ancestor_snapshot_id

# restore table
tbl.manage_snapshots().set_current_snapshot(snapshot_id=current_snapshot_id).commit()
tbl = catalog.load_table(identifier)
restored_snapshot = tbl.current_snapshot()
assert restored_snapshot and restored_snapshot.snapshot_id == current_snapshot_id
Loading