diff --git a/paimon-python/pypaimon/api/rest_util.py b/paimon-python/pypaimon/api/rest_util.py index 97a709ecc34c..fd4d1da040a2 100644 --- a/paimon-python/pypaimon/api/rest_util.py +++ b/paimon-python/pypaimon/api/rest_util.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Dict +from typing import Dict, Optional from urllib.parse import unquote from pypaimon.common.options import Options @@ -46,8 +46,8 @@ def extract_prefix_map( @staticmethod def merge( - base_properties: Dict[str, str], - override_properties: Dict[str, str]) -> Dict[str, str]: + base_properties: Optional[Dict[str, str]], + override_properties: Optional[Dict[str, str]]) -> Dict[str, str]: if override_properties is None: override_properties = {} if base_properties is None: diff --git a/paimon-python/pypaimon/catalog/rest/rest_token_file_io.py b/paimon-python/pypaimon/catalog/rest/rest_token_file_io.py index f686dc66ea37..7769ba639b28 100644 --- a/paimon-python/pypaimon/catalog/rest/rest_token_file_io.py +++ b/paimon-python/pypaimon/catalog/rest/rest_token_file_io.py @@ -18,7 +18,7 @@ import logging import threading import time -from typing import Optional +from typing import Optional, Union from cachetools import TTLCache @@ -41,40 +41,55 @@ class RESTTokenFileIO(FileIO): _FILE_IO_CACHE_MAXSIZE = 1000 _FILE_IO_CACHE_TTL = 36000 # 10 hours in seconds + _FILE_IO_CACHE: TTLCache = None + _FILE_IO_CACHE_LOCK = threading.Lock() + + _TOKEN_CACHE: dict = {} + _TOKEN_LOCKS: dict = {} + _TOKEN_LOCKS_LOCK = threading.Lock() + + @classmethod + def _get_file_io_cache(cls) -> TTLCache: + if cls._FILE_IO_CACHE is None: + with cls._FILE_IO_CACHE_LOCK: + if cls._FILE_IO_CACHE is None: + cls._FILE_IO_CACHE = TTLCache( + maxsize=cls._FILE_IO_CACHE_MAXSIZE, + ttl=cls._FILE_IO_CACHE_TTL + ) + return cls._FILE_IO_CACHE + def __init__(self, identifier: Identifier, path: str, - catalog_options: Optional[Options] = None): + catalog_options: Optional[Union[dict, Options]] = None): self.identifier = identifier self.path = path - self.catalog_options = catalog_options - self.properties = catalog_options or Options({}) # For compatibility with refresh_token() + if catalog_options is None: + self.catalog_options = None + elif isinstance(catalog_options, dict): + self.catalog_options = Options(catalog_options) + else: + # Assume it's already an Options object + self.catalog_options = catalog_options + self.properties = self.catalog_options or Options({}) # For compatibility with refresh_token() self.token: Optional[RESTToken] = None self.api_instance: Optional[RESTApi] = None self.lock = threading.Lock() self.log = logging.getLogger(__name__) self._uri_reader_factory_cache: Optional[UriReaderFactory] = None - self._file_io_cache: TTLCache = TTLCache( - maxsize=self._FILE_IO_CACHE_MAXSIZE, - ttl=self._FILE_IO_CACHE_TTL - ) def __getstate__(self): state = self.__dict__.copy() # Remove non-serializable objects state.pop('lock', None) state.pop('api_instance', None) - state.pop('_file_io_cache', None) state.pop('_uri_reader_factory_cache', None) # token can be serialized, but we'll refresh it on deserialization return state def __setstate__(self, state): self.__dict__.update(state) - # Recreate lock and cache after deserialization + # Recreate lock after deserialization self.lock = threading.Lock() - self._file_io_cache = TTLCache( - maxsize=self._FILE_IO_CACHE_MAXSIZE, - ttl=self._FILE_IO_CACHE_TTL - ) self._uri_reader_factory_cache = None # api_instance will be recreated when needed self.api_instance = None @@ -86,25 +101,36 @@ def file_io(self) -> FileIO: return FileIO.get(self.path, self.catalog_options or Options({})) cache_key = self.token + cache = self._get_file_io_cache() - file_io = self._file_io_cache.get(cache_key) + file_io = cache.get(cache_key) if file_io is not None: return file_io - with self.lock: - file_io = self._file_io_cache.get(cache_key) + with self._FILE_IO_CACHE_LOCK: + self.try_to_refresh_token() + + if self.token is None: + return FileIO.get(self.path, self.catalog_options or Options({})) + + cache_key = self.token + cache = self._get_file_io_cache() + file_io = cache.get(cache_key) if file_io is not None: return file_io - merged_token = self._merge_token_with_catalog_options(self.token.token) merged_properties = RESTUtil.merge( self.catalog_options.to_map() if self.catalog_options else {}, - merged_token + self.token.token ) + if self.catalog_options: + dlf_oss_endpoint = self.catalog_options.get(CatalogOptions.DLF_OSS_ENDPOINT) + if dlf_oss_endpoint and dlf_oss_endpoint.strip(): + merged_properties[OssOptions.OSS_ENDPOINT.key()] = dlf_oss_endpoint merged_options = Options(merged_properties) file_io = PyArrowFileIO(self.path, merged_options) - self._file_io_cache[cache_key] = file_io + cache[cache_key] = file_io return file_io def _merge_token_with_catalog_options(self, token: dict) -> dict: @@ -180,16 +206,55 @@ def filesystem(self): return self.file_io().filesystem def try_to_refresh_token(self): - if self.should_refresh(): - with self.lock: - if self.should_refresh(): - self.refresh_token() + identifier_str = str(self.identifier) + + if self.token is not None and not self._is_token_expired(self.token): + return + + cached_token = self._get_cached_token(identifier_str) + if cached_token and not self._is_token_expired(cached_token): + self.token = cached_token + return + + global_lock = self._get_global_token_lock(identifier_str) + + with global_lock: + cached_token = self._get_cached_token(identifier_str) + if cached_token and not self._is_token_expired(cached_token): + self.token = cached_token + return + + token_to_check = cached_token if cached_token else self.token + if token_to_check is None or self._is_token_expired(token_to_check): + self.refresh_token() + self._set_cached_token(identifier_str, self.token) + + def _get_cached_token(self, identifier_str: str) -> Optional[RESTToken]: + with self._TOKEN_LOCKS_LOCK: + return self._TOKEN_CACHE.get(identifier_str) + + def _set_cached_token(self, identifier_str: str, token: RESTToken): + with self._TOKEN_LOCKS_LOCK: + self._TOKEN_CACHE[identifier_str] = token + + def _is_token_expired(self, token: Optional[RESTToken]) -> bool: + if token is None: + return True + current_time = int(time.time() * 1000) + return (token.expire_at_millis - current_time) < RESTApi.TOKEN_EXPIRATION_SAFE_TIME_MILLIS + + def _get_global_token_lock(self, identifier_str: str) -> threading.Lock: + with self._TOKEN_LOCKS_LOCK: + if identifier_str not in self._TOKEN_LOCKS: + self._TOKEN_LOCKS[identifier_str] = threading.Lock() + return self._TOKEN_LOCKS[identifier_str] def should_refresh(self): if self.token is None: return True current_time = int(time.time() * 1000) - return (self.token.expire_at_millis - current_time) < RESTApi.TOKEN_EXPIRATION_SAFE_TIME_MILLIS + time_until_expiry = self.token.expire_at_millis - current_time + return time_until_expiry < RESTApi.TOKEN_EXPIRATION_SAFE_TIME_MILLIS def refresh_token(self): self.log.info(f"begin refresh data token for identifier [{self.identifier}]") @@ -200,17 +265,14 @@ def refresh_token(self): self.log.info( f"end refresh data token for identifier [{self.identifier}] expiresAtMillis [{response.expires_at_millis}]" ) - self.token = RESTToken(response.token, response.expires_at_millis) + + merged_token_dict = self._merge_token_with_catalog_options(response.token) + new_token = RESTToken(merged_token_dict, response.expires_at_millis) + self.token = new_token def valid_token(self): self.try_to_refresh_token() return self.token def close(self): - with self.lock: - for file_io in self._file_io_cache.values(): - try: - file_io.close() - except Exception as e: - self.log.warning(f"Error closing cached FileIO: {e}") - self._file_io_cache.clear() + pass diff --git a/paimon-python/pypaimon/read/reader/lance_utils.py b/paimon-python/pypaimon/read/reader/lance_utils.py index c219dc67043f..2e3a331e4b4a 100644 --- a/paimon-python/pypaimon/read/reader/lance_utils.py +++ b/paimon-python/pypaimon/read/reader/lance_utils.py @@ -26,12 +26,24 @@ def to_lance_specified(file_io: FileIO, file_path: str) -> Tuple[str, Optional[Dict[str, str]]]: """Convert path and extract storage options for Lance format.""" + # For RESTTokenFileIO, get underlying FileIO which already has latest token merged + # This follows Java implementation: ((RESTTokenFileIO) fileIO).fileIO() + # The file_io() method will refresh token and return a FileIO with merged token if hasattr(file_io, 'file_io'): + # Call file_io() to get underlying FileIO with latest token + # This ensures token is refreshed and merged with catalog options file_io = file_io.file_io() + # Now get properties from the underlying FileIO (which has latest token) + if hasattr(file_io, 'get_merged_properties'): + properties = file_io.get_merged_properties() + else: + properties = file_io.properties if hasattr(file_io, 'properties') and file_io.properties else None + scheme, _, _ = file_io.parse_location(file_path) - storage_options = None file_path_for_lance = file_io.to_filesystem_path(file_path) + + storage_options = None if scheme in {'file', None} or not scheme: if not os.path.isabs(file_path_for_lance): @@ -40,37 +52,40 @@ def to_lance_specified(file_io: FileIO, file_path: str) -> Tuple[str, Optional[D file_path_for_lance = file_path if scheme == 'oss': - storage_options = {} - if hasattr(file_io, 'properties'): - for key, value in file_io.properties.data.items(): + parsed = urlparse(file_path) + bucket = parsed.netloc + path = parsed.path.lstrip('/') + + if properties: + storage_options = {} + for key, value in properties.to_map().items(): if str(key).startswith('fs.'): storage_options[key] = value - parsed = urlparse(file_path) - bucket = parsed.netloc - path = parsed.path.lstrip('/') - - endpoint = file_io.properties.get(OssOptions.OSS_ENDPOINT) + endpoint = properties.get(OssOptions.OSS_ENDPOINT) if endpoint: endpoint_clean = endpoint.replace('http://', '').replace('https://', '') storage_options['endpoint'] = f"https://{bucket}.{endpoint_clean}" - if file_io.properties.contains(OssOptions.OSS_ACCESS_KEY_ID): - storage_options['access_key_id'] = file_io.properties.get(OssOptions.OSS_ACCESS_KEY_ID) - storage_options['oss_access_key_id'] = file_io.properties.get(OssOptions.OSS_ACCESS_KEY_ID) - if file_io.properties.contains(OssOptions.OSS_ACCESS_KEY_SECRET): - storage_options['secret_access_key'] = file_io.properties.get(OssOptions.OSS_ACCESS_KEY_SECRET) - storage_options['oss_secret_access_key'] = file_io.properties.get(OssOptions.OSS_ACCESS_KEY_SECRET) - if file_io.properties.contains(OssOptions.OSS_SECURITY_TOKEN): - storage_options['session_token'] = file_io.properties.get(OssOptions.OSS_SECURITY_TOKEN) - storage_options['oss_session_token'] = file_io.properties.get(OssOptions.OSS_SECURITY_TOKEN) - if file_io.properties.contains(OssOptions.OSS_ENDPOINT): - storage_options['oss_endpoint'] = file_io.properties.get(OssOptions.OSS_ENDPOINT) + if properties.contains(OssOptions.OSS_ACCESS_KEY_ID): + storage_options['access_key_id'] = properties.get(OssOptions.OSS_ACCESS_KEY_ID) + storage_options['oss_access_key_id'] = properties.get(OssOptions.OSS_ACCESS_KEY_ID) + if properties.contains(OssOptions.OSS_ACCESS_KEY_SECRET): + storage_options['secret_access_key'] = properties.get(OssOptions.OSS_ACCESS_KEY_SECRET) + storage_options['oss_secret_access_key'] = properties.get(OssOptions.OSS_ACCESS_KEY_SECRET) + if properties.contains(OssOptions.OSS_SECURITY_TOKEN): + storage_options['session_token'] = properties.get(OssOptions.OSS_SECURITY_TOKEN) + storage_options['oss_session_token'] = properties.get(OssOptions.OSS_SECURITY_TOKEN) + if properties.contains(OssOptions.OSS_ENDPOINT): + storage_options['oss_endpoint'] = properties.get(OssOptions.OSS_ENDPOINT) + storage_options['virtual_hosted_style_request'] = 'true' if bucket and path: file_path_for_lance = f"oss://{bucket}/{path}" elif bucket: file_path_for_lance = f"oss://{bucket}" + else: + storage_options = None return file_path_for_lance, storage_options diff --git a/paimon-python/pypaimon/tests/rest/rest_server.py b/paimon-python/pypaimon/tests/rest/rest_server.py index d556f7f55c99..cb7b32108352 100755 --- a/paimon-python/pypaimon/tests/rest/rest_server.py +++ b/paimon-python/pypaimon/tests/rest/rest_server.py @@ -24,9 +24,12 @@ from dataclasses import dataclass from http.server import BaseHTTPRequestHandler, HTTPServer from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union, TYPE_CHECKING from urllib.parse import urlparse +if TYPE_CHECKING: + from pypaimon.catalog.rest.rest_token import RESTToken + from pypaimon.api.api_request import (AlterTableRequest, CreateDatabaseRequest, CreateTableRequest, RenameTableRequest) from pypaimon.api.api_response import (ConfigResponse, GetDatabaseResponse, @@ -213,6 +216,7 @@ def __init__(self, data_path: str, auth_provider, config: ConfigResponse, wareho self.table_partitions_store: Dict[str, List] = {} self.no_permission_databases: List[str] = [] self.no_permission_tables: List[str] = [] + self.table_token_store: Dict[str, "RESTToken"] = {} # Initialize mock catalog (simplified) self.data_path = data_path @@ -469,10 +473,12 @@ def _handle_table_resource(self, method: str, path_parts: List[str], # Basic table operations (GET, DELETE, etc.) return self._table_handle(method, data, lookup_identifier) elif len(path_parts) == 4: - # Extended operations (e.g., commit) + # Extended operations (e.g., commit, token) operation = path_parts[3] if operation == "commit": return self._table_commit_handle(method, data, lookup_identifier, branch_part) + elif operation == "token": + return self._table_token_handle(method, lookup_identifier) else: return self._mock_response(ErrorResponse(None, None, "Not Found", 404), 404) return self._mock_response(ErrorResponse(None, None, "Not Found", 404), 404) @@ -574,6 +580,44 @@ def _table_handle(self, method: str, data: str, identifier: Identifier) -> Tuple return self._mock_response(ErrorResponse(None, None, "Method Not Allowed", 405), 405) + def _table_token_handle(self, method: str, identifier: Identifier) -> Tuple[str, int]: + if method != "GET": + return self._mock_response(ErrorResponse(None, None, "Method Not Allowed", 405), 405) + + if identifier.get_full_name() not in self.table_metadata_store: + raise TableNotExistException(identifier) + + from pypaimon.api.api_response import GetTableTokenResponse + + token_key = identifier.get_full_name() + if token_key in self.table_token_store: + rest_token = self.table_token_store[token_key] + response = GetTableTokenResponse( + token=rest_token.token, + expires_at_millis=rest_token.expire_at_millis + ) + else: + default_token = { + "akId": "akId" + str(int(time.time() * 1000)), + "akSecret": "akSecret" + str(int(time.time() * 1000)) + } + response = GetTableTokenResponse( + token=default_token, + expires_at_millis=int(time.time() * 1000) + 3600_000 # 1 hour from now + ) + + return self._mock_response(response, 200) + + def set_table_token(self, identifier: Identifier, token: "RESTToken") -> None: + self.table_token_store[identifier.get_full_name()] = token + + def get_table_token(self, identifier: Identifier) -> Optional["RESTToken"]: + return self.table_token_store.get(identifier.get_full_name()) + + def reset_table_token(self, identifier: Identifier) -> None: + if identifier.get_full_name() in self.table_token_store: + del self.table_token_store[identifier.get_full_name()] + def _table_commit_handle(self, method: str, data: str, identifier: Identifier, branch: str = None) -> Tuple[str, int]: """Handle table commit operations"""