Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,12 @@ def __exit__(self, exctype: type[BaseException] | None, excinst: BaseException |
if exctype is None and excinst is None and exctb is None:
self.commit_transaction()

def _apply(self, updates: tuple[TableUpdate, ...], requirements: tuple[TableRequirement, ...] = ()) -> Transaction:
def _apply(
self,
updates: tuple[TableUpdate, ...],
requirements: tuple[TableRequirement, ...] = (),
commit_transaction_if_autocommit: bool = True,
) -> Transaction:
"""Check if the requirements are met, and applies the updates to the metadata."""
for requirement in requirements:
requirement.validate(self.table_metadata)
Expand All @@ -289,7 +294,7 @@ def _apply(self, updates: tuple[TableUpdate, ...], requirements: tuple[TableRequ
if type(new_requirement) not in existing_requirements:
self._requirements = self._requirements + (new_requirement,)

if self._autocommit:
if self._autocommit and commit_transaction_if_autocommit:
self.commit_transaction()

return self
Expand Down
21 changes: 21 additions & 0 deletions pyiceberg/table/snapshots.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,3 +472,24 @@ def ancestors_between(from_snapshot: Snapshot | None, to_snapshot: Snapshot, tab
break
else:
yield from ancestors_of(to_snapshot, table_metadata)


def latest_ancestor_before_timestamp(table_metadata: TableMetadata, timestamp_ms: int) -> Snapshot | None:
"""Find the latest ancestor snapshot whose timestamp is before the provided timestamp.

Args:
table_metadata: The table metadata for a table
timestamp_ms: lookup snapshots before this timestamp

Returns:
The latest ancestor snapshot older than the timestamp, or None if not found.
"""
result: Snapshot | None = None
result_timestamp: int = 0

for ancestor in ancestors_of(table_metadata.current_snapshot(), table_metadata):
if timestamp_ms > ancestor.timestamp_ms > result_timestamp:
result = ancestor
result_timestamp = ancestor.timestamp_ms

return result
106 changes: 106 additions & 0 deletions pyiceberg/table/update/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@
Snapshot,
SnapshotSummaryCollector,
Summary,
ancestors_of,
latest_ancestor_before_timestamp,
update_snapshot_summaries,
)
from pyiceberg.table.update import (
Expand Down Expand Up @@ -843,6 +845,13 @@ def _commit(self) -> UpdatesAndRequirements:
"""Apply the pending changes and commit."""
return self._updates, self._requirements

def _commit_if_ref_updates_exist(self) -> None:
"""Commit any pending ref updates to the transaction."""
if self._updates:
self._transaction._apply(*self._commit(), commit_transaction_if_autocommit=False)
self._updates = ()
self._requirements = ()

def _remove_ref_snapshot(self, ref_name: str) -> ManageSnapshots:
"""Remove a snapshot ref.

Expand Down Expand Up @@ -941,6 +950,103 @@ def remove_branch(self, branch_name: str) -> ManageSnapshots:
"""
return self._remove_ref_snapshot(ref_name=branch_name)

def set_current_snapshot(self, snapshot_id: int | None = None, ref_name: str | None = None) -> ManageSnapshots:
"""Set the current snapshot to a specific snapshot ID or ref.

Args:
snapshot_id: The ID of the snapshot to set as current.
ref_name: The snapshot reference (branch or tag) to set as current.

Returns:
This for method chaining.

Raises:
ValueError: If neither or both arguments are provided, or if the snapshot/ref does not exist.
"""
self._commit_if_ref_updates_exist()

if (snapshot_id is None) == (ref_name is None):
raise ValueError("Either snapshot_id or ref_name must be provided, not both")

target_snapshot_id: int
if snapshot_id is not None:
target_snapshot_id = snapshot_id
else:
if ref_name not in self._transaction.table_metadata.refs:
raise ValueError(f"Cannot find matching snapshot ID for ref: {ref_name}")
target_snapshot_id = self._transaction.table_metadata.refs[ref_name].snapshot_id

if self._transaction.table_metadata.snapshot_by_id(target_snapshot_id) is None:
raise ValueError(f"Cannot set current snapshot to unknown snapshot id: {target_snapshot_id}")

update, requirement = self._transaction._set_ref_snapshot(
snapshot_id=target_snapshot_id,
ref_name=MAIN_BRANCH,
type="branch",
)
self._updates += update
self._requirements += requirement
return self

def rollback_to_snapshot(self, snapshot_id: int) -> ManageSnapshots:
"""Rollback the table to the given snapshot id.

The snapshot needs to be an ancestor of the current table state.

Args:
snapshot_id (int): rollback to this snapshot_id that used to be current.
Returns:
This for method chaining
Raises:
ValueError: If the snapshot does not exist or is not an ancestor of the current table state.
"""
if not self._transaction.table_metadata.snapshot_by_id(snapshot_id):
raise ValueError(f"Cannot roll back to unknown snapshot id: {snapshot_id}")

if not self._is_current_ancestor(snapshot_id):
raise ValueError(f"Cannot roll back to snapshot, not an ancestor of the current state: {snapshot_id}")

return self.set_current_snapshot(snapshot_id=snapshot_id)

def rollback_to_timestamp(self, timestamp_ms: int) -> ManageSnapshots:
"""Rollback the table to the latest snapshot before the given timestamp.

Finds the latest ancestor snapshot whose timestamp is before the given timestamp and rolls back to it.

Args:
timestamp_ms: Rollback to the latest snapshot before this timestamp in milliseconds.
Returns:
This for method chaining
Raises:
ValueError: If no valid snapshot exists older than the given timestamp.
"""
self._commit_if_ref_updates_exist()

snapshot = latest_ancestor_before_timestamp(self._transaction.table_metadata, timestamp_ms)
if snapshot is None:
raise ValueError(f"Cannot roll back, no valid snapshot older than: {timestamp_ms}")

update, requirement = self._transaction._set_ref_snapshot(
snapshot_id=snapshot.snapshot_id,
ref_name=MAIN_BRANCH,
type=str(SnapshotRefType.BRANCH),
)
self._updates += update
self._requirements += requirement
return self

def _is_current_ancestor(self, snapshot_id: int) -> bool:
return snapshot_id in self._current_ancestors()

def _current_ancestors(self) -> set[int]:
return {
a.snapshot_id
for a in ancestors_of(
self._transaction._table.current_snapshot(),
self._transaction.table_metadata,
)
}


class ExpireSnapshots(UpdateTableMetadata["ExpireSnapshots"]):
"""Expire snapshots by ID.
Expand Down
143 changes: 143 additions & 0 deletions tests/integration/test_snapshot_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,146 @@ def test_remove_branch(catalog: Catalog) -> None:
# now, remove the branch
tbl.manage_snapshots().remove_branch(branch_name=branch_name).commit()
assert tbl.metadata.refs.get(branch_name, None) is None


@pytest.mark.integration
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
def test_set_current_snapshot(catalog: Catalog) -> None:
identifier = "default.test_table_snapshot_operations"
tbl = catalog.load_table(identifier)
assert len(tbl.history()) > 2

# first get the current snapshot and an older one
current_snapshot_id = tbl.history()[-1].snapshot_id
older_snapshot_id = tbl.history()[-2].snapshot_id

# set the current snapshot to the older one
tbl.manage_snapshots().set_current_snapshot(snapshot_id=older_snapshot_id).commit()

tbl = catalog.load_table(identifier)
updated_snapshot = tbl.current_snapshot()
assert updated_snapshot and updated_snapshot.snapshot_id == older_snapshot_id

# restore table
tbl.manage_snapshots().set_current_snapshot(snapshot_id=current_snapshot_id).commit()
tbl = catalog.load_table(identifier)
restored_snapshot = tbl.current_snapshot()
assert restored_snapshot and restored_snapshot.snapshot_id == current_snapshot_id


@pytest.mark.integration
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
def test_set_current_snapshot_by_ref(catalog: Catalog) -> None:
identifier = "default.test_table_snapshot_operations"
tbl = catalog.load_table(identifier)
assert len(tbl.history()) > 2

# first get the current snapshot and an older one
current_snapshot_id = tbl.history()[-1].snapshot_id
older_snapshot_id = tbl.history()[-2].snapshot_id
assert older_snapshot_id != current_snapshot_id

# create a tag pointing to the older snapshot
tag_name = "my-tag"
tbl.manage_snapshots().create_tag(snapshot_id=older_snapshot_id, tag_name=tag_name).commit()

# set current snapshot using the tag name
tbl = catalog.load_table(identifier)
tbl.manage_snapshots().set_current_snapshot(ref_name=tag_name).commit()

tbl = catalog.load_table(identifier)
updated_snapshot = tbl.current_snapshot()
assert updated_snapshot and updated_snapshot.snapshot_id == older_snapshot_id

# restore table
tbl.manage_snapshots().set_current_snapshot(snapshot_id=current_snapshot_id).commit()
tbl = catalog.load_table(identifier)
tbl.manage_snapshots().remove_tag(tag_name=tag_name).commit()
assert tbl.metadata.refs.get(tag_name, None) is None


@pytest.mark.integration
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
def test_set_current_snapshot_chained_with_create_tag(catalog: Catalog) -> None:
identifier = "default.test_table_snapshot_operations"
tbl = catalog.load_table(identifier)
assert len(tbl.history()) > 2

current_snapshot_id = tbl.history()[-1].snapshot_id
older_snapshot_id = tbl.history()[-2].snapshot_id
assert older_snapshot_id != current_snapshot_id

# create a tag and use it to set current snapshot
tag_name = "my-tag"
(
tbl.manage_snapshots()
.create_tag(snapshot_id=older_snapshot_id, tag_name=tag_name)
.set_current_snapshot(ref_name=tag_name)
.commit()
)

tbl = catalog.load_table(identifier)
updated_snapshot = tbl.current_snapshot()
assert updated_snapshot
assert updated_snapshot.snapshot_id == older_snapshot_id

# restore table
tbl.manage_snapshots().set_current_snapshot(snapshot_id=current_snapshot_id).commit()
tbl = catalog.load_table(identifier)
tbl.manage_snapshots().remove_tag(tag_name=tag_name).commit()
assert tbl.metadata.refs.get(tag_name, None) is None


@pytest.mark.integration
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
def test_rollback_to_snapshot(catalog: Catalog) -> None:
identifier = "default.test_table_snapshot_operations"
tbl = catalog.load_table(identifier)
assert len(tbl.history()) > 2

# get the current snapshot and an ancestor
current_snapshot_id = tbl.history()[-1].snapshot_id
ancestor_snapshot_id = tbl.history()[-2].snapshot_id
assert ancestor_snapshot_id != current_snapshot_id

# rollback to the ancestor snapshot
tbl.manage_snapshots().rollback_to_snapshot(snapshot_id=ancestor_snapshot_id).commit()

tbl = catalog.load_table(identifier)
updated_snapshot = tbl.current_snapshot()
assert updated_snapshot and updated_snapshot.snapshot_id == ancestor_snapshot_id

# restore table
tbl.manage_snapshots().set_current_snapshot(snapshot_id=current_snapshot_id).commit()
tbl = catalog.load_table(identifier)
restored_snapshot = tbl.current_snapshot()
assert restored_snapshot and restored_snapshot.snapshot_id == current_snapshot_id


@pytest.mark.integration
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
def test_rollback_to_timestamp(catalog: Catalog) -> None:
identifier = "default.test_table_snapshot_operations"
tbl = catalog.load_table(identifier)

current_snapshot = tbl.current_snapshot()
assert current_snapshot is not None
assert current_snapshot.parent_snapshot_id is not None

parent_snapshot = tbl.metadata.snapshot_by_id(current_snapshot.parent_snapshot_id)
assert parent_snapshot is not None

# rollback_to_timestamp finds the latest ancestor with timestamp less than given timestamp
tbl.manage_snapshots().rollback_to_timestamp(timestamp_ms=current_snapshot.timestamp_ms).commit()

tbl = catalog.load_table(identifier)
updated_snapshot = tbl.current_snapshot()
assert updated_snapshot is not None
assert updated_snapshot.snapshot_id == parent_snapshot.snapshot_id

# restore table
tbl.manage_snapshots().set_current_snapshot(snapshot_id=current_snapshot.snapshot_id).commit()
tbl = catalog.load_table(identifier)
restored_snapshot = tbl.current_snapshot()
assert restored_snapshot is not None
assert restored_snapshot.snapshot_id == current_snapshot.snapshot_id
Loading
Loading