diff --git a/arangoasync/auth.py b/arangoasync/auth.py index 96e9b1b..a4df28f 100644 --- a/arangoasync/auth.py +++ b/arangoasync/auth.py @@ -20,8 +20,8 @@ class Auth: encoding (str): Encoding for the password (default: utf-8) """ - username: str - password: str + username: str = "" + password: str = "" encoding: str = "utf-8" diff --git a/arangoasync/client.py b/arangoasync/client.py index 235cfae..b2eed10 100644 --- a/arangoasync/client.py +++ b/arangoasync/client.py @@ -147,7 +147,7 @@ async def db( self, name: str, auth_method: str = "basic", - auth: Optional[Auth] = None, + auth: Optional[Auth | str] = None, token: Optional[JwtToken] = None, verify: bool = False, compression: Optional[CompressionManager] = None, @@ -169,7 +169,8 @@ async def db( and client are synchronized. - "superuser": Superuser JWT authentication. The `token` parameter is required. The `auth` parameter is ignored. - auth (Auth | None): Login information. + auth (Auth | None): Login information (username and password) or + access token. token (JwtToken | None): JWT token. verify (bool): Verify the connection by sending a test request. compression (CompressionManager | None): If set, supersedes the @@ -188,6 +189,9 @@ async def db( """ connection: Connection + if isinstance(auth, str): + auth = Auth(password=auth) + if auth_method == "basic": if auth is None: raise ValueError("Basic authentication requires the `auth` parameter") diff --git a/arangoasync/database.py b/arangoasync/database.py index a28fa43..2cbbc68 100644 --- a/arangoasync/database.py +++ b/arangoasync/database.py @@ -17,6 +17,9 @@ from arangoasync.connection import Connection from arangoasync.errno import HTTP_FORBIDDEN, HTTP_NOT_FOUND from arangoasync.exceptions import ( + AccessTokenCreateError, + AccessTokenDeleteError, + AccessTokenListError, AnalyzerCreateError, AnalyzerDeleteError, AnalyzerGetError, @@ -107,6 +110,7 @@ from arangoasync.result import Result from arangoasync.serialization import Deserializer, Serializer from arangoasync.typings import ( + AccessToken, CollectionInfo, CollectionType, DatabaseProperties, @@ -2130,6 +2134,96 @@ def response_handler(resp: Response) -> Json: return await self._executor.execute(request, response_handler) + async def create_access_token( + self, + user: str, + name: str, + valid_until: int, + ) -> Result[AccessToken]: + """Create an access token for the given user. + + Args: + user (str): The name of the user. + name (str): A name for the access token to make identification easier, + like a short description. + valid_until (int): A Unix timestamp in seconds to set the expiration date and time. + + Returns: + AccessToken: Information about the created access token, including the token itself. + + Raises: + AccessTokenCreateError: If the operation fails. + + References: + - `create-an-access-token `__ + """ # noqa: E501 + data: Json = { + "name": name, + "valid_until": valid_until, + } + + request = Request( + method=Method.POST, + endpoint=f"/_api/token/{user}", + data=self.serializer.dumps(data), + ) + + def response_handler(resp: Response) -> AccessToken: + if not resp.is_success: + raise AccessTokenCreateError(resp, request) + result: Json = self.deserializer.loads(resp.raw_body) + return AccessToken(result) + + return await self._executor.execute(request, response_handler) + + async def delete_access_token(self, user: str, token_id: int) -> None: + """List all access tokens for the given user. + + Args: + user (str): The name of the user. + token_id (int): The ID of the access token to delete. + + Raises: + AccessTokenDeleteError: If the operation fails. + + References: + - `delete-an-access-token `__ + """ # noqa: E501 + request = Request( + method=Method.DELETE, endpoint=f"/_api/token/{user}/{token_id}" + ) + + def response_handler(resp: Response) -> None: + if not resp.is_success: + raise AccessTokenDeleteError(resp, request) + + await self._executor.execute(request, response_handler) + + async def list_access_tokens(self, user: str) -> Result[Jsons]: + """List all access tokens for the given user. + + Args: + user (str): The name of the user. + + Returns: + list: List of access tokens for the user. + + Raises: + AccessTokenListError: If the operation fails. + + References: + - `list-all-access-tokens `__ + """ # noqa: E501 + request = Request(method=Method.GET, endpoint=f"/_api/token/{user}") + + def response_handler(resp: Response) -> Jsons: + if not resp.is_success: + raise AccessTokenListError(resp, request) + result: Json = self.deserializer.loads(resp.raw_body) + return cast(Jsons, result["tokens"]) + + return await self._executor.execute(request, response_handler) + async def tls(self) -> Result[Json]: """Return TLS data (keyfile, clientCA). diff --git a/arangoasync/exceptions.py b/arangoasync/exceptions.py index a940e1b..58a9505 100644 --- a/arangoasync/exceptions.py +++ b/arangoasync/exceptions.py @@ -139,6 +139,18 @@ class AQLQueryValidateError(ArangoServerError): """Failed to parse and validate query.""" +class AccessTokenCreateError(ArangoServerError): + """Failed to create an access token.""" + + +class AccessTokenDeleteError(ArangoServerError): + """Failed to delete an access token.""" + + +class AccessTokenListError(ArangoServerError): + """Failed to retrieve access tokens.""" + + class AnalyzerCreateError(ArangoServerError): """Failed to create analyzer.""" diff --git a/arangoasync/typings.py b/arangoasync/typings.py index d49411d..0d85035 100644 --- a/arangoasync/typings.py +++ b/arangoasync/typings.py @@ -2024,3 +2024,55 @@ def __init__( @property def satellites(self) -> Optional[List[str]]: return cast(Optional[List[str]], self._data.get("satellites")) + + +class AccessToken(JsonWrapper): + """User access token. + + Example: + .. code-block:: json + + { + "id" : 1, + "name" : "Token for Service A", + "valid_until" : 1782864000, + "created_at" : 1765543306, + "fingerprint" : "v1...71227d", + "active" : true, + "token" : "v1.7b2265223a3137471227d" + } + + References: + - `create-an-access-token `__ + """ # noqa: E501 + + def __init__(self, data: Json) -> None: + super().__init__(data) + + @property + def active(self) -> bool: + return cast(bool, self._data["active"]) + + @property + def created_at(self) -> int: + return cast(int, self._data["created_at"]) + + @property + def fingerprint(self) -> str: + return cast(str, self._data["fingerprint"]) + + @property + def id(self) -> int: + return cast(int, self._data["id"]) + + @property + def name(self) -> str: + return cast(str, self._data["name"]) + + @property + def token(self) -> str: + return cast(str, self._data["token"]) + + @property + def valid_until(self) -> int: + return cast(int, self._data["valid_until"]) diff --git a/tests/helpers.py b/tests/helpers.py index 0e6e8a8..2bc04a5 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -89,3 +89,12 @@ def generate_service_mount(): str: Random service name. """ return f"/test_{uuid4().hex}" + + +def generate_token_name(): + """Generate and return a random token name. + + Returns: + str: Random token name. + """ + return f"test_token_{uuid4().hex}" diff --git a/tests/test_client.py b/tests/test_client.py index cbd96d4..2218384 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,12 +1,20 @@ +import time + import pytest from arangoasync.auth import JwtToken from arangoasync.client import ArangoClient from arangoasync.compression import DefaultCompressionManager -from arangoasync.exceptions import ServerEncryptionError +from arangoasync.exceptions import ( + AccessTokenCreateError, + AccessTokenDeleteError, + AccessTokenListError, + ServerEncryptionError, +) from arangoasync.http import DefaultHTTPClient from arangoasync.resolver import DefaultHostResolver, RoundRobinHostResolver from arangoasync.version import __version__ +from tests.helpers import generate_token_name @pytest.mark.asyncio @@ -152,3 +160,49 @@ async def test_client_jwt_superuser_auth( await client.db( sys_db_name, auth_method="superuser", auth=basic_auth_root, verify=True ) + + +@pytest.mark.asyncio +async def test_client_access_token(url, sys_db_name, basic_auth_root, bad_db): + username = basic_auth_root.username + + async with ArangoClient(hosts=url) as client: + # First login with basic auth + db_auth_basic = await client.db( + sys_db_name, + auth_method="basic", + auth=basic_auth_root, + verify=True, + ) + + # Create an access token + token_name = generate_token_name() + token = await db_auth_basic.create_access_token( + user=username, name=token_name, valid_until=int(time.time() + 3600) + ) + assert token.active is True + + # Cannot create a token with the same name + with pytest.raises(AccessTokenCreateError): + await db_auth_basic.create_access_token( + user=username, name=token_name, valid_until=int(time.time() + 3600) + ) + + # Authenticate with the created token + access_token_db = await client.db( + sys_db_name, + auth_method="basic", + auth=token.token, + verify=True, + ) + + # List access tokens + tokens = await access_token_db.list_access_tokens(username) + assert isinstance(tokens, list) + with pytest.raises(AccessTokenListError): + await bad_db.list_access_tokens(username) + + # Clean up - delete the created token + await access_token_db.delete_access_token(username, token.id) + with pytest.raises(AccessTokenDeleteError): + await access_token_db.delete_access_token(username, token.id) diff --git a/tests/test_typings.py b/tests/test_typings.py index 3b4e5e2..48e9eb0 100644 --- a/tests/test_typings.py +++ b/tests/test_typings.py @@ -1,6 +1,7 @@ import pytest from arangoasync.typings import ( + AccessToken, CollectionInfo, CollectionStatistics, CollectionStatus, @@ -446,3 +447,28 @@ def test_CollectionStatistics(): assert stats.key_options["type"] == "traditional" assert stats.computed_values is None assert stats.object_id == "69124" + + +def test_AccessToken(): + data = { + "active": True, + "created_at": 1720000000, + "fingerprint": "abc123fingerprint", + "id": 42, + "name": "ci-token", + "token": "v2.local.eyJhbGciOi...", + "valid_until": 1720003600, + } + + access_token = AccessToken(data) + + assert access_token.active is True + assert access_token.created_at == 1720000000 + assert access_token.fingerprint == "abc123fingerprint" + assert access_token.id == 42 + assert access_token.name == "ci-token" + assert access_token.token == "v2.local.eyJhbGciOi..." + assert access_token.valid_until == 1720003600 + + # JsonWrapper behavior + assert access_token.to_dict() == data