From fed8cbf3421546294478489b1cab3328d0dfe4da Mon Sep 17 00:00:00 2001 From: Paul Asjes Date: Thu, 5 Sep 2024 14:34:32 +0200 Subject: [PATCH 01/14] WIP --- workos/session.py | 9 +++++++++ workos/types/user_management/session.py | 25 +++++++++++++++++++++++++ 2 files changed, 34 insertions(+) create mode 100644 workos/session.py create mode 100644 workos/types/user_management/session.py diff --git a/workos/session.py b/workos/session.py new file mode 100644 index 00000000..48f38e51 --- /dev/null +++ b/workos/session.py @@ -0,0 +1,9 @@ + +from typing import Protocol, Union + +from workos.types.user_management.session import AuthenticateWithSessionCookieSuccessResponse, AuthenticateWithSessionCookieErrorResponse + +class SessionModule(Protocol): + + def authenticate(self) -> Union[AuthenticateWithSessionCookieSuccessResponse, AuthenticateWithSessionCookieErrorResponse]: + ... \ No newline at end of file diff --git a/workos/types/user_management/session.py b/workos/types/user_management/session.py new file mode 100644 index 00000000..7e1a50df --- /dev/null +++ b/workos/types/user_management/session.py @@ -0,0 +1,25 @@ +from typing import List, Optional +from enum import Enum + +from workos.types.user_management.impersonator import Impersonator +from workos.types.user_management.user import User +from workos.types.workos_model import WorkOSModel + +class AuthenticateWithSessionCookieFailureReason(Enum): + INVALID_JWT = 'invalid_jwt' + INVALID_SESSION_COOKIE = 'invalid_session_cookie' + NO_SESSION_COOKIE_PROVIDED = 'no_session_cookie_provided' + +class AuthenticateWithSessionCookieSuccessResponse(WorkOSModel): + authenticated: bool = True + sessionId: str + organizationId: Optional[str] = None + role: Optional[str] = None + permissions: Optional[List[str]] = None + user: User + impersonator: Optional[Impersonator] = None + +class AuthenticateWithSessionCookieErrorResponse(WorkOSModel): + authenticated: bool = False + reason: AuthenticateWithSessionCookieFailureReason + From c199df348c417b3eadfb078676442109bd1496fc Mon Sep 17 00:00:00 2001 From: Paul Asjes Date: Wed, 13 Nov 2024 16:37:15 +0100 Subject: [PATCH 02/14] WIP --- LICENSE | 2 +- requirements-dev.txt | 9 + requirements.txt | 4 + setup.py | 19 +-- workos/session.py | 156 +++++++++++++++++- .../authenticate_with_common.py | 2 + .../authentication_response.py | 1 + workos/types/user_management/session.py | 5 +- workos/user_management.py | 25 +++ 9 files changed, 204 insertions(+), 19 deletions(-) create mode 100644 requirements-dev.txt create mode 100644 requirements.txt diff --git a/LICENSE b/LICENSE index 656ceca8..67f8e417 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2021 WorkOS +Copyright (c) 2024 WorkOS Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 00000000..478712f7 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,9 @@ +flake8 +pytest==8.3.2 +pytest-asyncio==0.23.8 +pytest-cov==5.0.0 +six==1.16.0 +black==24.4.2 +twine==5.1.1 +mypy==1.12.0 +httpx>=0.27.0 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..beaf1927 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +httpx>=0.27.0 +pydantic==2.9.2 +PyJWT==2.9.0 +cryptography==43.0.3 diff --git a/setup.py b/setup.py index e86e5a92..bfd5dbcb 100644 --- a/setup.py +++ b/setup.py @@ -10,6 +10,11 @@ with open(os.path.join(base_dir, "workos", "__about__.py")) as f: exec(f.read(), about) +def read_requirements(filename): + with open(filename) as f: + return [line.strip() for line in f + if line.strip() and not line.startswith('#')] + setup( name=about["__package_name__"], version=about["__version__"], @@ -27,19 +32,9 @@ ), zip_safe=False, license=about["__license__"], - install_requires=["httpx>=0.27.0", "pydantic==2.9.2"], + install_requires=read_requirements("requirements.txt"), extras_require={ - "dev": [ - "flake8", - "pytest==8.3.2", - "pytest-asyncio==0.23.8", - "pytest-cov==5.0.0", - "six==1.16.0", - "black==24.4.2", - "twine==5.1.1", - "mypy==1.12.0", - "httpx>=0.27.0", - ], + "dev": read_requirements("requirements-dev.txt"), ":python_version<'3.4'": ["enum34"], }, classifiers=[ diff --git a/workos/session.py b/workos/session.py index 48f38e51..47765e04 100644 --- a/workos/session.py +++ b/workos/session.py @@ -1,9 +1,155 @@ +import json +from typing import Any, Dict, List, Optional, Union +import jwt +from jwt import PyJWKClient +from cryptography.fernet import Fernet -from typing import Protocol, Union +from workos.types.user_management.session import ( + AuthenticateWithSessionCookieFailureReason, + AuthenticateWithSessionCookieSuccessResponse, + AuthenticateWithSessionCookieErrorResponse, +) -from workos.types.user_management.session import AuthenticateWithSessionCookieSuccessResponse, AuthenticateWithSessionCookieErrorResponse +class SessionModule: + def __init__( + self, + *, + user_management: Any, + client_id: str, + session_data: str, + cookie_password: str + ) -> None: + # If the cookie password is not provided, throw an error + if cookie_password is None or cookie_password == "": + raise ValueError("cookie_password is required") -class SessionModule(Protocol): + self.user_management = user_management + self.client_id = client_id + self.session_data = session_data + self.cookie_password = cookie_password - def authenticate(self) -> Union[AuthenticateWithSessionCookieSuccessResponse, AuthenticateWithSessionCookieErrorResponse]: - ... \ No newline at end of file + self.jwks = self.create_remote_jwk_set( + self.user_management.get_jwks_url() + ) + self.jwk_algorithms = [str(key.Algorithm) for key in self.jwks] + + for key in self.jwks: + print("Key properties:", dir(key)) # This will show all available attributes + print("Algorithm:", key.Algorithm) + print("Key type:", key.key_type) + + def authenticate( + self, + ) -> Union[ + AuthenticateWithSessionCookieSuccessResponse, + AuthenticateWithSessionCookieErrorResponse, + ]: + if self.session_data is None: + return AuthenticateWithSessionCookieErrorResponse( + authenticated=False, reason=AuthenticateWithSessionCookieFailureReason.NO_SESSION_COOKIE_PROVIDED + ) + + try: + session = self.unseal_data(self.session_data, self.cookie_password) + except Exception: + return AuthenticateWithSessionCookieErrorResponse( + authenticated=False, reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE + ) + + if not session["access_token"]: + return AuthenticateWithSessionCookieErrorResponse( + authenticated=False, reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE + ) + + if not self.is_valid_jwt(session["access_token"]): + return AuthenticateWithSessionCookieErrorResponse( + authenticated=False, reason=AuthenticateWithSessionCookieFailureReason.INVALID_JWT + ) + + decoded = jwt.decode( + session["access_token"], self.jwks, algorithms=self.jwk_algorithms + ) + + return AuthenticateWithSessionCookieSuccessResponse( + authenticated=True, + session_id=decoded["sid"], + organization_id=decoded["org_id"], + role=decoded["role"], + permissions=decoded["permissions"], + entitlements=decoded["entitlements"], + user=session["user"], + impersonator=session["impersonator"], + reason=None, + ) + + def refresh(self, options: Optional[Dict[str, Any]] = None) -> Union[ + AuthenticateWithSessionCookieSuccessResponse, + AuthenticateWithSessionCookieErrorResponse, + ]: + cookie_password = options.get("cookie_password", self.cookie_password) + organization_id = options.get("organization_id", None) + + try: + session = self.unseal_data(self.session_data, cookie_password) + except Exception: + return AuthenticateWithSessionCookieErrorResponse( + authenticated=False, reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE + ) + + if not session["refresh_token"] or not session["user"]: + return AuthenticateWithSessionCookieErrorResponse( + authenticated=False, reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE + ) + + try: + auth_response = self.user_management.authenticate_with_refresh_token( + refresh_token=session["refresh_token"], + organization_id=organization_id, + ) + + self.session_data = auth_response.sealed_session + self.cookie_password = cookie_password + + return AuthenticateWithSessionCookieSuccessResponse( + authenticated=True, + sealed_session=auth_response.sealed_session, + session=auth_response, + reason=None, + ) + except Exception as e: + return AuthenticateWithSessionCookieErrorResponse( + authenticated=False, reason=str(e) + ) + + def get_logout_url(self) -> str: + auth_response = self.authenticate() + + if not auth_response["authenticated"]: + raise ValueError(auth_response["reason"]) + + return self.user_management.get_logout_url( + session_id=auth_response["session_id"] + ) + + def create_remote_jwk_set(self, url: str) -> List[Dict[str, Any]]: + jwks_client = PyJWKClient(url) + return jwks_client.get_signing_keys() + + def is_valid_jwt(self, token: str) -> bool: + try: + jwt.decode(token, self.jwks, algorithms=self.jwk_algorithms) + return True + except jwt.exceptions.InvalidTokenError as error: + print("invalid token", error) + return False + + @staticmethod + def seal_data(data: Dict[str, Any], key: str) -> str: + fernet = Fernet(key) + # take the data and encrypt it with the key using fernet + return fernet.encrypt(json.dumps(data).encode()) + + @staticmethod + def unseal_data(sealed_data: str, key: str) -> Dict[str, Any]: + fernet = Fernet(key) + return json.loads(fernet.decrypt(sealed_data).decode()) diff --git a/workos/types/user_management/authenticate_with_common.py b/workos/types/user_management/authenticate_with_common.py index af423e18..73de96ac 100644 --- a/workos/types/user_management/authenticate_with_common.py +++ b/workos/types/user_management/authenticate_with_common.py @@ -1,5 +1,6 @@ from typing import Literal, Union from typing_extensions import TypedDict +from workos.types.user_management.session import SessionConfig class AuthenticateWithBaseParameters(TypedDict): @@ -17,6 +18,7 @@ class AuthenticateWithCodeParameters(AuthenticateWithBaseParameters): code: str code_verifier: Union[str, None] grant_type: Literal["authorization_code"] + session: Union[SessionConfig, None] class AuthenticateWithMagicAuthParameters(AuthenticateWithBaseParameters): diff --git a/workos/types/user_management/authentication_response.py b/workos/types/user_management/authentication_response.py index 999234a9..6a5f4449 100644 --- a/workos/types/user_management/authentication_response.py +++ b/workos/types/user_management/authentication_response.py @@ -29,6 +29,7 @@ class AuthenticationResponse(_AuthenticationResponseBase): impersonator: Optional[Impersonator] = None organization_id: Optional[str] = None user: User + sealed_session: Optional[str] = None class AuthKitAuthenticationResponse(AuthenticationResponse): diff --git a/workos/types/user_management/session.py b/workos/types/user_management/session.py index 7e1a50df..01581ec1 100644 --- a/workos/types/user_management/session.py +++ b/workos/types/user_management/session.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List, Optional, TypedDict from enum import Enum from workos.types.user_management.impersonator import Impersonator @@ -23,3 +23,6 @@ class AuthenticateWithSessionCookieErrorResponse(WorkOSModel): authenticated: bool = False reason: AuthenticateWithSessionCookieFailureReason +class SessionConfig(TypedDict): + seal_session: bool + cookie_password: str \ No newline at end of file diff --git a/workos/user_management.py b/workos/user_management.py index dc444a01..66232e3a 100644 --- a/workos/user_management.py +++ b/workos/user_management.py @@ -1,5 +1,6 @@ from typing import Optional, Protocol, Sequence, Set, Type from workos._client_configuration import ClientConfiguration +from workos.session import SessionModule from workos.types.list_resource import ( ListArgs, ListMetadata, @@ -43,6 +44,7 @@ UsersListFilters, ) from workos.types.user_management.password_hash_type import PasswordHashType +from workos.types.user_management.session import SessionConfig from workos.types.user_management.user_management_provider_type import ( UserManagementProviderType, ) @@ -109,6 +111,18 @@ class UserManagementModule(Protocol): _client_configuration: ClientConfiguration + def load_sealed_session(self, *, sealed_session: str, cookie_password: str) -> SyncOrAsync[SessionModule]: + """Load a sealed session and return the session data. + + Args: + sealed_session (str): The sealed session data to load. + cookie_password (str): The cookie password to use to decrypt the session data. + + Returns: + SessionModule: The session module. + """ + ... + def get_user(self, user_id: str) -> SyncOrAsync[User]: """Get the details of an existing user. @@ -804,6 +818,9 @@ def __init__( self._client_configuration = client_configuration self._http_client = http_client + def load_sealed_session(self, *, session_data: str, cookie_password: str) -> SessionModule: + return SessionModule(user_management=self, client_id=self._http_client.client_id, session_data=session_data, cookie_password=cookie_password) + def get_user(self, user_id: str) -> User: response = self._http_client.request( USER_DETAIL_PATH.format(user_id), method=REQUEST_METHOD_GET @@ -1013,6 +1030,9 @@ def _authenticate_with( json=json, ) + if payload["session"] is not None and payload["session"].get("seal_session") is True: + response["sealed_session"] = SessionModule.seal_data(response, payload["session"]["cookie_password"]) + return response_model.model_validate(response) def authenticate_with_password( @@ -1037,16 +1057,21 @@ def authenticate_with_code( self, *, code: str, + session: Optional[SessionConfig] = None, code_verifier: Optional[str] = None, ip_address: Optional[str] = None, user_agent: Optional[str] = None, ) -> AuthKitAuthenticationResponse: + if session is not None and (session.get("seal_session") is True and session.get("cookie_password") is None or ""): + raise ValueError("cookie_password is required when sealing session") + payload: AuthenticateWithCodeParameters = { "code": code, "grant_type": "authorization_code", "ip_address": ip_address, "user_agent": user_agent, "code_verifier": code_verifier, + "session": session, } return self._authenticate_with( From 8f1ec0cc5251592af2ccf556d24413a5c1eeef97 Mon Sep 17 00:00:00 2001 From: Paul Asjes Date: Wed, 13 Nov 2024 17:44:41 +0100 Subject: [PATCH 03/14] lol camel case --- workos/session.py | 30 +++++++++---------------- workos/types/user_management/session.py | 4 ++-- 2 files changed, 13 insertions(+), 21 deletions(-) diff --git a/workos/session.py b/workos/session.py index 47765e04..d2b9cf01 100644 --- a/workos/session.py +++ b/workos/session.py @@ -28,15 +28,9 @@ def __init__( self.session_data = session_data self.cookie_password = cookie_password - self.jwks = self.create_remote_jwk_set( - self.user_management.get_jwks_url() - ) - self.jwk_algorithms = [str(key.Algorithm) for key in self.jwks] + self.jwks = PyJWKClient(self.user_management.get_jwks_url()) - for key in self.jwks: - print("Key properties:", dir(key)) # This will show all available attributes - print("Algorithm:", key.Algorithm) - print("Key type:", key.key_type) + self.jwk_algorithms = ['RS256'] def authenticate( self, @@ -66,19 +60,20 @@ def authenticate( authenticated=False, reason=AuthenticateWithSessionCookieFailureReason.INVALID_JWT ) + signing_key = self.jwks.get_signing_key_from_jwt(session["access_token"]) decoded = jwt.decode( - session["access_token"], self.jwks, algorithms=self.jwk_algorithms + session["access_token"], signing_key.key, algorithms=self.jwk_algorithms ) return AuthenticateWithSessionCookieSuccessResponse( authenticated=True, session_id=decoded["sid"], - organization_id=decoded["org_id"], - role=decoded["role"], - permissions=decoded["permissions"], - entitlements=decoded["entitlements"], + organization_id=decoded.get("org_id", None), + role=decoded.get("role", None), + permissions=decoded.get("permissions", None), + entitlements=decoded.get("entitlements", None), user=session["user"], - impersonator=session["impersonator"], + impersonator=session.get("impersonator", None), reason=None, ) @@ -131,13 +126,10 @@ def get_logout_url(self) -> str: session_id=auth_response["session_id"] ) - def create_remote_jwk_set(self, url: str) -> List[Dict[str, Any]]: - jwks_client = PyJWKClient(url) - return jwks_client.get_signing_keys() - def is_valid_jwt(self, token: str) -> bool: try: - jwt.decode(token, self.jwks, algorithms=self.jwk_algorithms) + signing_key = self.jwks.get_signing_key_from_jwt(token) + jwt.decode(token, signing_key.key, algorithms=self.jwk_algorithms) return True except jwt.exceptions.InvalidTokenError as error: print("invalid token", error) diff --git a/workos/types/user_management/session.py b/workos/types/user_management/session.py index 01581ec1..fdcd8146 100644 --- a/workos/types/user_management/session.py +++ b/workos/types/user_management/session.py @@ -12,8 +12,8 @@ class AuthenticateWithSessionCookieFailureReason(Enum): class AuthenticateWithSessionCookieSuccessResponse(WorkOSModel): authenticated: bool = True - sessionId: str - organizationId: Optional[str] = None + session_id: str + organization_id: Optional[str] = None role: Optional[str] = None permissions: Optional[List[str]] = None user: User From 29f4c7367a09750001284f4cb305f2d02d893f96 Mon Sep 17 00:00:00 2001 From: Paul Asjes Date: Thu, 14 Nov 2024 16:14:35 +0100 Subject: [PATCH 04/14] Fix log out url part --- workos/session.py | 10 +++++----- workos/types/user_management/session.py | 1 + 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/workos/session.py b/workos/session.py index d2b9cf01..64537fb6 100644 --- a/workos/session.py +++ b/workos/session.py @@ -30,6 +30,7 @@ def __init__( self.jwks = PyJWKClient(self.user_management.get_jwks_url()) + # Algorithms are hardcoded for security reasons. See https://pyjwt.readthedocs.io/en/stable/algorithms.html#specifying-an-algorithm self.jwk_algorithms = ['RS256'] def authenticate( @@ -119,11 +120,11 @@ def refresh(self, options: Optional[Dict[str, Any]] = None) -> Union[ def get_logout_url(self) -> str: auth_response = self.authenticate() - if not auth_response["authenticated"]: - raise ValueError(auth_response["reason"]) + if not auth_response.authenticated: + raise ValueError(f"Failed to extract session ID for logout URL: {auth_response.reason}") return self.user_management.get_logout_url( - session_id=auth_response["session_id"] + session_id=auth_response.session_id ) def is_valid_jwt(self, token: str) -> bool: @@ -131,8 +132,7 @@ def is_valid_jwt(self, token: str) -> bool: signing_key = self.jwks.get_signing_key_from_jwt(token) jwt.decode(token, signing_key.key, algorithms=self.jwk_algorithms) return True - except jwt.exceptions.InvalidTokenError as error: - print("invalid token", error) + except jwt.exceptions.InvalidTokenError: return False @staticmethod diff --git a/workos/types/user_management/session.py b/workos/types/user_management/session.py index fdcd8146..5e543cb4 100644 --- a/workos/types/user_management/session.py +++ b/workos/types/user_management/session.py @@ -18,6 +18,7 @@ class AuthenticateWithSessionCookieSuccessResponse(WorkOSModel): permissions: Optional[List[str]] = None user: User impersonator: Optional[Impersonator] = None + entitlements: Optional[List[str]] = None class AuthenticateWithSessionCookieErrorResponse(WorkOSModel): authenticated: bool = False From 87fadcb92b09fea861588b3bf1ae2048de3c9bdf Mon Sep 17 00:00:00 2001 From: Paul Asjes Date: Fri, 22 Nov 2024 14:53:52 +0100 Subject: [PATCH 05/14] Add tests --- tests/test_session.py | 333 ++++++++++++++++++ tests/test_user_management.py | 3 + workos/session.py | 47 ++- .../authenticate_with_common.py | 1 + .../authentication_response.py | 2 +- workos/types/user_management/session.py | 12 +- workos/user_management.py | 26 +- 7 files changed, 405 insertions(+), 19 deletions(-) create mode 100644 tests/test_session.py diff --git a/tests/test_session.py b/tests/test_session.py new file mode 100644 index 00000000..a898949c --- /dev/null +++ b/tests/test_session.py @@ -0,0 +1,333 @@ +import pytest +from unittest.mock import Mock, patch +import jwt +from jwt import PyJWKClient +from datetime import datetime, timezone + +from workos.session import SessionModule +from workos.types.user_management.authentication_response import RefreshTokenAuthenticationResponse +from workos.types.user_management.session import ( + AuthenticateWithSessionCookieFailureReason, + AuthenticateWithSessionCookieSuccessResponse, + RefreshWithSessionCookieErrorResponse, + RefreshWithSessionCookieSuccessResponse, +) +from workos.types.user_management.user import User + +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa + +@pytest.fixture(scope="session") +def TEST_CONSTANTS(): + # Generate RSA key pair for testing + private_key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048 + ) + + public_key = private_key.public_key() + + # Get the private key in PEM format + private_pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption() + ) + + return { + "COOKIE_PASSWORD": "pfSqwTFXUTGEBBD1RQh2kt/oNJYxBgaoZan4Z8sMrKU=", + "SESSION_DATA": "session_data", + "CLIENT_ID": "client_123", + "USER_ID": "user_123", + "SESSION_ID": "session_123", + "ORGANIZATION_ID": "organization_123", + "CURRENT_TIMESTAMP": str(datetime.now(timezone.utc)), + "PRIVATE_KEY": private_pem, + "PUBLIC_KEY": public_key, + "TEST_TOKEN": jwt.encode( + { + "sid": "session_123", + "org_id": "organization_123", + "role": "admin", + "permissions": ["read"], + "entitlements": ["feature_1"], + "exp": int(datetime.now(timezone.utc).timestamp()) + 3600, + "iat": int(datetime.now(timezone.utc).timestamp()), + }, + private_pem, + algorithm="RS256" + ) + } + +@pytest.fixture +def mock_user_management(TEST_CONSTANTS): + mock_jwks = Mock(spec=PyJWKClient) + mock_jwk_set = Mock() + mock_jwk_set.keys = [Mock(Algorithm="RS256")] + mock_jwks.get_jwk_set.return_value = mock_jwk_set + + + + mock = Mock() + mock.get_jwks_url.return_value = "https://api.workos.com/user_management/sso/jwks/client_123" + mock.authenticate_with_refresh_token.return_value = RefreshTokenAuthenticationResponse( + **{ + "access_token": "access_token_123", + "refresh_token": "refresh_token_123", + "user": { + "object": "user", + "id": TEST_CONSTANTS["USER_ID"], + "email": "user@example.com", + "first_name": "Test", + "last_name": "User", + "email_verified": True, + "created_at": TEST_CONSTANTS["CURRENT_TIMESTAMP"], + "updated_at": TEST_CONSTANTS["CURRENT_TIMESTAMP"], + } + } + ) + return mock + +def test_initialize_session_module(TEST_CONSTANTS, mock_user_management): + session = SessionModule( + user_management=mock_user_management, + client_id=TEST_CONSTANTS["CLIENT_ID"], + session_data=TEST_CONSTANTS["SESSION_DATA"], + cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"] + ) + + assert session.client_id == TEST_CONSTANTS["CLIENT_ID"] + assert session.cookie_password is not None + +def test_initialize_without_cookie_password(mock_user_management): + with pytest.raises(ValueError, match="cookie_password is required"): + SessionModule( + user_management=mock_user_management, + client_id="client_123", + session_data="session_data", + cookie_password="" + ) + +def test_authenticate_no_session_cookie_provided(TEST_CONSTANTS, mock_user_management): + session = SessionModule( + user_management=mock_user_management, + client_id=TEST_CONSTANTS["CLIENT_ID"], + session_data=None, + cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"] + ) + + response = session.authenticate() + + assert response.reason == AuthenticateWithSessionCookieFailureReason.NO_SESSION_COOKIE_PROVIDED + +def test_authenticate_invalid_session_cookie(TEST_CONSTANTS, mock_user_management): + session = SessionModule( + user_management=mock_user_management, + client_id=TEST_CONSTANTS["CLIENT_ID"], + session_data="invalid_session_data", + cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"] + ) + + response = session.authenticate() + + assert response.reason == AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE + +def test_authenticate_invalid_jwt(TEST_CONSTANTS, mock_user_management): + invalid_session_data = SessionModule.seal_data({ "access_token": "invalid_session_data" }, TEST_CONSTANTS["COOKIE_PASSWORD"]) + session = SessionModule( + user_management=mock_user_management, + client_id=TEST_CONSTANTS["CLIENT_ID"], + session_data=invalid_session_data, + cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"] + ) + + response = session.authenticate() + + assert response.reason == AuthenticateWithSessionCookieFailureReason.INVALID_JWT + +def test_authenticate_success(TEST_CONSTANTS, mock_user_management): + session = SessionModule( + user_management=mock_user_management, + client_id=TEST_CONSTANTS["CLIENT_ID"], + session_data=TEST_CONSTANTS["SESSION_DATA"], + cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"] + ) + + # Mock the session data that would be unsealed + mock_session = { + "access_token": jwt.encode( + { + "sid": TEST_CONSTANTS["SESSION_ID"], + "org_id": TEST_CONSTANTS["ORGANIZATION_ID"], + "role": "admin", + "permissions": ["read"], + "entitlements": ["feature_1"], + "exp": int(datetime.now(timezone.utc).timestamp()) + 3600, + "iat": int(datetime.now(timezone.utc).timestamp()), + }, + TEST_CONSTANTS["PRIVATE_KEY"], + algorithm="RS256" + ), + "user": { + "object": "user", + "id": TEST_CONSTANTS["USER_ID"], + "email": "user@example.com", + "email_verified": True, + "created_at": TEST_CONSTANTS["CURRENT_TIMESTAMP"], + "updated_at": TEST_CONSTANTS["CURRENT_TIMESTAMP"], + }, + "impersonator": None + } + + # Mock the JWT payload that would be decoded + mock_jwt_payload = { + "sid": TEST_CONSTANTS["SESSION_ID"], + "org_id": TEST_CONSTANTS["ORGANIZATION_ID"], + "role": "admin", + "permissions": ["read"], + "entitlements": ["feature_1"] + } + + with ( + # Mock unsealing the session data + patch.object( + SessionModule, + "unseal_data", + return_value=mock_session + ), + # Mock JWT validation + patch.object( + session, + "is_valid_jwt", + return_value=True + ), + # Mock JWT decoding + patch( + "jwt.decode", + return_value=mock_jwt_payload + ), + # Mock JWT signing key retrieval + patch.object( + session.jwks, + "get_signing_key_from_jwt", + return_value=Mock(key=TEST_CONSTANTS["PUBLIC_KEY"]) + ) + ): + response = session.authenticate() + + assert isinstance(response, AuthenticateWithSessionCookieSuccessResponse) + assert response.authenticated is True + assert response.session_id == TEST_CONSTANTS["SESSION_ID"] + assert response.organization_id == TEST_CONSTANTS["ORGANIZATION_ID"] + assert response.role == "admin" + assert response.permissions == ["read"] + assert response.entitlements == ["feature_1"] + assert response.user.id == TEST_CONSTANTS["USER_ID"] + assert response.impersonator is None + +def test_refresh_invalid_session_cookie(TEST_CONSTANTS, mock_user_management): + session = SessionModule( + user_management=mock_user_management, + client_id=TEST_CONSTANTS["CLIENT_ID"], + session_data="invalid_session_data", + cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"] + ) + + response = session.refresh() + + assert isinstance(response, RefreshWithSessionCookieErrorResponse) + assert response.reason == AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE + +def test_refresh_success(TEST_CONSTANTS, mock_user_management): + # Create mock JWKS client + mock_jwks = Mock(spec=PyJWKClient) + mock_signing_key = Mock() + mock_signing_key.key = TEST_CONSTANTS["PUBLIC_KEY"] + mock_jwks.get_signing_key_from_jwt.return_value = mock_signing_key + + test_user = { + "object": "user", + "id": TEST_CONSTANTS["USER_ID"], + "email": "user@example.com", + "first_name": "Test", + "last_name": "User", + "email_verified": True, + "created_at": TEST_CONSTANTS["CURRENT_TIMESTAMP"], + "updated_at": TEST_CONSTANTS["CURRENT_TIMESTAMP"], + } + + session_data = SessionModule.seal_data({ + "refresh_token": "refresh_token_12345", + "user": test_user + }, TEST_CONSTANTS["COOKIE_PASSWORD"]) + + mock_response = { + "access_token": TEST_CONSTANTS["TEST_TOKEN"], + "refresh_token": "refresh_token_123", + "sealed_session": session_data, + "user": test_user + } + + mock_user_management.authenticate_with_refresh_token.return_value = RefreshTokenAuthenticationResponse( + **mock_response + ) + + with ( + patch( + 'workos.session.PyJWKClient', + return_value=mock_jwks + ), + ): + session = SessionModule( + user_management=mock_user_management, + client_id=TEST_CONSTANTS["CLIENT_ID"], + session_data=session_data, + cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"] + ) + + with ( + patch.object( + session, + "is_valid_jwt", + return_value=True + ), + patch( + "jwt.decode", + return_value={ + "sid": TEST_CONSTANTS["SESSION_ID"], + "org_id": TEST_CONSTANTS["ORGANIZATION_ID"], + "role": "admin", + "permissions": ["read"], + "entitlements": ["feature_1"] + } + ) + ): + response = session.refresh() + + assert isinstance(response, RefreshWithSessionCookieSuccessResponse) + assert response.authenticated is True + assert response.user.id == test_user["id"] + + # Verify the refresh token was used correctly + mock_user_management.authenticate_with_refresh_token.assert_called_once_with( + refresh_token="refresh_token_12345", + organization_id=None, + session={ + "seal_session": True, + "cookie_password": TEST_CONSTANTS["COOKIE_PASSWORD"] + } + ) + + +def test_seal_data(TEST_CONSTANTS): + test_data = {"test": "data"} + sealed = SessionModule.seal_data(test_data, TEST_CONSTANTS["COOKIE_PASSWORD"]) + assert isinstance(sealed, str) + + # Test unsealing + unsealed = SessionModule.unseal_data(sealed, TEST_CONSTANTS["COOKIE_PASSWORD"]) + assert unsealed == test_data + +def test_unseal_invalid_data(TEST_CONSTANTS): + with pytest.raises(Exception): # Adjust exception type based on your implementation + SessionModule.unseal_data("invalid_sealed_data", TEST_CONSTANTS["COOKIE_PASSWORD"]) diff --git a/tests/test_user_management.py b/tests/test_user_management.py index 3935f74d..3df9e417 100644 --- a/tests/test_user_management.py +++ b/tests/test_user_management.py @@ -60,9 +60,12 @@ def base_authentication_params(self): @pytest.fixture def mock_auth_refresh_token_response(self): + user = MockUser("user_01H7ZGXFP5C6BBQY6Z7277ZCT0").dict() + return { "access_token": "access_token_12345", "refresh_token": "refresh_token_12345", + "user": user, } @pytest.fixture diff --git a/workos/session.py b/workos/session.py index 64537fb6..91590b23 100644 --- a/workos/session.py +++ b/workos/session.py @@ -8,6 +8,8 @@ AuthenticateWithSessionCookieFailureReason, AuthenticateWithSessionCookieSuccessResponse, AuthenticateWithSessionCookieErrorResponse, + RefreshWithSessionCookieErrorResponse, + RefreshWithSessionCookieSuccessResponse, ) class SessionModule: @@ -75,25 +77,24 @@ def authenticate( entitlements=decoded.get("entitlements", None), user=session["user"], impersonator=session.get("impersonator", None), - reason=None, ) def refresh(self, options: Optional[Dict[str, Any]] = None) -> Union[ - AuthenticateWithSessionCookieSuccessResponse, - AuthenticateWithSessionCookieErrorResponse, + RefreshWithSessionCookieSuccessResponse, + RefreshWithSessionCookieErrorResponse, ]: - cookie_password = options.get("cookie_password", self.cookie_password) - organization_id = options.get("organization_id", None) + cookie_password = self.cookie_password if options is None else options.get("cookie_password") + organization_id = None if options is None else options.get("organization_id") try: session = self.unseal_data(self.session_data, cookie_password) except Exception: - return AuthenticateWithSessionCookieErrorResponse( + return RefreshWithSessionCookieErrorResponse( authenticated=False, reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE ) if not session["refresh_token"] or not session["user"]: - return AuthenticateWithSessionCookieErrorResponse( + return RefreshWithSessionCookieErrorResponse( authenticated=False, reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE ) @@ -101,19 +102,34 @@ def refresh(self, options: Optional[Dict[str, Any]] = None) -> Union[ auth_response = self.user_management.authenticate_with_refresh_token( refresh_token=session["refresh_token"], organization_id=organization_id, + session={ + "seal_session": True, + "cookie_password": cookie_password + } ) self.session_data = auth_response.sealed_session self.cookie_password = cookie_password - return AuthenticateWithSessionCookieSuccessResponse( + signing_key = self.jwks.get_signing_key_from_jwt(auth_response.access_token) + + decoded = jwt.decode( + auth_response.access_token, signing_key.key, algorithms=self.jwk_algorithms + ) + + return RefreshWithSessionCookieSuccessResponse( authenticated=True, sealed_session=auth_response.sealed_session, - session=auth_response, - reason=None, + session_id=decoded["sid"], + organization_id=decoded.get("org_id", None), + role=decoded.get("role", None), + permissions=decoded.get("permissions", None), + entitlements=decoded.get("entitlements", None), + user=auth_response.user, + impersonator=auth_response.impersonator, ) except Exception as e: - return AuthenticateWithSessionCookieErrorResponse( + return RefreshWithSessionCookieErrorResponse( authenticated=False, reason=str(e) ) @@ -138,10 +154,13 @@ def is_valid_jwt(self, token: str) -> bool: @staticmethod def seal_data(data: Dict[str, Any], key: str) -> str: fernet = Fernet(key) - # take the data and encrypt it with the key using fernet - return fernet.encrypt(json.dumps(data).encode()) + # Encrypt and convert bytes to string + encrypted_bytes = fernet.encrypt(json.dumps(data).encode()) + return encrypted_bytes.decode('utf-8') @staticmethod def unseal_data(sealed_data: str, key: str) -> Dict[str, Any]: fernet = Fernet(key) - return json.loads(fernet.decrypt(sealed_data).decode()) + # Convert string back to bytes before decryption + encrypted_bytes = sealed_data.encode('utf-8') + return json.loads(fernet.decrypt(encrypted_bytes).decode()) diff --git a/workos/types/user_management/authenticate_with_common.py b/workos/types/user_management/authenticate_with_common.py index 73de96ac..8adc9ee6 100644 --- a/workos/types/user_management/authenticate_with_common.py +++ b/workos/types/user_management/authenticate_with_common.py @@ -51,6 +51,7 @@ class AuthenticateWithRefreshTokenParameters(AuthenticateWithBaseParameters): refresh_token: str organization_id: Union[str, None] grant_type: Literal["refresh_token"] + session: Union[SessionConfig, None] AuthenticateWithParameters = Union[ diff --git a/workos/types/user_management/authentication_response.py b/workos/types/user_management/authentication_response.py index 6a5f4449..96973fd7 100644 --- a/workos/types/user_management/authentication_response.py +++ b/workos/types/user_management/authentication_response.py @@ -39,7 +39,7 @@ class AuthKitAuthenticationResponse(AuthenticationResponse): oauth_tokens: Optional[OAuthTokens] = None -class RefreshTokenAuthenticationResponse(_AuthenticationResponseBase): +class RefreshTokenAuthenticationResponse(AuthenticationResponse): """Representation of a WorkOS refresh token authentication response.""" pass diff --git a/workos/types/user_management/session.py b/workos/types/user_management/session.py index 5e543cb4..1c8cc6a7 100644 --- a/workos/types/user_management/session.py +++ b/workos/types/user_management/session.py @@ -1,4 +1,4 @@ -from typing import List, Optional, TypedDict +from typing import List, Optional, TypedDict, Union from enum import Enum from workos.types.user_management.impersonator import Impersonator @@ -22,7 +22,15 @@ class AuthenticateWithSessionCookieSuccessResponse(WorkOSModel): class AuthenticateWithSessionCookieErrorResponse(WorkOSModel): authenticated: bool = False - reason: AuthenticateWithSessionCookieFailureReason + reason: Union[AuthenticateWithSessionCookieFailureReason, str] + +class RefreshWithSessionCookieSuccessResponse(AuthenticateWithSessionCookieSuccessResponse): + sealed_session: str + + +class RefreshWithSessionCookieErrorResponse(WorkOSModel): + authenticated: bool = False + reason: Union[AuthenticateWithSessionCookieFailureReason, str] class SessionConfig(TypedDict): seal_session: bool diff --git a/workos/user_management.py b/workos/user_management.py index 66232e3a..976d36af 100644 --- a/workos/user_management.py +++ b/workos/user_management.py @@ -431,6 +431,7 @@ def authenticate_with_code( self, *, code: str, + session: Optional[SessionConfig] = None, code_verifier: Optional[str] = None, ip_address: Optional[str] = None, user_agent: Optional[str] = None, @@ -439,6 +440,7 @@ def authenticate_with_code( Kwargs: code (str): The authorization value which was passed back as a query parameter in the callback to the Redirect URI. + session (SessionConfig): Configuration for the session. (Optional) code_verifier (str): The randomly generated string used to derive the code challenge that was passed to the authorization url as part of the PKCE flow. This parameter is required when the client secret is not present. (Optional) ip_address (str): The IP address of the request from the user who is attempting to authenticate. (Optional) @@ -542,6 +544,7 @@ def authenticate_with_refresh_token( self, *, refresh_token: str, + session: Optional[SessionConfig] = None, organization_id: Optional[str] = None, ip_address: Optional[str] = None, user_agent: Optional[str] = None, @@ -550,6 +553,7 @@ def authenticate_with_refresh_token( Kwargs: refresh_token (str): The token associated to the user. + session (SessionConfig): Configuration for the session. (Optional) organization_id (str): The organization to issue the new access token for. (Optional) ip_address (str): The IP address of the request from the user who is attempting to authenticate. (Optional) user_agent (str): The user agent of the request from the user who is attempting to authenticate. (Optional) @@ -1030,8 +1034,8 @@ def _authenticate_with( json=json, ) - if payload["session"] is not None and payload["session"].get("seal_session") is True: - response["sealed_session"] = SessionModule.seal_data(response, payload["session"]["cookie_password"]) + if payload.get("session") is not None and payload.get("session").get("seal_session") is True: + response["sealed_session"] = SessionModule.seal_data(response, payload.get("session").get("cookie_password")) return response_model.model_validate(response) @@ -1158,16 +1162,21 @@ def authenticate_with_refresh_token( self, *, refresh_token: str, + session: Optional[SessionConfig] = None, organization_id: Optional[str] = None, ip_address: Optional[str] = None, user_agent: Optional[str] = None, ) -> RefreshTokenAuthenticationResponse: + if session is not None and (session.get("seal_session") is True and session.get("cookie_password") is None or ""): + raise ValueError("cookie_password is required when sealing session") + payload: AuthenticateWithRefreshTokenParameters = { "refresh_token": refresh_token, "organization_id": organization_id, "grant_type": "refresh_token", "ip_address": ip_address, "user_agent": user_agent, + "session": session, } return self._authenticate_with( @@ -1614,6 +1623,9 @@ async def _authenticate_with( json=json, ) + if payload.get("session") is not None and payload.get("session").get("seal_session") is True: + response["sealed_session"] = SessionModule.seal_data(response, payload.get("session").get("cookie_password")) + return response_model.model_validate(response) async def authenticate_with_password( @@ -1640,16 +1652,21 @@ async def authenticate_with_code( self, *, code: str, + session: Optional[SessionConfig] = None, code_verifier: Optional[str] = None, ip_address: Optional[str] = None, user_agent: Optional[str] = None, ) -> AuthKitAuthenticationResponse: + if session is not None and (session.get("seal_session") is True and session.get("cookie_password") is None or ""): + raise ValueError("cookie_password is required when sealing session") + payload: AuthenticateWithCodeParameters = { "code": code, "grant_type": "authorization_code", "ip_address": ip_address, "user_agent": user_agent, "code_verifier": code_verifier, + "session": session, } return await self._authenticate_with( @@ -1744,16 +1761,21 @@ async def authenticate_with_refresh_token( self, *, refresh_token: str, + session: Optional[SessionConfig] = None, organization_id: Optional[str] = None, ip_address: Optional[str] = None, user_agent: Optional[str] = None, ) -> RefreshTokenAuthenticationResponse: + if session is not None and (session.get("seal_session") is True and session.get("cookie_password") is None or ""): + raise ValueError("cookie_password is required when sealing session") + payload: AuthenticateWithRefreshTokenParameters = { "refresh_token": refresh_token, "organization_id": organization_id, "grant_type": "refresh_token", "ip_address": ip_address, "user_agent": user_agent, + "session": session, } return await self._authenticate_with( From 13bb07362711db32c63a62af2cf26784c538bef5 Mon Sep 17 00:00:00 2001 From: Paul Asjes Date: Fri, 22 Nov 2024 15:08:13 +0100 Subject: [PATCH 06/14] Mock JWKS call to speed up tests --- tests/conftest.py | 17 ++++++ tests/test_session.py | 123 +++++++++++++++++------------------------- 2 files changed, 66 insertions(+), 74 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 81ef0ca8..1bd77d72 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,6 +24,9 @@ from workos.utils.http_client import AsyncHTTPClient, HTTPClient, SyncHTTPClient from workos.utils.request_helper import DEFAULT_LIST_RESPONSE_LIMIT +from jwt import PyJWKClient +from unittest.mock import Mock, patch +from functools import wraps def _get_test_client_setup( http_client_class_name: str, @@ -302,3 +305,17 @@ def inner( assert request_kwargs["params"][param] == params[param] return inner + +def with_jwks_mock(func): + @wraps(func) + def wrapper(*args, **kwargs): + # Create mock JWKS client + mock_jwks = Mock(spec=PyJWKClient) + mock_signing_key = Mock() + mock_signing_key.key = kwargs['TEST_CONSTANTS']["PUBLIC_KEY"] + mock_jwks.get_signing_key_from_jwt.return_value = mock_signing_key + + # Apply the mock + with patch('workos.session.PyJWKClient', return_value=mock_jwks): + return func(*args, **kwargs) + return wrapper \ No newline at end of file diff --git a/tests/test_session.py b/tests/test_session.py index a898949c..9e3cc021 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -4,6 +4,7 @@ from jwt import PyJWKClient from datetime import datetime, timezone +from tests.conftest import with_jwks_mock from workos.session import SessionModule from workos.types.user_management.authentication_response import RefreshTokenAuthenticationResponse from workos.types.user_management.session import ( @@ -60,34 +61,13 @@ def TEST_CONSTANTS(): } @pytest.fixture -def mock_user_management(TEST_CONSTANTS): - mock_jwks = Mock(spec=PyJWKClient) - mock_jwk_set = Mock() - mock_jwk_set.keys = [Mock(Algorithm="RS256")] - mock_jwks.get_jwk_set.return_value = mock_jwk_set - - - +def mock_user_management(): mock = Mock() mock.get_jwks_url.return_value = "https://api.workos.com/user_management/sso/jwks/client_123" - mock.authenticate_with_refresh_token.return_value = RefreshTokenAuthenticationResponse( - **{ - "access_token": "access_token_123", - "refresh_token": "refresh_token_123", - "user": { - "object": "user", - "id": TEST_CONSTANTS["USER_ID"], - "email": "user@example.com", - "first_name": "Test", - "last_name": "User", - "email_verified": True, - "created_at": TEST_CONSTANTS["CURRENT_TIMESTAMP"], - "updated_at": TEST_CONSTANTS["CURRENT_TIMESTAMP"], - } - } - ) + return mock +@with_jwks_mock def test_initialize_session_module(TEST_CONSTANTS, mock_user_management): session = SessionModule( user_management=mock_user_management, @@ -99,15 +79,17 @@ def test_initialize_session_module(TEST_CONSTANTS, mock_user_management): assert session.client_id == TEST_CONSTANTS["CLIENT_ID"] assert session.cookie_password is not None -def test_initialize_without_cookie_password(mock_user_management): +@with_jwks_mock +def test_initialize_without_cookie_password(TEST_CONSTANTS, mock_user_management): with pytest.raises(ValueError, match="cookie_password is required"): SessionModule( user_management=mock_user_management, - client_id="client_123", - session_data="session_data", + client_id=TEST_CONSTANTS["CLIENT_ID"], + session_data=TEST_CONSTANTS["SESSION_DATA"], cookie_password="" ) +@with_jwks_mock def test_authenticate_no_session_cookie_provided(TEST_CONSTANTS, mock_user_management): session = SessionModule( user_management=mock_user_management, @@ -120,6 +102,7 @@ def test_authenticate_no_session_cookie_provided(TEST_CONSTANTS, mock_user_manag assert response.reason == AuthenticateWithSessionCookieFailureReason.NO_SESSION_COOKIE_PROVIDED +@with_jwks_mock def test_authenticate_invalid_session_cookie(TEST_CONSTANTS, mock_user_management): session = SessionModule( user_management=mock_user_management, @@ -132,6 +115,7 @@ def test_authenticate_invalid_session_cookie(TEST_CONSTANTS, mock_user_managemen assert response.reason == AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE +@with_jwks_mock def test_authenticate_invalid_jwt(TEST_CONSTANTS, mock_user_management): invalid_session_data = SessionModule.seal_data({ "access_token": "invalid_session_data" }, TEST_CONSTANTS["COOKIE_PASSWORD"]) session = SessionModule( @@ -145,6 +129,7 @@ def test_authenticate_invalid_jwt(TEST_CONSTANTS, mock_user_management): assert response.reason == AuthenticateWithSessionCookieFailureReason.INVALID_JWT +@with_jwks_mock def test_authenticate_success(TEST_CONSTANTS, mock_user_management): session = SessionModule( user_management=mock_user_management, @@ -225,6 +210,7 @@ def test_authenticate_success(TEST_CONSTANTS, mock_user_management): assert response.user.id == TEST_CONSTANTS["USER_ID"] assert response.impersonator is None +@with_jwks_mock def test_refresh_invalid_session_cookie(TEST_CONSTANTS, mock_user_management): session = SessionModule( user_management=mock_user_management, @@ -238,13 +224,8 @@ def test_refresh_invalid_session_cookie(TEST_CONSTANTS, mock_user_management): assert isinstance(response, RefreshWithSessionCookieErrorResponse) assert response.reason == AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE +@with_jwks_mock def test_refresh_success(TEST_CONSTANTS, mock_user_management): - # Create mock JWKS client - mock_jwks = Mock(spec=PyJWKClient) - mock_signing_key = Mock() - mock_signing_key.key = TEST_CONSTANTS["PUBLIC_KEY"] - mock_jwks.get_signing_key_from_jwt.return_value = mock_signing_key - test_user = { "object": "user", "id": TEST_CONSTANTS["USER_ID"], @@ -272,51 +253,45 @@ def test_refresh_success(TEST_CONSTANTS, mock_user_management): **mock_response ) + session = SessionModule( + user_management=mock_user_management, + client_id=TEST_CONSTANTS["CLIENT_ID"], + session_data=session_data, + cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"] + ) + with ( - patch( - 'workos.session.PyJWKClient', - return_value=mock_jwks + patch.object( + session, + "is_valid_jwt", + return_value=True ), - ): - session = SessionModule( - user_management=mock_user_management, - client_id=TEST_CONSTANTS["CLIENT_ID"], - session_data=session_data, - cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"] - ) - - with ( - patch.object( - session, - "is_valid_jwt", - return_value=True - ), - patch( - "jwt.decode", - return_value={ - "sid": TEST_CONSTANTS["SESSION_ID"], - "org_id": TEST_CONSTANTS["ORGANIZATION_ID"], - "role": "admin", - "permissions": ["read"], - "entitlements": ["feature_1"] - } - ) - ): - response = session.refresh() - - assert isinstance(response, RefreshWithSessionCookieSuccessResponse) - assert response.authenticated is True - assert response.user.id == test_user["id"] - - # Verify the refresh token was used correctly - mock_user_management.authenticate_with_refresh_token.assert_called_once_with( - refresh_token="refresh_token_12345", - organization_id=None, - session={ - "seal_session": True, - "cookie_password": TEST_CONSTANTS["COOKIE_PASSWORD"] + patch( + "jwt.decode", + return_value={ + "sid": TEST_CONSTANTS["SESSION_ID"], + "org_id": TEST_CONSTANTS["ORGANIZATION_ID"], + "role": "admin", + "permissions": ["read"], + "entitlements": ["feature_1"] } ) + ): + response = session.refresh() + + assert isinstance(response, RefreshWithSessionCookieSuccessResponse) + assert response.authenticated is True + assert response.user.id == test_user["id"] + + # Verify the refresh token was used correctly + mock_user_management.authenticate_with_refresh_token.assert_called_once_with( + refresh_token="refresh_token_12345", + organization_id=None, + session={ + "seal_session": True, + "cookie_password": TEST_CONSTANTS["COOKIE_PASSWORD"] + } + ) def test_seal_data(TEST_CONSTANTS): From 38fa6c9eb0b2d5b2f23f6dcf98e7902d4b23c9de Mon Sep 17 00:00:00 2001 From: Paul Asjes Date: Fri, 22 Nov 2024 15:22:04 +0100 Subject: [PATCH 07/14] linting --- setup.py | 5 +- tests/conftest.py | 9 +- tests/test_session.py | 126 +++++++++++++----------- workos/session.py | 48 +++++---- workos/types/user_management/session.py | 17 +++- workos/user_management.py | 57 ++++++++--- 6 files changed, 163 insertions(+), 99 deletions(-) diff --git a/setup.py b/setup.py index bfd5dbcb..103be545 100644 --- a/setup.py +++ b/setup.py @@ -10,10 +10,11 @@ with open(os.path.join(base_dir, "workos", "__about__.py")) as f: exec(f.read(), about) + def read_requirements(filename): with open(filename) as f: - return [line.strip() for line in f - if line.strip() and not line.startswith('#')] + return [line.strip() for line in f if line.strip() and not line.startswith("#")] + setup( name=about["__package_name__"], diff --git a/tests/conftest.py b/tests/conftest.py index 1bd77d72..7f9f058a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -28,6 +28,7 @@ from unittest.mock import Mock, patch from functools import wraps + def _get_test_client_setup( http_client_class_name: str, ) -> Tuple[Literal["async", "sync"], ClientConfiguration, HTTPClient]: @@ -306,16 +307,18 @@ def inner( return inner + def with_jwks_mock(func): @wraps(func) def wrapper(*args, **kwargs): # Create mock JWKS client mock_jwks = Mock(spec=PyJWKClient) mock_signing_key = Mock() - mock_signing_key.key = kwargs['TEST_CONSTANTS']["PUBLIC_KEY"] + mock_signing_key.key = kwargs["TEST_CONSTANTS"]["PUBLIC_KEY"] mock_jwks.get_signing_key_from_jwt.return_value = mock_signing_key # Apply the mock - with patch('workos.session.PyJWKClient', return_value=mock_jwks): + with patch("workos.session.PyJWKClient", return_value=mock_jwks): return func(*args, **kwargs) - return wrapper \ No newline at end of file + + return wrapper diff --git a/tests/test_session.py b/tests/test_session.py index 9e3cc021..ee65f331 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -6,7 +6,9 @@ from tests.conftest import with_jwks_mock from workos.session import SessionModule -from workos.types.user_management.authentication_response import RefreshTokenAuthenticationResponse +from workos.types.user_management.authentication_response import ( + RefreshTokenAuthenticationResponse, +) from workos.types.user_management.session import ( AuthenticateWithSessionCookieFailureReason, AuthenticateWithSessionCookieSuccessResponse, @@ -18,13 +20,11 @@ from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa + @pytest.fixture(scope="session") def TEST_CONSTANTS(): # Generate RSA key pair for testing - private_key = rsa.generate_private_key( - public_exponent=65537, - key_size=2048 - ) + private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) public_key = private_key.public_key() @@ -32,7 +32,7 @@ def TEST_CONSTANTS(): private_pem = private_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption() + encryption_algorithm=serialization.NoEncryption(), ) return { @@ -56,29 +56,34 @@ def TEST_CONSTANTS(): "iat": int(datetime.now(timezone.utc).timestamp()), }, private_pem, - algorithm="RS256" - ) + algorithm="RS256", + ), } + @pytest.fixture def mock_user_management(): mock = Mock() - mock.get_jwks_url.return_value = "https://api.workos.com/user_management/sso/jwks/client_123" + mock.get_jwks_url.return_value = ( + "https://api.workos.com/user_management/sso/jwks/client_123" + ) return mock + @with_jwks_mock def test_initialize_session_module(TEST_CONSTANTS, mock_user_management): session = SessionModule( user_management=mock_user_management, client_id=TEST_CONSTANTS["CLIENT_ID"], session_data=TEST_CONSTANTS["SESSION_DATA"], - cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"] + cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"], ) assert session.client_id == TEST_CONSTANTS["CLIENT_ID"] assert session.cookie_password is not None + @with_jwks_mock def test_initialize_without_cookie_password(TEST_CONSTANTS, mock_user_management): with pytest.raises(ValueError, match="cookie_password is required"): @@ -86,21 +91,26 @@ def test_initialize_without_cookie_password(TEST_CONSTANTS, mock_user_management user_management=mock_user_management, client_id=TEST_CONSTANTS["CLIENT_ID"], session_data=TEST_CONSTANTS["SESSION_DATA"], - cookie_password="" + cookie_password="", ) + @with_jwks_mock def test_authenticate_no_session_cookie_provided(TEST_CONSTANTS, mock_user_management): session = SessionModule( user_management=mock_user_management, client_id=TEST_CONSTANTS["CLIENT_ID"], session_data=None, - cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"] + cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"], ) response = session.authenticate() - assert response.reason == AuthenticateWithSessionCookieFailureReason.NO_SESSION_COOKIE_PROVIDED + assert ( + response.reason + == AuthenticateWithSessionCookieFailureReason.NO_SESSION_COOKIE_PROVIDED + ) + @with_jwks_mock def test_authenticate_invalid_session_cookie(TEST_CONSTANTS, mock_user_management): @@ -108,34 +118,41 @@ def test_authenticate_invalid_session_cookie(TEST_CONSTANTS, mock_user_managemen user_management=mock_user_management, client_id=TEST_CONSTANTS["CLIENT_ID"], session_data="invalid_session_data", - cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"] + cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"], ) response = session.authenticate() - assert response.reason == AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE + assert ( + response.reason + == AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE + ) + @with_jwks_mock def test_authenticate_invalid_jwt(TEST_CONSTANTS, mock_user_management): - invalid_session_data = SessionModule.seal_data({ "access_token": "invalid_session_data" }, TEST_CONSTANTS["COOKIE_PASSWORD"]) + invalid_session_data = SessionModule.seal_data( + {"access_token": "invalid_session_data"}, TEST_CONSTANTS["COOKIE_PASSWORD"] + ) session = SessionModule( user_management=mock_user_management, client_id=TEST_CONSTANTS["CLIENT_ID"], session_data=invalid_session_data, - cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"] + cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"], ) response = session.authenticate() assert response.reason == AuthenticateWithSessionCookieFailureReason.INVALID_JWT + @with_jwks_mock def test_authenticate_success(TEST_CONSTANTS, mock_user_management): session = SessionModule( user_management=mock_user_management, client_id=TEST_CONSTANTS["CLIENT_ID"], session_data=TEST_CONSTANTS["SESSION_DATA"], - cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"] + cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"], ) # Mock the session data that would be unsealed @@ -151,7 +168,7 @@ def test_authenticate_success(TEST_CONSTANTS, mock_user_management): "iat": int(datetime.now(timezone.utc).timestamp()), }, TEST_CONSTANTS["PRIVATE_KEY"], - algorithm="RS256" + algorithm="RS256", ), "user": { "object": "user", @@ -161,7 +178,7 @@ def test_authenticate_success(TEST_CONSTANTS, mock_user_management): "created_at": TEST_CONSTANTS["CURRENT_TIMESTAMP"], "updated_at": TEST_CONSTANTS["CURRENT_TIMESTAMP"], }, - "impersonator": None + "impersonator": None, } # Mock the JWT payload that would be decoded @@ -170,33 +187,22 @@ def test_authenticate_success(TEST_CONSTANTS, mock_user_management): "org_id": TEST_CONSTANTS["ORGANIZATION_ID"], "role": "admin", "permissions": ["read"], - "entitlements": ["feature_1"] + "entitlements": ["feature_1"], } with ( # Mock unsealing the session data - patch.object( - SessionModule, - "unseal_data", - return_value=mock_session - ), + patch.object(SessionModule, "unseal_data", return_value=mock_session), # Mock JWT validation - patch.object( - session, - "is_valid_jwt", - return_value=True - ), + patch.object(session, "is_valid_jwt", return_value=True), # Mock JWT decoding - patch( - "jwt.decode", - return_value=mock_jwt_payload - ), + patch("jwt.decode", return_value=mock_jwt_payload), # Mock JWT signing key retrieval patch.object( session.jwks, "get_signing_key_from_jwt", - return_value=Mock(key=TEST_CONSTANTS["PUBLIC_KEY"]) - ) + return_value=Mock(key=TEST_CONSTANTS["PUBLIC_KEY"]), + ), ): response = session.authenticate() @@ -210,19 +216,24 @@ def test_authenticate_success(TEST_CONSTANTS, mock_user_management): assert response.user.id == TEST_CONSTANTS["USER_ID"] assert response.impersonator is None + @with_jwks_mock def test_refresh_invalid_session_cookie(TEST_CONSTANTS, mock_user_management): session = SessionModule( user_management=mock_user_management, client_id=TEST_CONSTANTS["CLIENT_ID"], session_data="invalid_session_data", - cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"] + cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"], ) response = session.refresh() assert isinstance(response, RefreshWithSessionCookieErrorResponse) - assert response.reason == AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE + assert ( + response.reason + == AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE + ) + @with_jwks_mock def test_refresh_success(TEST_CONSTANTS, mock_user_management): @@ -237,35 +248,31 @@ def test_refresh_success(TEST_CONSTANTS, mock_user_management): "updated_at": TEST_CONSTANTS["CURRENT_TIMESTAMP"], } - session_data = SessionModule.seal_data({ - "refresh_token": "refresh_token_12345", - "user": test_user - }, TEST_CONSTANTS["COOKIE_PASSWORD"]) + session_data = SessionModule.seal_data( + {"refresh_token": "refresh_token_12345", "user": test_user}, + TEST_CONSTANTS["COOKIE_PASSWORD"], + ) mock_response = { "access_token": TEST_CONSTANTS["TEST_TOKEN"], "refresh_token": "refresh_token_123", "sealed_session": session_data, - "user": test_user + "user": test_user, } - mock_user_management.authenticate_with_refresh_token.return_value = RefreshTokenAuthenticationResponse( - **mock_response + mock_user_management.authenticate_with_refresh_token.return_value = ( + RefreshTokenAuthenticationResponse(**mock_response) ) session = SessionModule( user_management=mock_user_management, client_id=TEST_CONSTANTS["CLIENT_ID"], session_data=session_data, - cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"] + cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"], ) with ( - patch.object( - session, - "is_valid_jwt", - return_value=True - ), + patch.object(session, "is_valid_jwt", return_value=True), patch( "jwt.decode", return_value={ @@ -273,9 +280,9 @@ def test_refresh_success(TEST_CONSTANTS, mock_user_management): "org_id": TEST_CONSTANTS["ORGANIZATION_ID"], "role": "admin", "permissions": ["read"], - "entitlements": ["feature_1"] - } - ) + "entitlements": ["feature_1"], + }, + ), ): response = session.refresh() @@ -289,8 +296,8 @@ def test_refresh_success(TEST_CONSTANTS, mock_user_management): organization_id=None, session={ "seal_session": True, - "cookie_password": TEST_CONSTANTS["COOKIE_PASSWORD"] - } + "cookie_password": TEST_CONSTANTS["COOKIE_PASSWORD"], + }, ) @@ -303,6 +310,9 @@ def test_seal_data(TEST_CONSTANTS): unsealed = SessionModule.unseal_data(sealed, TEST_CONSTANTS["COOKIE_PASSWORD"]) assert unsealed == test_data + def test_unseal_invalid_data(TEST_CONSTANTS): with pytest.raises(Exception): # Adjust exception type based on your implementation - SessionModule.unseal_data("invalid_sealed_data", TEST_CONSTANTS["COOKIE_PASSWORD"]) + SessionModule.unseal_data( + "invalid_sealed_data", TEST_CONSTANTS["COOKIE_PASSWORD"] + ) diff --git a/workos/session.py b/workos/session.py index 91590b23..dc4bf2d1 100644 --- a/workos/session.py +++ b/workos/session.py @@ -12,6 +12,7 @@ RefreshWithSessionCookieSuccessResponse, ) + class SessionModule: def __init__( self, @@ -19,7 +20,7 @@ def __init__( user_management: Any, client_id: str, session_data: str, - cookie_password: str + cookie_password: str, ) -> None: # If the cookie password is not provided, throw an error if cookie_password is None or cookie_password == "": @@ -33,7 +34,7 @@ def __init__( self.jwks = PyJWKClient(self.user_management.get_jwks_url()) # Algorithms are hardcoded for security reasons. See https://pyjwt.readthedocs.io/en/stable/algorithms.html#specifying-an-algorithm - self.jwk_algorithms = ['RS256'] + self.jwk_algorithms = ["RS256"] def authenticate( self, @@ -43,24 +44,28 @@ def authenticate( ]: if self.session_data is None: return AuthenticateWithSessionCookieErrorResponse( - authenticated=False, reason=AuthenticateWithSessionCookieFailureReason.NO_SESSION_COOKIE_PROVIDED + authenticated=False, + reason=AuthenticateWithSessionCookieFailureReason.NO_SESSION_COOKIE_PROVIDED, ) try: session = self.unseal_data(self.session_data, self.cookie_password) except Exception: return AuthenticateWithSessionCookieErrorResponse( - authenticated=False, reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE + authenticated=False, + reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE, ) if not session["access_token"]: return AuthenticateWithSessionCookieErrorResponse( - authenticated=False, reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE + authenticated=False, + reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE, ) if not self.is_valid_jwt(session["access_token"]): return AuthenticateWithSessionCookieErrorResponse( - authenticated=False, reason=AuthenticateWithSessionCookieFailureReason.INVALID_JWT + authenticated=False, + reason=AuthenticateWithSessionCookieFailureReason.INVALID_JWT, ) signing_key = self.jwks.get_signing_key_from_jwt(session["access_token"]) @@ -83,29 +88,30 @@ def refresh(self, options: Optional[Dict[str, Any]] = None) -> Union[ RefreshWithSessionCookieSuccessResponse, RefreshWithSessionCookieErrorResponse, ]: - cookie_password = self.cookie_password if options is None else options.get("cookie_password") + cookie_password = ( + self.cookie_password if options is None else options.get("cookie_password") + ) organization_id = None if options is None else options.get("organization_id") try: session = self.unseal_data(self.session_data, cookie_password) except Exception: return RefreshWithSessionCookieErrorResponse( - authenticated=False, reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE + authenticated=False, + reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE, ) if not session["refresh_token"] or not session["user"]: return RefreshWithSessionCookieErrorResponse( - authenticated=False, reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE + authenticated=False, + reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE, ) try: auth_response = self.user_management.authenticate_with_refresh_token( refresh_token=session["refresh_token"], organization_id=organization_id, - session={ - "seal_session": True, - "cookie_password": cookie_password - } + session={"seal_session": True, "cookie_password": cookie_password}, ) self.session_data = auth_response.sealed_session @@ -114,7 +120,9 @@ def refresh(self, options: Optional[Dict[str, Any]] = None) -> Union[ signing_key = self.jwks.get_signing_key_from_jwt(auth_response.access_token) decoded = jwt.decode( - auth_response.access_token, signing_key.key, algorithms=self.jwk_algorithms + auth_response.access_token, + signing_key.key, + algorithms=self.jwk_algorithms, ) return RefreshWithSessionCookieSuccessResponse( @@ -137,11 +145,11 @@ def get_logout_url(self) -> str: auth_response = self.authenticate() if not auth_response.authenticated: - raise ValueError(f"Failed to extract session ID for logout URL: {auth_response.reason}") + raise ValueError( + f"Failed to extract session ID for logout URL: {auth_response.reason}" + ) - return self.user_management.get_logout_url( - session_id=auth_response.session_id - ) + return self.user_management.get_logout_url(session_id=auth_response.session_id) def is_valid_jwt(self, token: str) -> bool: try: @@ -156,11 +164,11 @@ def seal_data(data: Dict[str, Any], key: str) -> str: fernet = Fernet(key) # Encrypt and convert bytes to string encrypted_bytes = fernet.encrypt(json.dumps(data).encode()) - return encrypted_bytes.decode('utf-8') + return encrypted_bytes.decode("utf-8") @staticmethod def unseal_data(sealed_data: str, key: str) -> Dict[str, Any]: fernet = Fernet(key) # Convert string back to bytes before decryption - encrypted_bytes = sealed_data.encode('utf-8') + encrypted_bytes = sealed_data.encode("utf-8") return json.loads(fernet.decrypt(encrypted_bytes).decode()) diff --git a/workos/types/user_management/session.py b/workos/types/user_management/session.py index 1c8cc6a7..d6c859e5 100644 --- a/workos/types/user_management/session.py +++ b/workos/types/user_management/session.py @@ -5,10 +5,12 @@ from workos.types.user_management.user import User from workos.types.workos_model import WorkOSModel + class AuthenticateWithSessionCookieFailureReason(Enum): - INVALID_JWT = 'invalid_jwt' - INVALID_SESSION_COOKIE = 'invalid_session_cookie' - NO_SESSION_COOKIE_PROVIDED = 'no_session_cookie_provided' + INVALID_JWT = "invalid_jwt" + INVALID_SESSION_COOKIE = "invalid_session_cookie" + NO_SESSION_COOKIE_PROVIDED = "no_session_cookie_provided" + class AuthenticateWithSessionCookieSuccessResponse(WorkOSModel): authenticated: bool = True @@ -20,11 +22,15 @@ class AuthenticateWithSessionCookieSuccessResponse(WorkOSModel): impersonator: Optional[Impersonator] = None entitlements: Optional[List[str]] = None + class AuthenticateWithSessionCookieErrorResponse(WorkOSModel): authenticated: bool = False reason: Union[AuthenticateWithSessionCookieFailureReason, str] -class RefreshWithSessionCookieSuccessResponse(AuthenticateWithSessionCookieSuccessResponse): + +class RefreshWithSessionCookieSuccessResponse( + AuthenticateWithSessionCookieSuccessResponse +): sealed_session: str @@ -32,6 +38,7 @@ class RefreshWithSessionCookieErrorResponse(WorkOSModel): authenticated: bool = False reason: Union[AuthenticateWithSessionCookieFailureReason, str] + class SessionConfig(TypedDict): seal_session: bool - cookie_password: str \ No newline at end of file + cookie_password: str diff --git a/workos/user_management.py b/workos/user_management.py index 976d36af..cdde4006 100644 --- a/workos/user_management.py +++ b/workos/user_management.py @@ -111,7 +111,9 @@ class UserManagementModule(Protocol): _client_configuration: ClientConfiguration - def load_sealed_session(self, *, sealed_session: str, cookie_password: str) -> SyncOrAsync[SessionModule]: + def load_sealed_session( + self, *, sealed_session: str, cookie_password: str + ) -> SyncOrAsync[SessionModule]: """Load a sealed session and return the session data. Args: @@ -822,8 +824,15 @@ def __init__( self._client_configuration = client_configuration self._http_client = http_client - def load_sealed_session(self, *, session_data: str, cookie_password: str) -> SessionModule: - return SessionModule(user_management=self, client_id=self._http_client.client_id, session_data=session_data, cookie_password=cookie_password) + def load_sealed_session( + self, *, session_data: str, cookie_password: str + ) -> SessionModule: + return SessionModule( + user_management=self, + client_id=self._http_client.client_id, + session_data=session_data, + cookie_password=cookie_password, + ) def get_user(self, user_id: str) -> User: response = self._http_client.request( @@ -1034,8 +1043,13 @@ def _authenticate_with( json=json, ) - if payload.get("session") is not None and payload.get("session").get("seal_session") is True: - response["sealed_session"] = SessionModule.seal_data(response, payload.get("session").get("cookie_password")) + if ( + payload.get("session") is not None + and payload.get("session").get("seal_session") is True + ): + response["sealed_session"] = SessionModule.seal_data( + response, payload.get("session").get("cookie_password") + ) return response_model.model_validate(response) @@ -1066,7 +1080,11 @@ def authenticate_with_code( ip_address: Optional[str] = None, user_agent: Optional[str] = None, ) -> AuthKitAuthenticationResponse: - if session is not None and (session.get("seal_session") is True and session.get("cookie_password") is None or ""): + if session is not None and ( + session.get("seal_session") is True + and session.get("cookie_password") is None + or "" + ): raise ValueError("cookie_password is required when sealing session") payload: AuthenticateWithCodeParameters = { @@ -1167,7 +1185,11 @@ def authenticate_with_refresh_token( ip_address: Optional[str] = None, user_agent: Optional[str] = None, ) -> RefreshTokenAuthenticationResponse: - if session is not None and (session.get("seal_session") is True and session.get("cookie_password") is None or ""): + if session is not None and ( + session.get("seal_session") is True + and session.get("cookie_password") is None + or "" + ): raise ValueError("cookie_password is required when sealing session") payload: AuthenticateWithRefreshTokenParameters = { @@ -1623,8 +1645,13 @@ async def _authenticate_with( json=json, ) - if payload.get("session") is not None and payload.get("session").get("seal_session") is True: - response["sealed_session"] = SessionModule.seal_data(response, payload.get("session").get("cookie_password")) + if ( + payload.get("session") is not None + and payload.get("session").get("seal_session") is True + ): + response["sealed_session"] = SessionModule.seal_data( + response, payload.get("session").get("cookie_password") + ) return response_model.model_validate(response) @@ -1657,7 +1684,11 @@ async def authenticate_with_code( ip_address: Optional[str] = None, user_agent: Optional[str] = None, ) -> AuthKitAuthenticationResponse: - if session is not None and (session.get("seal_session") is True and session.get("cookie_password") is None or ""): + if session is not None and ( + session.get("seal_session") is True + and session.get("cookie_password") is None + or "" + ): raise ValueError("cookie_password is required when sealing session") payload: AuthenticateWithCodeParameters = { @@ -1766,7 +1797,11 @@ async def authenticate_with_refresh_token( ip_address: Optional[str] = None, user_agent: Optional[str] = None, ) -> RefreshTokenAuthenticationResponse: - if session is not None and (session.get("seal_session") is True and session.get("cookie_password") is None or ""): + if session is not None and ( + session.get("seal_session") is True + and session.get("cookie_password") is None + or "" + ): raise ValueError("cookie_password is required when sealing session") payload: AuthenticateWithRefreshTokenParameters = { From 0b2e91ca2869f2094a25a91be9a726363a2e2b96 Mon Sep 17 00:00:00 2001 From: Paul Asjes Date: Wed, 27 Nov 2024 13:52:07 +0100 Subject: [PATCH 08/14] Make tests and mypy happy --- tests/test_session.py | 35 ++++++----- workos/session.py | 30 ++++++---- workos/types/user_management/session.py | 10 ++-- workos/user_management.py | 80 +++++++++++-------------- 4 files changed, 76 insertions(+), 79 deletions(-) diff --git a/tests/test_session.py b/tests/test_session.py index ee65f331..9c38372a 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -5,7 +5,7 @@ from datetime import datetime, timezone from tests.conftest import with_jwks_mock -from workos.session import SessionModule +from workos.session import Session from workos.types.user_management.authentication_response import ( RefreshTokenAuthenticationResponse, ) @@ -73,7 +73,7 @@ def mock_user_management(): @with_jwks_mock def test_initialize_session_module(TEST_CONSTANTS, mock_user_management): - session = SessionModule( + session = Session( user_management=mock_user_management, client_id=TEST_CONSTANTS["CLIENT_ID"], session_data=TEST_CONSTANTS["SESSION_DATA"], @@ -87,7 +87,7 @@ def test_initialize_session_module(TEST_CONSTANTS, mock_user_management): @with_jwks_mock def test_initialize_without_cookie_password(TEST_CONSTANTS, mock_user_management): with pytest.raises(ValueError, match="cookie_password is required"): - SessionModule( + Session( user_management=mock_user_management, client_id=TEST_CONSTANTS["CLIENT_ID"], session_data=TEST_CONSTANTS["SESSION_DATA"], @@ -97,7 +97,7 @@ def test_initialize_without_cookie_password(TEST_CONSTANTS, mock_user_management @with_jwks_mock def test_authenticate_no_session_cookie_provided(TEST_CONSTANTS, mock_user_management): - session = SessionModule( + session = Session( user_management=mock_user_management, client_id=TEST_CONSTANTS["CLIENT_ID"], session_data=None, @@ -114,7 +114,7 @@ def test_authenticate_no_session_cookie_provided(TEST_CONSTANTS, mock_user_manag @with_jwks_mock def test_authenticate_invalid_session_cookie(TEST_CONSTANTS, mock_user_management): - session = SessionModule( + session = Session( user_management=mock_user_management, client_id=TEST_CONSTANTS["CLIENT_ID"], session_data="invalid_session_data", @@ -131,10 +131,10 @@ def test_authenticate_invalid_session_cookie(TEST_CONSTANTS, mock_user_managemen @with_jwks_mock def test_authenticate_invalid_jwt(TEST_CONSTANTS, mock_user_management): - invalid_session_data = SessionModule.seal_data( + invalid_session_data = Session.seal_data( {"access_token": "invalid_session_data"}, TEST_CONSTANTS["COOKIE_PASSWORD"] ) - session = SessionModule( + session = Session( user_management=mock_user_management, client_id=TEST_CONSTANTS["CLIENT_ID"], session_data=invalid_session_data, @@ -143,12 +143,14 @@ def test_authenticate_invalid_jwt(TEST_CONSTANTS, mock_user_management): response = session.authenticate() + print(response) + assert response.reason == AuthenticateWithSessionCookieFailureReason.INVALID_JWT @with_jwks_mock def test_authenticate_success(TEST_CONSTANTS, mock_user_management): - session = SessionModule( + session = Session( user_management=mock_user_management, client_id=TEST_CONSTANTS["CLIENT_ID"], session_data=TEST_CONSTANTS["SESSION_DATA"], @@ -192,7 +194,7 @@ def test_authenticate_success(TEST_CONSTANTS, mock_user_management): with ( # Mock unsealing the session data - patch.object(SessionModule, "unseal_data", return_value=mock_session), + patch.object(Session, "unseal_data", return_value=mock_session), # Mock JWT validation patch.object(session, "is_valid_jwt", return_value=True), # Mock JWT decoding @@ -219,7 +221,7 @@ def test_authenticate_success(TEST_CONSTANTS, mock_user_management): @with_jwks_mock def test_refresh_invalid_session_cookie(TEST_CONSTANTS, mock_user_management): - session = SessionModule( + session = Session( user_management=mock_user_management, client_id=TEST_CONSTANTS["CLIENT_ID"], session_data="invalid_session_data", @@ -248,7 +250,7 @@ def test_refresh_success(TEST_CONSTANTS, mock_user_management): "updated_at": TEST_CONSTANTS["CURRENT_TIMESTAMP"], } - session_data = SessionModule.seal_data( + session_data = Session.seal_data( {"refresh_token": "refresh_token_12345", "user": test_user}, TEST_CONSTANTS["COOKIE_PASSWORD"], ) @@ -264,7 +266,7 @@ def test_refresh_success(TEST_CONSTANTS, mock_user_management): RefreshTokenAuthenticationResponse(**mock_response) ) - session = SessionModule( + session = Session( user_management=mock_user_management, client_id=TEST_CONSTANTS["CLIENT_ID"], session_data=session_data, @@ -303,16 +305,19 @@ def test_refresh_success(TEST_CONSTANTS, mock_user_management): def test_seal_data(TEST_CONSTANTS): test_data = {"test": "data"} - sealed = SessionModule.seal_data(test_data, TEST_CONSTANTS["COOKIE_PASSWORD"]) + sealed = Session.seal_data(test_data, TEST_CONSTANTS["COOKIE_PASSWORD"]) assert isinstance(sealed, str) # Test unsealing - unsealed = SessionModule.unseal_data(sealed, TEST_CONSTANTS["COOKIE_PASSWORD"]) + unsealed = Session.unseal_data(sealed, TEST_CONSTANTS["COOKIE_PASSWORD"]) + print(test_data) + print(unsealed) + assert unsealed == test_data def test_unseal_invalid_data(TEST_CONSTANTS): with pytest.raises(Exception): # Adjust exception type based on your implementation - SessionModule.unseal_data( + Session.unseal_data( "invalid_sealed_data", TEST_CONSTANTS["COOKIE_PASSWORD"] ) diff --git a/workos/session.py b/workos/session.py index dc4bf2d1..be48a0ff 100644 --- a/workos/session.py +++ b/workos/session.py @@ -1,5 +1,6 @@ +from __future__ import annotations import json -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, Optional, Union, cast import jwt from jwt import PyJWKClient from cryptography.fernet import Fernet @@ -12,12 +13,11 @@ RefreshWithSessionCookieSuccessResponse, ) - -class SessionModule: +class Session: def __init__( self, *, - user_management: Any, + user_management: "UserManagementModule", # type: ignore client_id: str, session_data: str, cookie_password: str, @@ -42,7 +42,7 @@ def authenticate( AuthenticateWithSessionCookieSuccessResponse, AuthenticateWithSessionCookieErrorResponse, ]: - if self.session_data is None: + if self.session_data is None or self.session_data == "": return AuthenticateWithSessionCookieErrorResponse( authenticated=False, reason=AuthenticateWithSessionCookieFailureReason.NO_SESSION_COOKIE_PROVIDED, @@ -56,7 +56,7 @@ def authenticate( reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE, ) - if not session["access_token"]: + if not session.get("access_token", None): return AuthenticateWithSessionCookieErrorResponse( authenticated=False, reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE, @@ -84,14 +84,15 @@ def authenticate( impersonator=session.get("impersonator", None), ) - def refresh(self, options: Optional[Dict[str, Any]] = None) -> Union[ + def refresh( + self, *, organization_id: Optional[str] = None, cookie_password: Optional[str] = None + ) -> Union[ RefreshWithSessionCookieSuccessResponse, RefreshWithSessionCookieErrorResponse, ]: cookie_password = ( - self.cookie_password if options is None else options.get("cookie_password") + self.cookie_password if cookie_password is None else cookie_password ) - organization_id = None if options is None else options.get("organization_id") try: session = self.unseal_data(self.session_data, cookie_password) @@ -101,7 +102,7 @@ def refresh(self, options: Optional[Dict[str, Any]] = None) -> Union[ reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE, ) - if not session["refresh_token"] or not session["user"]: + if not session.get("refresh_token", None) or not session.get("user", None): return RefreshWithSessionCookieErrorResponse( authenticated=False, reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE, @@ -144,12 +145,14 @@ def refresh(self, options: Optional[Dict[str, Any]] = None) -> Union[ def get_logout_url(self) -> str: auth_response = self.authenticate() - if not auth_response.authenticated: + if isinstance(auth_response, AuthenticateWithSessionCookieErrorResponse): raise ValueError( f"Failed to extract session ID for logout URL: {auth_response.reason}" ) - return self.user_management.get_logout_url(session_id=auth_response.session_id) + result = self.user_management.get_logout_url(session_id=auth_response.session_id) + return str(result) + def is_valid_jwt(self, token: str) -> bool: try: @@ -171,4 +174,5 @@ def unseal_data(sealed_data: str, key: str) -> Dict[str, Any]: fernet = Fernet(key) # Convert string back to bytes before decryption encrypted_bytes = sealed_data.encode("utf-8") - return json.loads(fernet.decrypt(encrypted_bytes).decode()) + decrypted_str = fernet.decrypt(encrypted_bytes).decode() + return cast(Dict[str, Any], json.loads(decrypted_str)) diff --git a/workos/types/user_management/session.py b/workos/types/user_management/session.py index d6c859e5..7c81a0c3 100644 --- a/workos/types/user_management/session.py +++ b/workos/types/user_management/session.py @@ -1,6 +1,6 @@ from typing import List, Optional, TypedDict, Union from enum import Enum - +from typing_extensions import Literal from workos.types.user_management.impersonator import Impersonator from workos.types.user_management.user import User from workos.types.workos_model import WorkOSModel @@ -13,7 +13,7 @@ class AuthenticateWithSessionCookieFailureReason(Enum): class AuthenticateWithSessionCookieSuccessResponse(WorkOSModel): - authenticated: bool = True + authenticated: Literal[True] session_id: str organization_id: Optional[str] = None role: Optional[str] = None @@ -24,7 +24,7 @@ class AuthenticateWithSessionCookieSuccessResponse(WorkOSModel): class AuthenticateWithSessionCookieErrorResponse(WorkOSModel): - authenticated: bool = False + authenticated: Literal[False] reason: Union[AuthenticateWithSessionCookieFailureReason, str] @@ -35,10 +35,10 @@ class RefreshWithSessionCookieSuccessResponse( class RefreshWithSessionCookieErrorResponse(WorkOSModel): - authenticated: bool = False + authenticated: Literal[False] reason: Union[AuthenticateWithSessionCookieFailureReason, str] -class SessionConfig(TypedDict): +class SessionConfig(TypedDict, total=False): seal_session: bool cookie_password: str diff --git a/workos/user_management.py b/workos/user_management.py index cdde4006..1d2e02c1 100644 --- a/workos/user_management.py +++ b/workos/user_management.py @@ -1,6 +1,6 @@ -from typing import Optional, Protocol, Sequence, Set, Type +from typing import Optional, Protocol, Sequence, Set, Type, cast from workos._client_configuration import ClientConfiguration -from workos.session import SessionModule +from workos.session import Session from workos.types.list_resource import ( ListArgs, ListMetadata, @@ -113,7 +113,7 @@ class UserManagementModule(Protocol): def load_sealed_session( self, *, sealed_session: str, cookie_password: str - ) -> SyncOrAsync[SessionModule]: + ) -> SyncOrAsync[Session]: """Load a sealed session and return the session data. Args: @@ -121,7 +121,7 @@ def load_sealed_session( cookie_password (str): The cookie password to use to decrypt the session data. Returns: - SessionModule: The session module. + Session: The session module. """ ... @@ -825,12 +825,12 @@ def __init__( self._http_client = http_client def load_sealed_session( - self, *, session_data: str, cookie_password: str - ) -> SessionModule: - return SessionModule( + self, *, sealed_session: str, cookie_password: str + ) -> Session: + return Session( user_management=self, client_id=self._http_client.client_id, - session_data=session_data, + session_data=sealed_session, cookie_password=cookie_password, ) @@ -1043,15 +1043,16 @@ def _authenticate_with( json=json, ) - if ( - payload.get("session") is not None - and payload.get("session").get("seal_session") is True - ): - response["sealed_session"] = SessionModule.seal_data( - response, payload.get("session").get("cookie_password") + response_data = dict(response) + + session = cast(Optional[SessionConfig], payload.get("session", None)) + + if session is not None and session.get("seal_session") is True: + response_data["sealed_session"] = Session.seal_data( + response_data, str(session.get("cookie_password")) ) - return response_model.model_validate(response) + return response_model.model_validate(response_data) def authenticate_with_password( self, @@ -1080,11 +1081,7 @@ def authenticate_with_code( ip_address: Optional[str] = None, user_agent: Optional[str] = None, ) -> AuthKitAuthenticationResponse: - if session is not None and ( - session.get("seal_session") is True - and session.get("cookie_password") is None - or "" - ): + if session is not None and session.get("seal_session") and not session.get("cookie_password"): raise ValueError("cookie_password is required when sealing session") payload: AuthenticateWithCodeParameters = { @@ -1185,11 +1182,7 @@ def authenticate_with_refresh_token( ip_address: Optional[str] = None, user_agent: Optional[str] = None, ) -> RefreshTokenAuthenticationResponse: - if session is not None and ( - session.get("seal_session") is True - and session.get("cookie_password") is None - or "" - ): + if session is not None and session.get("seal_session") and not session.get("cookie_password"): raise ValueError("cookie_password is required when sealing session") payload: AuthenticateWithRefreshTokenParameters = { @@ -1273,10 +1266,7 @@ def get_magic_auth(self, magic_auth_id: str) -> MagicAuth: return MagicAuth.model_validate(response) def create_magic_auth( - self, - *, - email: str, - invitation_token: Optional[str] = None, + self, *, email: str, invitation_token: Optional[str] = None ) -> MagicAuth: json = { "email": email, @@ -1435,6 +1425,11 @@ def __init__( self._client_configuration = client_configuration self._http_client = http_client + async def load_sealed_session( + self, *, sealed_session: str, cookie_password: str + ) -> Session: + raise NotImplementedError("Async load_sealed_session not implemented") + async def get_user(self, user_id: str) -> User: response = await self._http_client.request( USER_DETAIL_PATH.format(user_id), method=REQUEST_METHOD_GET @@ -1645,15 +1640,16 @@ async def _authenticate_with( json=json, ) - if ( - payload.get("session") is not None - and payload.get("session").get("seal_session") is True - ): - response["sealed_session"] = SessionModule.seal_data( - response, payload.get("session").get("cookie_password") + response_data = dict(response) + + session = cast(Optional[SessionConfig], payload.get("session", None)) + + if session is not None and session.get("seal_session") is True: + response_data["sealed_session"] = Session.seal_data( + response_data, str(session.get("cookie_password")) ) - return response_model.model_validate(response) + return response_model.model_validate(response_data) async def authenticate_with_password( self, @@ -1684,11 +1680,7 @@ async def authenticate_with_code( ip_address: Optional[str] = None, user_agent: Optional[str] = None, ) -> AuthKitAuthenticationResponse: - if session is not None and ( - session.get("seal_session") is True - and session.get("cookie_password") is None - or "" - ): + if session is not None and session.get("seal_session") and not session.get("cookie_password"): raise ValueError("cookie_password is required when sealing session") payload: AuthenticateWithCodeParameters = { @@ -1797,11 +1789,7 @@ async def authenticate_with_refresh_token( ip_address: Optional[str] = None, user_agent: Optional[str] = None, ) -> RefreshTokenAuthenticationResponse: - if session is not None and ( - session.get("seal_session") is True - and session.get("cookie_password") is None - or "" - ): + if session is not None and session.get("seal_session") and not session.get("cookie_password"): raise ValueError("cookie_password is required when sealing session") payload: AuthenticateWithRefreshTokenParameters = { From 0410381657f0c75f3450b3d8c92c975f191505e1 Mon Sep 17 00:00:00 2001 From: Paul Asjes Date: Wed, 27 Nov 2024 13:54:32 +0100 Subject: [PATCH 09/14] make black happy too --- tests/test_session.py | 4 +--- workos/session.py | 13 +++++++++---- workos/user_management.py | 24 ++++++++++++++++++++---- 3 files changed, 30 insertions(+), 11 deletions(-) diff --git a/tests/test_session.py b/tests/test_session.py index 9c38372a..4646fcd0 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -318,6 +318,4 @@ def test_seal_data(TEST_CONSTANTS): def test_unseal_invalid_data(TEST_CONSTANTS): with pytest.raises(Exception): # Adjust exception type based on your implementation - Session.unseal_data( - "invalid_sealed_data", TEST_CONSTANTS["COOKIE_PASSWORD"] - ) + Session.unseal_data("invalid_sealed_data", TEST_CONSTANTS["COOKIE_PASSWORD"]) diff --git a/workos/session.py b/workos/session.py index be48a0ff..5a649734 100644 --- a/workos/session.py +++ b/workos/session.py @@ -13,11 +13,12 @@ RefreshWithSessionCookieSuccessResponse, ) + class Session: def __init__( self, *, - user_management: "UserManagementModule", # type: ignore + user_management: "UserManagementModule", # type: ignore client_id: str, session_data: str, cookie_password: str, @@ -85,7 +86,10 @@ def authenticate( ) def refresh( - self, *, organization_id: Optional[str] = None, cookie_password: Optional[str] = None + self, + *, + organization_id: Optional[str] = None, + cookie_password: Optional[str] = None, ) -> Union[ RefreshWithSessionCookieSuccessResponse, RefreshWithSessionCookieErrorResponse, @@ -150,10 +154,11 @@ def get_logout_url(self) -> str: f"Failed to extract session ID for logout URL: {auth_response.reason}" ) - result = self.user_management.get_logout_url(session_id=auth_response.session_id) + result = self.user_management.get_logout_url( + session_id=auth_response.session_id + ) return str(result) - def is_valid_jwt(self, token: str) -> bool: try: signing_key = self.jwks.get_signing_key_from_jwt(token) diff --git a/workos/user_management.py b/workos/user_management.py index 1d2e02c1..3534c678 100644 --- a/workos/user_management.py +++ b/workos/user_management.py @@ -1081,7 +1081,11 @@ def authenticate_with_code( ip_address: Optional[str] = None, user_agent: Optional[str] = None, ) -> AuthKitAuthenticationResponse: - if session is not None and session.get("seal_session") and not session.get("cookie_password"): + if ( + session is not None + and session.get("seal_session") + and not session.get("cookie_password") + ): raise ValueError("cookie_password is required when sealing session") payload: AuthenticateWithCodeParameters = { @@ -1182,7 +1186,11 @@ def authenticate_with_refresh_token( ip_address: Optional[str] = None, user_agent: Optional[str] = None, ) -> RefreshTokenAuthenticationResponse: - if session is not None and session.get("seal_session") and not session.get("cookie_password"): + if ( + session is not None + and session.get("seal_session") + and not session.get("cookie_password") + ): raise ValueError("cookie_password is required when sealing session") payload: AuthenticateWithRefreshTokenParameters = { @@ -1680,7 +1688,11 @@ async def authenticate_with_code( ip_address: Optional[str] = None, user_agent: Optional[str] = None, ) -> AuthKitAuthenticationResponse: - if session is not None and session.get("seal_session") and not session.get("cookie_password"): + if ( + session is not None + and session.get("seal_session") + and not session.get("cookie_password") + ): raise ValueError("cookie_password is required when sealing session") payload: AuthenticateWithCodeParameters = { @@ -1789,7 +1801,11 @@ async def authenticate_with_refresh_token( ip_address: Optional[str] = None, user_agent: Optional[str] = None, ) -> RefreshTokenAuthenticationResponse: - if session is not None and session.get("seal_session") and not session.get("cookie_password"): + if ( + session is not None + and session.get("seal_session") + and not session.get("cookie_password") + ): raise ValueError("cookie_password is required when sealing session") payload: AuthenticateWithRefreshTokenParameters = { From 25e6e5afbf7bbb9429ded355abfa8e9dfe28ac19 Mon Sep 17 00:00:00 2001 From: Paul Asjes Date: Wed, 27 Nov 2024 13:57:37 +0100 Subject: [PATCH 10/14] Forgot import --- workos/session.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/workos/session.py b/workos/session.py index 5a649734..55cca562 100644 --- a/workos/session.py +++ b/workos/session.py @@ -12,13 +12,14 @@ RefreshWithSessionCookieErrorResponse, RefreshWithSessionCookieSuccessResponse, ) +from workos.user_management import UserManagementModule class Session: def __init__( self, *, - user_management: "UserManagementModule", # type: ignore + user_management: "UserManagementModule", client_id: str, session_data: str, cookie_password: str, From 28c213037b839f2e6889e2998727aae7e0ebb508 Mon Sep 17 00:00:00 2001 From: Paul Asjes Date: Wed, 27 Nov 2024 14:29:31 +0100 Subject: [PATCH 11/14] Satisfy type checker --- workos/session.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/workos/session.py b/workos/session.py index 55cca562..df8970c1 100644 --- a/workos/session.py +++ b/workos/session.py @@ -1,10 +1,15 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import json from typing import Any, Dict, Optional, Union, cast import jwt from jwt import PyJWKClient from cryptography.fernet import Fernet +from workos.types.user_management.authentication_response import ( + RefreshTokenAuthenticationResponse, +) from workos.types.user_management.session import ( AuthenticateWithSessionCookieFailureReason, AuthenticateWithSessionCookieSuccessResponse, @@ -12,7 +17,9 @@ RefreshWithSessionCookieErrorResponse, RefreshWithSessionCookieSuccessResponse, ) -from workos.user_management import UserManagementModule + +if TYPE_CHECKING: + from workos.user_management import UserManagementModule class Session: @@ -114,14 +121,19 @@ def refresh( ) try: - auth_response = self.user_management.authenticate_with_refresh_token( - refresh_token=session["refresh_token"], - organization_id=organization_id, - session={"seal_session": True, "cookie_password": cookie_password}, + auth_response = cast( + RefreshTokenAuthenticationResponse, + self.user_management.authenticate_with_refresh_token( + refresh_token=session["refresh_token"], + organization_id=organization_id, + session={"seal_session": True, "cookie_password": cookie_password}, + ), ) - self.session_data = auth_response.sealed_session - self.cookie_password = cookie_password + self.session_data = str(auth_response.sealed_session) + self.cookie_password = ( + cookie_password if cookie_password is not None else self.cookie_password + ) signing_key = self.jwks.get_signing_key_from_jwt(auth_response.access_token) @@ -133,7 +145,7 @@ def refresh( return RefreshWithSessionCookieSuccessResponse( authenticated=True, - sealed_session=auth_response.sealed_session, + sealed_session=str(auth_response.sealed_session), session_id=decoded["sid"], organization_id=decoded.get("org_id", None), role=decoded.get("role", None), From 3cbc6a602471d62d9933e374928fffb7a65be297 Mon Sep 17 00:00:00 2001 From: Paul Asjes Date: Wed, 27 Nov 2024 14:40:21 +0100 Subject: [PATCH 12/14] 3.8 compatibility and remove print statements --- tests/test_session.py | 39 +++++++++++++-------------------------- 1 file changed, 13 insertions(+), 26 deletions(-) diff --git a/tests/test_session.py b/tests/test_session.py index 4646fcd0..0ff10baf 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -143,8 +143,6 @@ def test_authenticate_invalid_jwt(TEST_CONSTANTS, mock_user_management): response = session.authenticate() - print(response) - assert response.reason == AuthenticateWithSessionCookieFailureReason.INVALID_JWT @@ -192,19 +190,12 @@ def test_authenticate_success(TEST_CONSTANTS, mock_user_management): "entitlements": ["feature_1"], } - with ( - # Mock unsealing the session data - patch.object(Session, "unseal_data", return_value=mock_session), - # Mock JWT validation - patch.object(session, "is_valid_jwt", return_value=True), - # Mock JWT decoding - patch("jwt.decode", return_value=mock_jwt_payload), - # Mock JWT signing key retrieval - patch.object( - session.jwks, - "get_signing_key_from_jwt", - return_value=Mock(key=TEST_CONSTANTS["PUBLIC_KEY"]), - ), + with patch.object(Session, "unseal_data", return_value=mock_session), patch.object( + session, "is_valid_jwt", return_value=True + ), patch("jwt.decode", return_value=mock_jwt_payload), patch.object( + session.jwks, + "get_signing_key_from_jwt", + return_value=Mock(key=TEST_CONSTANTS["PUBLIC_KEY"]), ): response = session.authenticate() @@ -273,9 +264,8 @@ def test_refresh_success(TEST_CONSTANTS, mock_user_management): cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"], ) - with ( - patch.object(session, "is_valid_jwt", return_value=True), - patch( + with patch.object(session, "is_valid_jwt", return_value=True) as _: + with patch( "jwt.decode", return_value={ "sid": TEST_CONSTANTS["SESSION_ID"], @@ -284,13 +274,12 @@ def test_refresh_success(TEST_CONSTANTS, mock_user_management): "permissions": ["read"], "entitlements": ["feature_1"], }, - ), - ): - response = session.refresh() + ): + response = session.refresh() - assert isinstance(response, RefreshWithSessionCookieSuccessResponse) - assert response.authenticated is True - assert response.user.id == test_user["id"] + assert isinstance(response, RefreshWithSessionCookieSuccessResponse) + assert response.authenticated is True + assert response.user.id == test_user["id"] # Verify the refresh token was used correctly mock_user_management.authenticate_with_refresh_token.assert_called_once_with( @@ -310,8 +299,6 @@ def test_seal_data(TEST_CONSTANTS): # Test unsealing unsealed = Session.unseal_data(sealed, TEST_CONSTANTS["COOKIE_PASSWORD"]) - print(test_data) - print(unsealed) assert unsealed == test_data From 6b6c66af41859b5d72107cb27b727dd9ba77fa48 Mon Sep 17 00:00:00 2001 From: Paul Asjes Date: Mon, 2 Dec 2024 15:35:06 +0100 Subject: [PATCH 13/14] Use sequence instead of list --- workos/types/user_management/session.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/workos/types/user_management/session.py b/workos/types/user_management/session.py index 7c81a0c3..76739f9d 100644 --- a/workos/types/user_management/session.py +++ b/workos/types/user_management/session.py @@ -1,4 +1,4 @@ -from typing import List, Optional, TypedDict, Union +from typing import Optional, Sequence, TypedDict, Union from enum import Enum from typing_extensions import Literal from workos.types.user_management.impersonator import Impersonator @@ -17,10 +17,10 @@ class AuthenticateWithSessionCookieSuccessResponse(WorkOSModel): session_id: str organization_id: Optional[str] = None role: Optional[str] = None - permissions: Optional[List[str]] = None + permissions: Optional[Sequence[str]] = None user: User impersonator: Optional[Impersonator] = None - entitlements: Optional[List[str]] = None + entitlements: Optional[Sequence[str]] = None class AuthenticateWithSessionCookieErrorResponse(WorkOSModel): From c9daffce40eacd3e79a3a72e8f2b3af76a83ad5a Mon Sep 17 00:00:00 2001 From: Paul Asjes Date: Mon, 2 Dec 2024 15:47:13 +0100 Subject: [PATCH 14/14] Make is_valid_jwt private --- tests/test_session.py | 4 ++-- workos/session.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_session.py b/tests/test_session.py index 0ff10baf..fbb82717 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -191,7 +191,7 @@ def test_authenticate_success(TEST_CONSTANTS, mock_user_management): } with patch.object(Session, "unseal_data", return_value=mock_session), patch.object( - session, "is_valid_jwt", return_value=True + session, "_is_valid_jwt", return_value=True ), patch("jwt.decode", return_value=mock_jwt_payload), patch.object( session.jwks, "get_signing_key_from_jwt", @@ -264,7 +264,7 @@ def test_refresh_success(TEST_CONSTANTS, mock_user_management): cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"], ) - with patch.object(session, "is_valid_jwt", return_value=True) as _: + with patch.object(session, "_is_valid_jwt", return_value=True) as _: with patch( "jwt.decode", return_value={ diff --git a/workos/session.py b/workos/session.py index df8970c1..fea062ca 100644 --- a/workos/session.py +++ b/workos/session.py @@ -71,7 +71,7 @@ def authenticate( reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE, ) - if not self.is_valid_jwt(session["access_token"]): + if not self._is_valid_jwt(session["access_token"]): return AuthenticateWithSessionCookieErrorResponse( authenticated=False, reason=AuthenticateWithSessionCookieFailureReason.INVALID_JWT, @@ -172,7 +172,7 @@ def get_logout_url(self) -> str: ) return str(result) - def is_valid_jwt(self, token: str) -> bool: + def _is_valid_jwt(self, token: str) -> bool: try: signing_key = self.jwks.get_signing_key_from_jwt(token) jwt.decode(token, signing_key.key, algorithms=self.jwk_algorithms)