Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions paimon-python/pypaimon/api/rest_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.

from typing import Dict
from typing import Dict, Optional
from urllib.parse import unquote

from pypaimon.common.options import Options
Expand Down Expand Up @@ -46,8 +46,8 @@ def extract_prefix_map(

@staticmethod
def merge(
base_properties: Dict[str, str],
override_properties: Dict[str, str]) -> Dict[str, str]:
base_properties: Optional[Dict[str, str]],
override_properties: Optional[Dict[str, str]]) -> Dict[str, str]:
if override_properties is None:
override_properties = {}
if base_properties is None:
Expand Down
128 changes: 95 additions & 33 deletions paimon-python/pypaimon/catalog/rest/rest_token_file_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import logging
import threading
import time
from typing import Optional
from typing import Optional, Union

from cachetools import TTLCache

Expand All @@ -41,40 +41,55 @@ class RESTTokenFileIO(FileIO):
_FILE_IO_CACHE_MAXSIZE = 1000
_FILE_IO_CACHE_TTL = 36000 # 10 hours in seconds

_FILE_IO_CACHE: TTLCache = None
_FILE_IO_CACHE_LOCK = threading.Lock()

_TOKEN_CACHE: dict = {}
_TOKEN_LOCKS: dict = {}
_TOKEN_LOCKS_LOCK = threading.Lock()

@classmethod
def _get_file_io_cache(cls) -> TTLCache:
if cls._FILE_IO_CACHE is None:
with cls._FILE_IO_CACHE_LOCK:
if cls._FILE_IO_CACHE is None:
cls._FILE_IO_CACHE = TTLCache(
maxsize=cls._FILE_IO_CACHE_MAXSIZE,
ttl=cls._FILE_IO_CACHE_TTL
)
return cls._FILE_IO_CACHE

def __init__(self, identifier: Identifier, path: str,
catalog_options: Optional[Options] = None):
catalog_options: Optional[Union[dict, Options]] = None):
self.identifier = identifier
self.path = path
self.catalog_options = catalog_options
self.properties = catalog_options or Options({}) # For compatibility with refresh_token()
if catalog_options is None:
self.catalog_options = None
elif isinstance(catalog_options, dict):
self.catalog_options = Options(catalog_options)
else:
# Assume it's already an Options object
self.catalog_options = catalog_options
self.properties = self.catalog_options or Options({}) # For compatibility with refresh_token()
self.token: Optional[RESTToken] = None
self.api_instance: Optional[RESTApi] = None
self.lock = threading.Lock()
self.log = logging.getLogger(__name__)
self._uri_reader_factory_cache: Optional[UriReaderFactory] = None
self._file_io_cache: TTLCache = TTLCache(
maxsize=self._FILE_IO_CACHE_MAXSIZE,
ttl=self._FILE_IO_CACHE_TTL
)

def __getstate__(self):
state = self.__dict__.copy()
# Remove non-serializable objects
state.pop('lock', None)
state.pop('api_instance', None)
state.pop('_file_io_cache', None)
state.pop('_uri_reader_factory_cache', None)
# token can be serialized, but we'll refresh it on deserialization
return state

def __setstate__(self, state):
self.__dict__.update(state)
# Recreate lock and cache after deserialization
# Recreate lock after deserialization
self.lock = threading.Lock()
self._file_io_cache = TTLCache(
maxsize=self._FILE_IO_CACHE_MAXSIZE,
ttl=self._FILE_IO_CACHE_TTL
)
self._uri_reader_factory_cache = None
# api_instance will be recreated when needed
self.api_instance = None
Expand All @@ -86,25 +101,36 @@ def file_io(self) -> FileIO:
return FileIO.get(self.path, self.catalog_options or Options({}))

cache_key = self.token
cache = self._get_file_io_cache()

file_io = self._file_io_cache.get(cache_key)
file_io = cache.get(cache_key)
if file_io is not None:
return file_io

with self.lock:
file_io = self._file_io_cache.get(cache_key)
with self._FILE_IO_CACHE_LOCK:
self.try_to_refresh_token()

if self.token is None:
return FileIO.get(self.path, self.catalog_options or Options({}))

cache_key = self.token
cache = self._get_file_io_cache()
file_io = cache.get(cache_key)
if file_io is not None:
return file_io

merged_token = self._merge_token_with_catalog_options(self.token.token)
merged_properties = RESTUtil.merge(
self.catalog_options.to_map() if self.catalog_options else {},
merged_token
self.token.token
)
if self.catalog_options:
dlf_oss_endpoint = self.catalog_options.get(CatalogOptions.DLF_OSS_ENDPOINT)
if dlf_oss_endpoint and dlf_oss_endpoint.strip():
merged_properties[OssOptions.OSS_ENDPOINT.key()] = dlf_oss_endpoint
merged_options = Options(merged_properties)

file_io = PyArrowFileIO(self.path, merged_options)
self._file_io_cache[cache_key] = file_io
cache[cache_key] = file_io
return file_io

def _merge_token_with_catalog_options(self, token: dict) -> dict:
Expand Down Expand Up @@ -180,16 +206,55 @@ def filesystem(self):
return self.file_io().filesystem

def try_to_refresh_token(self):
if self.should_refresh():
with self.lock:
if self.should_refresh():
self.refresh_token()
identifier_str = str(self.identifier)

if self.token is not None and not self._is_token_expired(self.token):
return

cached_token = self._get_cached_token(identifier_str)
if cached_token and not self._is_token_expired(cached_token):
self.token = cached_token
return

global_lock = self._get_global_token_lock(identifier_str)

with global_lock:
cached_token = self._get_cached_token(identifier_str)
if cached_token and not self._is_token_expired(cached_token):
self.token = cached_token
return

token_to_check = cached_token if cached_token else self.token
if token_to_check is None or self._is_token_expired(token_to_check):
self.refresh_token()
self._set_cached_token(identifier_str, self.token)

def _get_cached_token(self, identifier_str: str) -> Optional[RESTToken]:
with self._TOKEN_LOCKS_LOCK:
return self._TOKEN_CACHE.get(identifier_str)

def _set_cached_token(self, identifier_str: str, token: RESTToken):
with self._TOKEN_LOCKS_LOCK:
self._TOKEN_CACHE[identifier_str] = token

def _is_token_expired(self, token: Optional[RESTToken]) -> bool:
if token is None:
return True
current_time = int(time.time() * 1000)
return (token.expire_at_millis - current_time) < RESTApi.TOKEN_EXPIRATION_SAFE_TIME_MILLIS

def _get_global_token_lock(self, identifier_str: str) -> threading.Lock:
with self._TOKEN_LOCKS_LOCK:
if identifier_str not in self._TOKEN_LOCKS:
self._TOKEN_LOCKS[identifier_str] = threading.Lock()
return self._TOKEN_LOCKS[identifier_str]

def should_refresh(self):
if self.token is None:
return True
current_time = int(time.time() * 1000)
return (self.token.expire_at_millis - current_time) < RESTApi.TOKEN_EXPIRATION_SAFE_TIME_MILLIS
time_until_expiry = self.token.expire_at_millis - current_time
return time_until_expiry < RESTApi.TOKEN_EXPIRATION_SAFE_TIME_MILLIS

def refresh_token(self):
self.log.info(f"begin refresh data token for identifier [{self.identifier}]")
Expand All @@ -200,17 +265,14 @@ def refresh_token(self):
self.log.info(
f"end refresh data token for identifier [{self.identifier}] expiresAtMillis [{response.expires_at_millis}]"
)
self.token = RESTToken(response.token, response.expires_at_millis)

merged_token_dict = self._merge_token_with_catalog_options(response.token)
new_token = RESTToken(merged_token_dict, response.expires_at_millis)
self.token = new_token

def valid_token(self):
self.try_to_refresh_token()
return self.token

def close(self):
with self.lock:
for file_io in self._file_io_cache.values():
try:
file_io.close()
except Exception as e:
self.log.warning(f"Error closing cached FileIO: {e}")
self._file_io_cache.clear()
pass
55 changes: 35 additions & 20 deletions paimon-python/pypaimon/read/reader/lance_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,24 @@

def to_lance_specified(file_io: FileIO, file_path: str) -> Tuple[str, Optional[Dict[str, str]]]:
"""Convert path and extract storage options for Lance format."""
# For RESTTokenFileIO, get underlying FileIO which already has latest token merged
# This follows Java implementation: ((RESTTokenFileIO) fileIO).fileIO()
# The file_io() method will refresh token and return a FileIO with merged token
if hasattr(file_io, 'file_io'):
# Call file_io() to get underlying FileIO with latest token
# This ensures token is refreshed and merged with catalog options
file_io = file_io.file_io()

# Now get properties from the underlying FileIO (which has latest token)
if hasattr(file_io, 'get_merged_properties'):
properties = file_io.get_merged_properties()
else:
properties = file_io.properties if hasattr(file_io, 'properties') and file_io.properties else None

scheme, _, _ = file_io.parse_location(file_path)
storage_options = None
file_path_for_lance = file_io.to_filesystem_path(file_path)

storage_options = None

if scheme in {'file', None} or not scheme:
if not os.path.isabs(file_path_for_lance):
Expand All @@ -40,37 +52,40 @@ def to_lance_specified(file_io: FileIO, file_path: str) -> Tuple[str, Optional[D
file_path_for_lance = file_path

if scheme == 'oss':
storage_options = {}
if hasattr(file_io, 'properties'):
for key, value in file_io.properties.data.items():
parsed = urlparse(file_path)
bucket = parsed.netloc
path = parsed.path.lstrip('/')

if properties:
storage_options = {}
for key, value in properties.to_map().items():
if str(key).startswith('fs.'):
storage_options[key] = value

parsed = urlparse(file_path)
bucket = parsed.netloc
path = parsed.path.lstrip('/')

endpoint = file_io.properties.get(OssOptions.OSS_ENDPOINT)
endpoint = properties.get(OssOptions.OSS_ENDPOINT)
if endpoint:
endpoint_clean = endpoint.replace('http://', '').replace('https://', '')
storage_options['endpoint'] = f"https://{bucket}.{endpoint_clean}"

if file_io.properties.contains(OssOptions.OSS_ACCESS_KEY_ID):
storage_options['access_key_id'] = file_io.properties.get(OssOptions.OSS_ACCESS_KEY_ID)
storage_options['oss_access_key_id'] = file_io.properties.get(OssOptions.OSS_ACCESS_KEY_ID)
if file_io.properties.contains(OssOptions.OSS_ACCESS_KEY_SECRET):
storage_options['secret_access_key'] = file_io.properties.get(OssOptions.OSS_ACCESS_KEY_SECRET)
storage_options['oss_secret_access_key'] = file_io.properties.get(OssOptions.OSS_ACCESS_KEY_SECRET)
if file_io.properties.contains(OssOptions.OSS_SECURITY_TOKEN):
storage_options['session_token'] = file_io.properties.get(OssOptions.OSS_SECURITY_TOKEN)
storage_options['oss_session_token'] = file_io.properties.get(OssOptions.OSS_SECURITY_TOKEN)
if file_io.properties.contains(OssOptions.OSS_ENDPOINT):
storage_options['oss_endpoint'] = file_io.properties.get(OssOptions.OSS_ENDPOINT)
if properties.contains(OssOptions.OSS_ACCESS_KEY_ID):
storage_options['access_key_id'] = properties.get(OssOptions.OSS_ACCESS_KEY_ID)
storage_options['oss_access_key_id'] = properties.get(OssOptions.OSS_ACCESS_KEY_ID)
if properties.contains(OssOptions.OSS_ACCESS_KEY_SECRET):
storage_options['secret_access_key'] = properties.get(OssOptions.OSS_ACCESS_KEY_SECRET)
storage_options['oss_secret_access_key'] = properties.get(OssOptions.OSS_ACCESS_KEY_SECRET)
if properties.contains(OssOptions.OSS_SECURITY_TOKEN):
storage_options['session_token'] = properties.get(OssOptions.OSS_SECURITY_TOKEN)
storage_options['oss_session_token'] = properties.get(OssOptions.OSS_SECURITY_TOKEN)
if properties.contains(OssOptions.OSS_ENDPOINT):
storage_options['oss_endpoint'] = properties.get(OssOptions.OSS_ENDPOINT)

storage_options['virtual_hosted_style_request'] = 'true'

if bucket and path:
file_path_for_lance = f"oss://{bucket}/{path}"
elif bucket:
file_path_for_lance = f"oss://{bucket}"
else:
storage_options = None

return file_path_for_lance, storage_options
48 changes: 46 additions & 2 deletions paimon-python/pypaimon/tests/rest/rest_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,12 @@
from dataclasses import dataclass
from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union, TYPE_CHECKING
from urllib.parse import urlparse

if TYPE_CHECKING:
from pypaimon.catalog.rest.rest_token import RESTToken

from pypaimon.api.api_request import (AlterTableRequest, CreateDatabaseRequest,
CreateTableRequest, RenameTableRequest)
from pypaimon.api.api_response import (ConfigResponse, GetDatabaseResponse,
Expand Down Expand Up @@ -213,6 +216,7 @@ def __init__(self, data_path: str, auth_provider, config: ConfigResponse, wareho
self.table_partitions_store: Dict[str, List] = {}
self.no_permission_databases: List[str] = []
self.no_permission_tables: List[str] = []
self.table_token_store: Dict[str, "RESTToken"] = {}

# Initialize mock catalog (simplified)
self.data_path = data_path
Expand Down Expand Up @@ -469,10 +473,12 @@ def _handle_table_resource(self, method: str, path_parts: List[str],
# Basic table operations (GET, DELETE, etc.)
return self._table_handle(method, data, lookup_identifier)
elif len(path_parts) == 4:
# Extended operations (e.g., commit)
# Extended operations (e.g., commit, token)
operation = path_parts[3]
if operation == "commit":
return self._table_commit_handle(method, data, lookup_identifier, branch_part)
elif operation == "token":
return self._table_token_handle(method, lookup_identifier)
else:
return self._mock_response(ErrorResponse(None, None, "Not Found", 404), 404)
return self._mock_response(ErrorResponse(None, None, "Not Found", 404), 404)
Expand Down Expand Up @@ -574,6 +580,44 @@ def _table_handle(self, method: str, data: str, identifier: Identifier) -> Tuple

return self._mock_response(ErrorResponse(None, None, "Method Not Allowed", 405), 405)

def _table_token_handle(self, method: str, identifier: Identifier) -> Tuple[str, int]:
if method != "GET":
return self._mock_response(ErrorResponse(None, None, "Method Not Allowed", 405), 405)

if identifier.get_full_name() not in self.table_metadata_store:
raise TableNotExistException(identifier)

from pypaimon.api.api_response import GetTableTokenResponse

token_key = identifier.get_full_name()
if token_key in self.table_token_store:
rest_token = self.table_token_store[token_key]
response = GetTableTokenResponse(
token=rest_token.token,
expires_at_millis=rest_token.expire_at_millis
)
else:
default_token = {
"akId": "akId" + str(int(time.time() * 1000)),
"akSecret": "akSecret" + str(int(time.time() * 1000))
}
response = GetTableTokenResponse(
token=default_token,
expires_at_millis=int(time.time() * 1000) + 3600_000 # 1 hour from now
)

return self._mock_response(response, 200)

def set_table_token(self, identifier: Identifier, token: "RESTToken") -> None:
self.table_token_store[identifier.get_full_name()] = token

def get_table_token(self, identifier: Identifier) -> Optional["RESTToken"]:
return self.table_token_store.get(identifier.get_full_name())

def reset_table_token(self, identifier: Identifier) -> None:
if identifier.get_full_name() in self.table_token_store:
del self.table_token_store[identifier.get_full_name()]

def _table_commit_handle(self, method: str, data: str, identifier: Identifier,
branch: str = None) -> Tuple[str, int]:
"""Handle table commit operations"""
Expand Down