Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions pyiceberg/catalog/rest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,8 @@
from tenacity import RetryCallState, retry, retry_if_exception_type, stop_after_attempt

from pyiceberg import __version__
from pyiceberg.catalog import (
BOTOCORE_SESSION,
TOKEN,
URI,
WAREHOUSE_LOCATION,
Catalog,
PropertiesUpdateSummary,
)
from pyiceberg.catalog.rest.auth import AuthManager, AuthManagerAdapter, AuthManagerFactory, LegacyOAuth2AuthManager
from pyiceberg.catalog import BOTOCORE_SESSION, TOKEN, URI, WAREHOUSE_LOCATION, Catalog, PropertiesUpdateSummary
from pyiceberg.catalog.rest.auth import AUTH_MANAGER, AuthManager, AuthManagerAdapter, AuthManagerFactory, LegacyOAuth2AuthManager
from pyiceberg.catalog.rest.response import _handle_non_200_response
from pyiceberg.exceptions import (
AuthorizationExpiredError,
Expand All @@ -49,7 +42,7 @@
TableAlreadyExistsError,
UnauthorizedError,
)
from pyiceberg.io import AWS_ACCESS_KEY_ID, AWS_REGION, AWS_SECRET_ACCESS_KEY, AWS_SESSION_TOKEN
from pyiceberg.io import AWS_ACCESS_KEY_ID, AWS_REGION, AWS_SECRET_ACCESS_KEY, AWS_SESSION_TOKEN, FileIO, load_file_io
from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionSpec, assign_fresh_partition_spec_ids
from pyiceberg.schema import Schema, assign_fresh_schema_ids
from pyiceberg.table import (
Expand Down Expand Up @@ -318,6 +311,7 @@ class ListViewsResponse(IcebergBaseModel):
class RestCatalog(Catalog):
uri: str
_session: Session
_auth_manager: AuthManager | None
_supported_endpoints: set[Endpoint]

def __init__(self, name: str, **properties: str):
Expand All @@ -330,6 +324,7 @@ def __init__(self, name: str, **properties: str):
properties: Properties that are passed along to the configuration.
"""
super().__init__(name, **properties)
self._auth_manager: AuthManager | None = None
self.uri = properties[URI]
self._fetch_config()
self._session = self._create_session()
Expand Down Expand Up @@ -364,16 +359,24 @@ def _create_session(self) -> Session:
if auth_type != CUSTOM and auth_impl:
raise ValueError("auth.impl can only be specified when using custom auth.type")

session.auth = AuthManagerAdapter(AuthManagerFactory.create(auth_impl or auth_type, auth_type_config))
self._auth_manager = AuthManagerFactory.create(auth_impl or auth_type, auth_type_config)
session.auth = AuthManagerAdapter(self._auth_manager)
else:
session.auth = AuthManagerAdapter(self._create_legacy_oauth2_auth_manager(session))
self._auth_manager = self._create_legacy_oauth2_auth_manager(session)
session.auth = AuthManagerAdapter(self._auth_manager)

# Configure SigV4 Request Signing
if property_as_bool(self.properties, SIGV4, False):
self._init_sigv4(session)

return session

def _load_file_io(self, properties: Properties = EMPTY_DICT, location: str | None = None) -> FileIO:
merged_properties = {**self.properties, **properties}
if self._auth_manager:
merged_properties[AUTH_MANAGER] = self._auth_manager
return load_file_io(merged_properties, location)

def is_rest_scan_planning_enabled(self) -> bool:
"""Check if rest server-side scan planning is enabled.

Expand Down
2 changes: 2 additions & 0 deletions pyiceberg/catalog/rest/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from pyiceberg.catalog.rest.response import TokenResponse, _handle_non_200_response
from pyiceberg.exceptions import OAuthError

AUTH_MANAGER = "auth.manager"

COLON = ":"
logger = logging.getLogger(__name__)

Expand Down
15 changes: 12 additions & 3 deletions pyiceberg/io/fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from requests import HTTPError

from pyiceberg.catalog import TOKEN, URI
from pyiceberg.catalog.rest.auth import AUTH_MANAGER
from pyiceberg.exceptions import SignError
from pyiceberg.io import (
ADLS_ACCOUNT_HOST,
Expand Down Expand Up @@ -121,9 +122,17 @@ def __call__(self, request: "AWSRequest", **_: Any) -> None:
signer_url = self.properties.get(S3_SIGNER_URI, self.properties[URI]).rstrip("/") # type: ignore
signer_endpoint = self.properties.get(S3_SIGNER_ENDPOINT, S3_SIGNER_ENDPOINT_DEFAULT)

signer_headers = {}
if token := self.properties.get(TOKEN):
signer_headers = {"Authorization": f"Bearer {token}"}
signer_headers: dict[str, str] = {}

auth_header: str | None = None
if auth_manager := self.properties.get(AUTH_MANAGER):
auth_header = auth_manager.auth_header()
elif token := self.properties.get(TOKEN):
auth_header = f"Bearer {token}"

if auth_header:
signer_headers["Authorization"] = auth_header

signer_headers.update(get_header_properties(self.properties))

signer_body = {
Expand Down
51 changes: 51 additions & 0 deletions tests/io/test_fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from fsspec.spec import AbstractFileSystem
from requests_mock import Mocker

from pyiceberg.catalog.rest.auth import AUTH_MANAGER
from pyiceberg.exceptions import SignError
from pyiceberg.io import fsspec
from pyiceberg.io.fsspec import FsspecFileIO, S3V4RestSigner
Expand Down Expand Up @@ -948,3 +949,53 @@ def test_s3v4_rest_signer_forbidden(requests_mock: Mocker) -> None:
"""Failed to sign request 401: {'method': 'HEAD', 'region': 'us-west-2', 'uri': 'https://bucket/metadata/snap-8048355899640248710-1-a5c8ea2d-aa1f-48e8-89f4-1fa69db8c742.avro', 'headers': {'User-Agent': ['Botocore/1.27.59 Python/3.10.7 Darwin/21.5.0']}}"""
in str(exc_info.value)
)


def test_s3v4_rest_signer_uses_auth_manager(requests_mock: Mocker) -> None:
new_uri = "https://bucket/metadata/snap-signed.avro"
requests_mock.post(
f"{TEST_URI}/v1/aws/s3/sign",
json={
"uri": new_uri,
"headers": {
"Authorization": ["AWS4-HMAC-SHA256 Credential=ASIA.../s3/aws4_request, SignedHeaders=host, Signature=abc"],
"Host": ["bucket.s3.us-west-2.amazonaws.com"],
},
"extensions": {},
},
status_code=200,
)

request = AWSRequest(
method="HEAD",
url="https://bucket/metadata/snap-foo.avro",
headers={"User-Agent": "Botocore/1.27.59 Python/3.10.7 Darwin/21.5.0"},
data=b"",
params={},
auth_path="/metadata/snap-foo.avro",
)
request.context = {
"client_region": "us-west-2",
"has_streaming_input": False,
"auth_type": None,
"signing": {"bucket": "bucket"},
"retries": {"attempt": 1, "invocation-id": "75d143fb-0219-439b-872c-18213d1c8d54"},
}

class DummyAuthManager:
def __init__(self) -> None:
self.calls = 0

def auth_header(self) -> str:
self.calls += 1
return "Bearer via-manager"

auth_manager = DummyAuthManager()

signer = S3V4RestSigner(properties={AUTH_MANAGER: auth_manager, "uri": TEST_URI})
signer(request)

assert auth_manager.calls == 1
assert requests_mock.last_request is not None
assert requests_mock.last_request.headers["Authorization"] == "Bearer via-manager"
assert request.url == new_uri