Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 35 additions & 18 deletions pyiceberg/catalog/rest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Any,
Union,
)
from urllib.parse import quote, unquote

from pydantic import ConfigDict, Field, field_validator
from requests import HTTPError, Session
Expand Down Expand Up @@ -227,7 +228,8 @@ class IdentifierKind(Enum):
VIEW_ENDPOINTS_SUPPORTED = "view-endpoints-supported"
VIEW_ENDPOINTS_SUPPORTED_DEFAULT = False

NAMESPACE_SEPARATOR = b"\x1f".decode(UTF8)
NAMESPACE_SEPARATOR_PROPERTY = "namespace-separator"
DEFAULT_NAMESPACE_SEPARATOR = b"\x1f".decode(UTF8)


def _retry_hook(retry_state: RetryCallState) -> None:
Expand Down Expand Up @@ -318,7 +320,7 @@ class ListViewsResponse(IcebergBaseModel):
class RestCatalog(Catalog):
uri: str
_session: Session
_supported_endpoints: set[Endpoint]
_namespace_separator: str

def __init__(self, name: str, **properties: str):
"""Rest Catalog.
Expand Down Expand Up @@ -478,6 +480,16 @@ def _extract_optional_oauth_params(self) -> dict[str, str]:

return optional_oauth_param

def _encode_namespace_path(self, namespace: Identifier) -> str:
"""
Encode a namespace for use as a path parameter in a URL.

Each part of the namespace is URL-encoded using `urllib.parse.quote`
(ensuring characters like '/' are encoded) and then joined by the
configured namespace separator.
"""
return self._namespace_separator.join(quote(part, safe="") for part in namespace)

def _fetch_config(self) -> None:
params = {}
if warehouse_location := self.properties.get(WAREHOUSE_LOCATION):
Expand Down Expand Up @@ -510,6 +522,11 @@ def _fetch_config(self) -> None:
if property_as_bool(self.properties, VIEW_ENDPOINTS_SUPPORTED, VIEW_ENDPOINTS_SUPPORTED_DEFAULT):
self._supported_endpoints.update(VIEW_ENDPOINTS)

separator_from_properties = self.properties.get(NAMESPACE_SEPARATOR_PROPERTY, DEFAULT_NAMESPACE_SEPARATOR)
if not separator_from_properties:
raise ValueError("Namespace separator cannot be an empty string")
self._namespace_separator = unquote(separator_from_properties)

def _identifier_to_validated_tuple(self, identifier: str | Identifier) -> Identifier:
identifier_tuple = self.identifier_to_tuple(identifier)
if len(identifier_tuple) <= 1:
Expand All @@ -520,10 +537,17 @@ def _split_identifier_for_path(
self, identifier: str | Identifier | TableIdentifier, kind: IdentifierKind = IdentifierKind.TABLE
) -> Properties:
if isinstance(identifier, TableIdentifier):
return {"namespace": NAMESPACE_SEPARATOR.join(identifier.namespace.root), kind.value: identifier.name}
return {
"namespace": self._encode_namespace_path(tuple(identifier.namespace.root)),
kind.value: quote(identifier.name, safe=""),
}
identifier_tuple = self._identifier_to_validated_tuple(identifier)

return {"namespace": NAMESPACE_SEPARATOR.join(identifier_tuple[:-1]), kind.value: identifier_tuple[-1]}
# Use quote to ensure that '/' aren't treated as path separators.
return {
"namespace": self._encode_namespace_path(identifier_tuple[:-1]),
kind.value: quote(identifier_tuple[-1], safe=""),
}

def _split_identifier_for_json(self, identifier: str | Identifier) -> dict[str, Identifier | str]:
identifier_tuple = self._identifier_to_validated_tuple(identifier)
Expand Down Expand Up @@ -741,7 +765,7 @@ def register_table(self, identifier: str | Identifier, metadata_location: str) -
def list_tables(self, namespace: str | Identifier) -> list[Identifier]:
self._check_endpoint(Capability.V1_LIST_TABLES)
namespace_tuple = self._check_valid_namespace_identifier(namespace)
namespace_concat = NAMESPACE_SEPARATOR.join(namespace_tuple)
namespace_concat = self._encode_namespace_path(namespace_tuple)
response = self._session.get(self.url(Endpoints.list_tables, namespace=namespace_concat))
try:
response.raise_for_status()
Expand Down Expand Up @@ -827,7 +851,7 @@ def list_views(self, namespace: str | Identifier) -> list[Identifier]:
if Capability.V1_LIST_VIEWS not in self._supported_endpoints:
return []
namespace_tuple = self._check_valid_namespace_identifier(namespace)
namespace_concat = NAMESPACE_SEPARATOR.join(namespace_tuple)
namespace_concat = self._encode_namespace_path(namespace_tuple)
response = self._session.get(self.url(Endpoints.list_views, namespace=namespace_concat))
try:
response.raise_for_status()
Expand Down Expand Up @@ -897,7 +921,7 @@ def create_namespace(self, namespace: str | Identifier, properties: Properties =
def drop_namespace(self, namespace: str | Identifier) -> None:
self._check_endpoint(Capability.V1_DELETE_NAMESPACE)
namespace_tuple = self._check_valid_namespace_identifier(namespace)
namespace = NAMESPACE_SEPARATOR.join(namespace_tuple)
namespace = self._encode_namespace_path(namespace_tuple)
response = self._session.delete(self.url(Endpoints.drop_namespace, namespace=namespace))
try:
response.raise_for_status()
Expand All @@ -910,7 +934,7 @@ def list_namespaces(self, namespace: str | Identifier = ()) -> list[Identifier]:
namespace_tuple = self.identifier_to_tuple(namespace)
response = self._session.get(
self.url(
f"{Endpoints.list_namespaces}?parent={NAMESPACE_SEPARATOR.join(namespace_tuple)}"
f"{Endpoints.list_namespaces}?parent={self._encode_namespace_path(namespace_tuple)}"
if namespace_tuple
else Endpoints.list_namespaces
),
Expand All @@ -926,7 +950,7 @@ def list_namespaces(self, namespace: str | Identifier = ()) -> list[Identifier]:
def load_namespace_properties(self, namespace: str | Identifier) -> Properties:
self._check_endpoint(Capability.V1_LOAD_NAMESPACE)
namespace_tuple = self._check_valid_namespace_identifier(namespace)
namespace = NAMESPACE_SEPARATOR.join(namespace_tuple)
namespace = self._encode_namespace_path(namespace_tuple)
response = self._session.get(self.url(Endpoints.load_namespace_metadata, namespace=namespace))
try:
response.raise_for_status()
Expand All @@ -941,7 +965,7 @@ def update_namespace_properties(
) -> PropertiesUpdateSummary:
self._check_endpoint(Capability.V1_UPDATE_NAMESPACE)
namespace_tuple = self._check_valid_namespace_identifier(namespace)
namespace = NAMESPACE_SEPARATOR.join(namespace_tuple)
namespace = self._encode_namespace_path(namespace_tuple)
payload = {"removals": list(removals or []), "updates": updates}
response = self._session.post(self.url(Endpoints.update_namespace_properties, namespace=namespace), json=payload)
try:
Expand All @@ -958,14 +982,7 @@ def update_namespace_properties(
@retry(**_RETRY_ARGS)
def namespace_exists(self, namespace: str | Identifier) -> bool:
namespace_tuple = self._check_valid_namespace_identifier(namespace)
namespace = NAMESPACE_SEPARATOR.join(namespace_tuple)
# fallback in order to work with older rest catalog implementations
if Capability.V1_NAMESPACE_EXISTS not in self._supported_endpoints:
try:
self.load_namespace_properties(namespace_tuple)
return True
except NoSuchNamespaceError:
return False
namespace = self._encode_namespace_path(namespace_tuple)

response = self._session.head(self.url(Endpoints.namespace_exists, namespace=namespace))

Expand Down
25 changes: 25 additions & 0 deletions tests/catalog/test_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1922,6 +1922,31 @@ def test_rest_catalog_with_google_credentials_path(
assert actual_headers["Authorization"] == expected_auth_header


def test_custom_namespace_separator(rest_mock: Mocker) -> None:
custom_separator = "-"
namespace_part1 = "some"
namespace_part2 = "namespace"
# The expected URL path segment should use the literal custom_separator
expected_url_path_segment = f"{namespace_part1}{custom_separator}{namespace_part2}"

rest_mock.get(
f"{TEST_URI}v1/config",
json={"defaults": {}, "overrides": {}},
status_code=200,
)
rest_mock.get(
f"{TEST_URI}v1/namespaces/{expected_url_path_segment}",
json={"namespace": [namespace_part1, namespace_part2], "properties": {"prop": "yes"}},
status_code=200,
request_headers=TEST_HEADERS,
)

catalog = RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN, **{"namespace-separator": custom_separator})
catalog.load_namespace_properties((namespace_part1, namespace_part2))

assert rest_mock.last_request.url == f"{TEST_URI}v1/namespaces/{expected_url_path_segment}"


@pytest.mark.filterwarnings(
"ignore:Deprecated in 0.8.0, will be removed in 1.0.0. Iceberg REST client is missing the OAuth2 server URI:DeprecationWarning"
)
Expand Down
34 changes: 34 additions & 0 deletions tests/integration/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

import os
import uuid
from collections.abc import Generator
from pathlib import Path, PosixPath

Expand Down Expand Up @@ -601,3 +602,36 @@ def test_register_table_existing(test_catalog: Catalog, table_schema_nested: Sch
# Assert that registering the table again raises TableAlreadyExistsError
with pytest.raises(TableAlreadyExistsError):
test_catalog.register_table(identifier, metadata_location=table.metadata_location)


@pytest.mark.integration
def test_rest_custom_namespace_separator(rest_catalog: RestCatalog, table_schema_simple: Schema) -> None:
"""
Tests that the REST catalog correctly picks up the namespace-separator from the config endpoint.
The REST Catalog is configured with a '.' namespace separator.
"""
assert rest_catalog._namespace_separator == "."

unique_id = uuid.uuid4().hex
parent_namespace = (f"test_parent_{unique_id}",)
child_namespace_part = "child"
full_namespace_tuple = (*parent_namespace, child_namespace_part)

table_name = "my_table"
full_table_identifier_tuple = (*full_namespace_tuple, table_name)

rest_catalog.create_namespace(namespace=parent_namespace)
rest_catalog.create_namespace(namespace=full_namespace_tuple)

namespaces = rest_catalog.list_namespaces(parent_namespace)
assert full_namespace_tuple in namespaces

# Test with a table
table = rest_catalog.create_table(identifier=full_table_identifier_tuple, schema=table_schema_simple)
assert table.name() == full_table_identifier_tuple

tables = rest_catalog.list_tables(full_namespace_tuple)
assert full_table_identifier_tuple in tables

loaded_table = rest_catalog.load_table(identifier=full_table_identifier_tuple)
assert loaded_table.name() == full_table_identifier_tuple