From fed8cbf3421546294478489b1cab3328d0dfe4da Mon Sep 17 00:00:00 2001
From: Paul Asjes
Date: Thu, 5 Sep 2024 14:34:32 +0200
Subject: [PATCH 01/14] WIP
---
workos/session.py | 9 +++++++++
workos/types/user_management/session.py | 25 +++++++++++++++++++++++++
2 files changed, 34 insertions(+)
create mode 100644 workos/session.py
create mode 100644 workos/types/user_management/session.py
diff --git a/workos/session.py b/workos/session.py
new file mode 100644
index 00000000..48f38e51
--- /dev/null
+++ b/workos/session.py
@@ -0,0 +1,9 @@
+
+from typing import Protocol, Union
+
+from workos.types.user_management.session import AuthenticateWithSessionCookieSuccessResponse, AuthenticateWithSessionCookieErrorResponse
+
+class SessionModule(Protocol):
+
+ def authenticate(self) -> Union[AuthenticateWithSessionCookieSuccessResponse, AuthenticateWithSessionCookieErrorResponse]:
+ ...
\ No newline at end of file
diff --git a/workos/types/user_management/session.py b/workos/types/user_management/session.py
new file mode 100644
index 00000000..7e1a50df
--- /dev/null
+++ b/workos/types/user_management/session.py
@@ -0,0 +1,25 @@
+from typing import List, Optional
+from enum import Enum
+
+from workos.types.user_management.impersonator import Impersonator
+from workos.types.user_management.user import User
+from workos.types.workos_model import WorkOSModel
+
+class AuthenticateWithSessionCookieFailureReason(Enum):
+ INVALID_JWT = 'invalid_jwt'
+ INVALID_SESSION_COOKIE = 'invalid_session_cookie'
+ NO_SESSION_COOKIE_PROVIDED = 'no_session_cookie_provided'
+
+class AuthenticateWithSessionCookieSuccessResponse(WorkOSModel):
+ authenticated: bool = True
+ sessionId: str
+ organizationId: Optional[str] = None
+ role: Optional[str] = None
+ permissions: Optional[List[str]] = None
+ user: User
+ impersonator: Optional[Impersonator] = None
+
+class AuthenticateWithSessionCookieErrorResponse(WorkOSModel):
+ authenticated: bool = False
+ reason: AuthenticateWithSessionCookieFailureReason
+
From c199df348c417b3eadfb078676442109bd1496fc Mon Sep 17 00:00:00 2001
From: Paul Asjes
Date: Wed, 13 Nov 2024 16:37:15 +0100
Subject: [PATCH 02/14] WIP
---
LICENSE | 2 +-
requirements-dev.txt | 9 +
requirements.txt | 4 +
setup.py | 19 +--
workos/session.py | 156 +++++++++++++++++-
.../authenticate_with_common.py | 2 +
.../authentication_response.py | 1 +
workos/types/user_management/session.py | 5 +-
workos/user_management.py | 25 +++
9 files changed, 204 insertions(+), 19 deletions(-)
create mode 100644 requirements-dev.txt
create mode 100644 requirements.txt
diff --git a/LICENSE b/LICENSE
index 656ceca8..67f8e417 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,6 +1,6 @@
MIT License
-Copyright (c) 2021 WorkOS
+Copyright (c) 2024 WorkOS
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
diff --git a/requirements-dev.txt b/requirements-dev.txt
new file mode 100644
index 00000000..478712f7
--- /dev/null
+++ b/requirements-dev.txt
@@ -0,0 +1,9 @@
+flake8
+pytest==8.3.2
+pytest-asyncio==0.23.8
+pytest-cov==5.0.0
+six==1.16.0
+black==24.4.2
+twine==5.1.1
+mypy==1.12.0
+httpx>=0.27.0
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 00000000..beaf1927
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,4 @@
+httpx>=0.27.0
+pydantic==2.9.2
+PyJWT==2.9.0
+cryptography==43.0.3
diff --git a/setup.py b/setup.py
index e86e5a92..bfd5dbcb 100644
--- a/setup.py
+++ b/setup.py
@@ -10,6 +10,11 @@
with open(os.path.join(base_dir, "workos", "__about__.py")) as f:
exec(f.read(), about)
+def read_requirements(filename):
+ with open(filename) as f:
+ return [line.strip() for line in f
+ if line.strip() and not line.startswith('#')]
+
setup(
name=about["__package_name__"],
version=about["__version__"],
@@ -27,19 +32,9 @@
),
zip_safe=False,
license=about["__license__"],
- install_requires=["httpx>=0.27.0", "pydantic==2.9.2"],
+ install_requires=read_requirements("requirements.txt"),
extras_require={
- "dev": [
- "flake8",
- "pytest==8.3.2",
- "pytest-asyncio==0.23.8",
- "pytest-cov==5.0.0",
- "six==1.16.0",
- "black==24.4.2",
- "twine==5.1.1",
- "mypy==1.12.0",
- "httpx>=0.27.0",
- ],
+ "dev": read_requirements("requirements-dev.txt"),
":python_version<'3.4'": ["enum34"],
},
classifiers=[
diff --git a/workos/session.py b/workos/session.py
index 48f38e51..47765e04 100644
--- a/workos/session.py
+++ b/workos/session.py
@@ -1,9 +1,155 @@
+import json
+from typing import Any, Dict, List, Optional, Union
+import jwt
+from jwt import PyJWKClient
+from cryptography.fernet import Fernet
-from typing import Protocol, Union
+from workos.types.user_management.session import (
+ AuthenticateWithSessionCookieFailureReason,
+ AuthenticateWithSessionCookieSuccessResponse,
+ AuthenticateWithSessionCookieErrorResponse,
+)
-from workos.types.user_management.session import AuthenticateWithSessionCookieSuccessResponse, AuthenticateWithSessionCookieErrorResponse
+class SessionModule:
+ def __init__(
+ self,
+ *,
+ user_management: Any,
+ client_id: str,
+ session_data: str,
+ cookie_password: str
+ ) -> None:
+ # If the cookie password is not provided, throw an error
+ if cookie_password is None or cookie_password == "":
+ raise ValueError("cookie_password is required")
-class SessionModule(Protocol):
+ self.user_management = user_management
+ self.client_id = client_id
+ self.session_data = session_data
+ self.cookie_password = cookie_password
- def authenticate(self) -> Union[AuthenticateWithSessionCookieSuccessResponse, AuthenticateWithSessionCookieErrorResponse]:
- ...
\ No newline at end of file
+ self.jwks = self.create_remote_jwk_set(
+ self.user_management.get_jwks_url()
+ )
+ self.jwk_algorithms = [str(key.Algorithm) for key in self.jwks]
+
+ for key in self.jwks:
+ print("Key properties:", dir(key)) # This will show all available attributes
+ print("Algorithm:", key.Algorithm)
+ print("Key type:", key.key_type)
+
+ def authenticate(
+ self,
+ ) -> Union[
+ AuthenticateWithSessionCookieSuccessResponse,
+ AuthenticateWithSessionCookieErrorResponse,
+ ]:
+ if self.session_data is None:
+ return AuthenticateWithSessionCookieErrorResponse(
+ authenticated=False, reason=AuthenticateWithSessionCookieFailureReason.NO_SESSION_COOKIE_PROVIDED
+ )
+
+ try:
+ session = self.unseal_data(self.session_data, self.cookie_password)
+ except Exception:
+ return AuthenticateWithSessionCookieErrorResponse(
+ authenticated=False, reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE
+ )
+
+ if not session["access_token"]:
+ return AuthenticateWithSessionCookieErrorResponse(
+ authenticated=False, reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE
+ )
+
+ if not self.is_valid_jwt(session["access_token"]):
+ return AuthenticateWithSessionCookieErrorResponse(
+ authenticated=False, reason=AuthenticateWithSessionCookieFailureReason.INVALID_JWT
+ )
+
+ decoded = jwt.decode(
+ session["access_token"], self.jwks, algorithms=self.jwk_algorithms
+ )
+
+ return AuthenticateWithSessionCookieSuccessResponse(
+ authenticated=True,
+ session_id=decoded["sid"],
+ organization_id=decoded["org_id"],
+ role=decoded["role"],
+ permissions=decoded["permissions"],
+ entitlements=decoded["entitlements"],
+ user=session["user"],
+ impersonator=session["impersonator"],
+ reason=None,
+ )
+
+ def refresh(self, options: Optional[Dict[str, Any]] = None) -> Union[
+ AuthenticateWithSessionCookieSuccessResponse,
+ AuthenticateWithSessionCookieErrorResponse,
+ ]:
+ cookie_password = options.get("cookie_password", self.cookie_password)
+ organization_id = options.get("organization_id", None)
+
+ try:
+ session = self.unseal_data(self.session_data, cookie_password)
+ except Exception:
+ return AuthenticateWithSessionCookieErrorResponse(
+ authenticated=False, reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE
+ )
+
+ if not session["refresh_token"] or not session["user"]:
+ return AuthenticateWithSessionCookieErrorResponse(
+ authenticated=False, reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE
+ )
+
+ try:
+ auth_response = self.user_management.authenticate_with_refresh_token(
+ refresh_token=session["refresh_token"],
+ organization_id=organization_id,
+ )
+
+ self.session_data = auth_response.sealed_session
+ self.cookie_password = cookie_password
+
+ return AuthenticateWithSessionCookieSuccessResponse(
+ authenticated=True,
+ sealed_session=auth_response.sealed_session,
+ session=auth_response,
+ reason=None,
+ )
+ except Exception as e:
+ return AuthenticateWithSessionCookieErrorResponse(
+ authenticated=False, reason=str(e)
+ )
+
+ def get_logout_url(self) -> str:
+ auth_response = self.authenticate()
+
+ if not auth_response["authenticated"]:
+ raise ValueError(auth_response["reason"])
+
+ return self.user_management.get_logout_url(
+ session_id=auth_response["session_id"]
+ )
+
+ def create_remote_jwk_set(self, url: str) -> List[Dict[str, Any]]:
+ jwks_client = PyJWKClient(url)
+ return jwks_client.get_signing_keys()
+
+ def is_valid_jwt(self, token: str) -> bool:
+ try:
+ jwt.decode(token, self.jwks, algorithms=self.jwk_algorithms)
+ return True
+ except jwt.exceptions.InvalidTokenError as error:
+ print("invalid token", error)
+ return False
+
+ @staticmethod
+ def seal_data(data: Dict[str, Any], key: str) -> str:
+ fernet = Fernet(key)
+ # take the data and encrypt it with the key using fernet
+ return fernet.encrypt(json.dumps(data).encode())
+
+ @staticmethod
+ def unseal_data(sealed_data: str, key: str) -> Dict[str, Any]:
+ fernet = Fernet(key)
+ return json.loads(fernet.decrypt(sealed_data).decode())
diff --git a/workos/types/user_management/authenticate_with_common.py b/workos/types/user_management/authenticate_with_common.py
index af423e18..73de96ac 100644
--- a/workos/types/user_management/authenticate_with_common.py
+++ b/workos/types/user_management/authenticate_with_common.py
@@ -1,5 +1,6 @@
from typing import Literal, Union
from typing_extensions import TypedDict
+from workos.types.user_management.session import SessionConfig
class AuthenticateWithBaseParameters(TypedDict):
@@ -17,6 +18,7 @@ class AuthenticateWithCodeParameters(AuthenticateWithBaseParameters):
code: str
code_verifier: Union[str, None]
grant_type: Literal["authorization_code"]
+ session: Union[SessionConfig, None]
class AuthenticateWithMagicAuthParameters(AuthenticateWithBaseParameters):
diff --git a/workos/types/user_management/authentication_response.py b/workos/types/user_management/authentication_response.py
index 999234a9..6a5f4449 100644
--- a/workos/types/user_management/authentication_response.py
+++ b/workos/types/user_management/authentication_response.py
@@ -29,6 +29,7 @@ class AuthenticationResponse(_AuthenticationResponseBase):
impersonator: Optional[Impersonator] = None
organization_id: Optional[str] = None
user: User
+ sealed_session: Optional[str] = None
class AuthKitAuthenticationResponse(AuthenticationResponse):
diff --git a/workos/types/user_management/session.py b/workos/types/user_management/session.py
index 7e1a50df..01581ec1 100644
--- a/workos/types/user_management/session.py
+++ b/workos/types/user_management/session.py
@@ -1,4 +1,4 @@
-from typing import List, Optional
+from typing import List, Optional, TypedDict
from enum import Enum
from workos.types.user_management.impersonator import Impersonator
@@ -23,3 +23,6 @@ class AuthenticateWithSessionCookieErrorResponse(WorkOSModel):
authenticated: bool = False
reason: AuthenticateWithSessionCookieFailureReason
+class SessionConfig(TypedDict):
+ seal_session: bool
+ cookie_password: str
\ No newline at end of file
diff --git a/workos/user_management.py b/workos/user_management.py
index dc444a01..66232e3a 100644
--- a/workos/user_management.py
+++ b/workos/user_management.py
@@ -1,5 +1,6 @@
from typing import Optional, Protocol, Sequence, Set, Type
from workos._client_configuration import ClientConfiguration
+from workos.session import SessionModule
from workos.types.list_resource import (
ListArgs,
ListMetadata,
@@ -43,6 +44,7 @@
UsersListFilters,
)
from workos.types.user_management.password_hash_type import PasswordHashType
+from workos.types.user_management.session import SessionConfig
from workos.types.user_management.user_management_provider_type import (
UserManagementProviderType,
)
@@ -109,6 +111,18 @@ class UserManagementModule(Protocol):
_client_configuration: ClientConfiguration
+ def load_sealed_session(self, *, sealed_session: str, cookie_password: str) -> SyncOrAsync[SessionModule]:
+ """Load a sealed session and return the session data.
+
+ Args:
+ sealed_session (str): The sealed session data to load.
+ cookie_password (str): The cookie password to use to decrypt the session data.
+
+ Returns:
+ SessionModule: The session module.
+ """
+ ...
+
def get_user(self, user_id: str) -> SyncOrAsync[User]:
"""Get the details of an existing user.
@@ -804,6 +818,9 @@ def __init__(
self._client_configuration = client_configuration
self._http_client = http_client
+ def load_sealed_session(self, *, session_data: str, cookie_password: str) -> SessionModule:
+ return SessionModule(user_management=self, client_id=self._http_client.client_id, session_data=session_data, cookie_password=cookie_password)
+
def get_user(self, user_id: str) -> User:
response = self._http_client.request(
USER_DETAIL_PATH.format(user_id), method=REQUEST_METHOD_GET
@@ -1013,6 +1030,9 @@ def _authenticate_with(
json=json,
)
+ if payload["session"] is not None and payload["session"].get("seal_session") is True:
+ response["sealed_session"] = SessionModule.seal_data(response, payload["session"]["cookie_password"])
+
return response_model.model_validate(response)
def authenticate_with_password(
@@ -1037,16 +1057,21 @@ def authenticate_with_code(
self,
*,
code: str,
+ session: Optional[SessionConfig] = None,
code_verifier: Optional[str] = None,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
) -> AuthKitAuthenticationResponse:
+ if session is not None and (session.get("seal_session") is True and session.get("cookie_password") is None or ""):
+ raise ValueError("cookie_password is required when sealing session")
+
payload: AuthenticateWithCodeParameters = {
"code": code,
"grant_type": "authorization_code",
"ip_address": ip_address,
"user_agent": user_agent,
"code_verifier": code_verifier,
+ "session": session,
}
return self._authenticate_with(
From 8f1ec0cc5251592af2ccf556d24413a5c1eeef97 Mon Sep 17 00:00:00 2001
From: Paul Asjes
Date: Wed, 13 Nov 2024 17:44:41 +0100
Subject: [PATCH 03/14] lol camel case
---
workos/session.py | 30 +++++++++----------------
workos/types/user_management/session.py | 4 ++--
2 files changed, 13 insertions(+), 21 deletions(-)
diff --git a/workos/session.py b/workos/session.py
index 47765e04..d2b9cf01 100644
--- a/workos/session.py
+++ b/workos/session.py
@@ -28,15 +28,9 @@ def __init__(
self.session_data = session_data
self.cookie_password = cookie_password
- self.jwks = self.create_remote_jwk_set(
- self.user_management.get_jwks_url()
- )
- self.jwk_algorithms = [str(key.Algorithm) for key in self.jwks]
+ self.jwks = PyJWKClient(self.user_management.get_jwks_url())
- for key in self.jwks:
- print("Key properties:", dir(key)) # This will show all available attributes
- print("Algorithm:", key.Algorithm)
- print("Key type:", key.key_type)
+ self.jwk_algorithms = ['RS256']
def authenticate(
self,
@@ -66,19 +60,20 @@ def authenticate(
authenticated=False, reason=AuthenticateWithSessionCookieFailureReason.INVALID_JWT
)
+ signing_key = self.jwks.get_signing_key_from_jwt(session["access_token"])
decoded = jwt.decode(
- session["access_token"], self.jwks, algorithms=self.jwk_algorithms
+ session["access_token"], signing_key.key, algorithms=self.jwk_algorithms
)
return AuthenticateWithSessionCookieSuccessResponse(
authenticated=True,
session_id=decoded["sid"],
- organization_id=decoded["org_id"],
- role=decoded["role"],
- permissions=decoded["permissions"],
- entitlements=decoded["entitlements"],
+ organization_id=decoded.get("org_id", None),
+ role=decoded.get("role", None),
+ permissions=decoded.get("permissions", None),
+ entitlements=decoded.get("entitlements", None),
user=session["user"],
- impersonator=session["impersonator"],
+ impersonator=session.get("impersonator", None),
reason=None,
)
@@ -131,13 +126,10 @@ def get_logout_url(self) -> str:
session_id=auth_response["session_id"]
)
- def create_remote_jwk_set(self, url: str) -> List[Dict[str, Any]]:
- jwks_client = PyJWKClient(url)
- return jwks_client.get_signing_keys()
-
def is_valid_jwt(self, token: str) -> bool:
try:
- jwt.decode(token, self.jwks, algorithms=self.jwk_algorithms)
+ signing_key = self.jwks.get_signing_key_from_jwt(token)
+ jwt.decode(token, signing_key.key, algorithms=self.jwk_algorithms)
return True
except jwt.exceptions.InvalidTokenError as error:
print("invalid token", error)
diff --git a/workos/types/user_management/session.py b/workos/types/user_management/session.py
index 01581ec1..fdcd8146 100644
--- a/workos/types/user_management/session.py
+++ b/workos/types/user_management/session.py
@@ -12,8 +12,8 @@ class AuthenticateWithSessionCookieFailureReason(Enum):
class AuthenticateWithSessionCookieSuccessResponse(WorkOSModel):
authenticated: bool = True
- sessionId: str
- organizationId: Optional[str] = None
+ session_id: str
+ organization_id: Optional[str] = None
role: Optional[str] = None
permissions: Optional[List[str]] = None
user: User
From 29f4c7367a09750001284f4cb305f2d02d893f96 Mon Sep 17 00:00:00 2001
From: Paul Asjes
Date: Thu, 14 Nov 2024 16:14:35 +0100
Subject: [PATCH 04/14] Fix log out url part
---
workos/session.py | 10 +++++-----
workos/types/user_management/session.py | 1 +
2 files changed, 6 insertions(+), 5 deletions(-)
diff --git a/workos/session.py b/workos/session.py
index d2b9cf01..64537fb6 100644
--- a/workos/session.py
+++ b/workos/session.py
@@ -30,6 +30,7 @@ def __init__(
self.jwks = PyJWKClient(self.user_management.get_jwks_url())
+ # Algorithms are hardcoded for security reasons. See https://pyjwt.readthedocs.io/en/stable/algorithms.html#specifying-an-algorithm
self.jwk_algorithms = ['RS256']
def authenticate(
@@ -119,11 +120,11 @@ def refresh(self, options: Optional[Dict[str, Any]] = None) -> Union[
def get_logout_url(self) -> str:
auth_response = self.authenticate()
- if not auth_response["authenticated"]:
- raise ValueError(auth_response["reason"])
+ if not auth_response.authenticated:
+ raise ValueError(f"Failed to extract session ID for logout URL: {auth_response.reason}")
return self.user_management.get_logout_url(
- session_id=auth_response["session_id"]
+ session_id=auth_response.session_id
)
def is_valid_jwt(self, token: str) -> bool:
@@ -131,8 +132,7 @@ def is_valid_jwt(self, token: str) -> bool:
signing_key = self.jwks.get_signing_key_from_jwt(token)
jwt.decode(token, signing_key.key, algorithms=self.jwk_algorithms)
return True
- except jwt.exceptions.InvalidTokenError as error:
- print("invalid token", error)
+ except jwt.exceptions.InvalidTokenError:
return False
@staticmethod
diff --git a/workos/types/user_management/session.py b/workos/types/user_management/session.py
index fdcd8146..5e543cb4 100644
--- a/workos/types/user_management/session.py
+++ b/workos/types/user_management/session.py
@@ -18,6 +18,7 @@ class AuthenticateWithSessionCookieSuccessResponse(WorkOSModel):
permissions: Optional[List[str]] = None
user: User
impersonator: Optional[Impersonator] = None
+ entitlements: Optional[List[str]] = None
class AuthenticateWithSessionCookieErrorResponse(WorkOSModel):
authenticated: bool = False
From 87fadcb92b09fea861588b3bf1ae2048de3c9bdf Mon Sep 17 00:00:00 2001
From: Paul Asjes
Date: Fri, 22 Nov 2024 14:53:52 +0100
Subject: [PATCH 05/14] Add tests
---
tests/test_session.py | 333 ++++++++++++++++++
tests/test_user_management.py | 3 +
workos/session.py | 47 ++-
.../authenticate_with_common.py | 1 +
.../authentication_response.py | 2 +-
workos/types/user_management/session.py | 12 +-
workos/user_management.py | 26 +-
7 files changed, 405 insertions(+), 19 deletions(-)
create mode 100644 tests/test_session.py
diff --git a/tests/test_session.py b/tests/test_session.py
new file mode 100644
index 00000000..a898949c
--- /dev/null
+++ b/tests/test_session.py
@@ -0,0 +1,333 @@
+import pytest
+from unittest.mock import Mock, patch
+import jwt
+from jwt import PyJWKClient
+from datetime import datetime, timezone
+
+from workos.session import SessionModule
+from workos.types.user_management.authentication_response import RefreshTokenAuthenticationResponse
+from workos.types.user_management.session import (
+ AuthenticateWithSessionCookieFailureReason,
+ AuthenticateWithSessionCookieSuccessResponse,
+ RefreshWithSessionCookieErrorResponse,
+ RefreshWithSessionCookieSuccessResponse,
+)
+from workos.types.user_management.user import User
+
+from cryptography.hazmat.primitives import serialization
+from cryptography.hazmat.primitives.asymmetric import rsa
+
+@pytest.fixture(scope="session")
+def TEST_CONSTANTS():
+ # Generate RSA key pair for testing
+ private_key = rsa.generate_private_key(
+ public_exponent=65537,
+ key_size=2048
+ )
+
+ public_key = private_key.public_key()
+
+ # Get the private key in PEM format
+ private_pem = private_key.private_bytes(
+ encoding=serialization.Encoding.PEM,
+ format=serialization.PrivateFormat.PKCS8,
+ encryption_algorithm=serialization.NoEncryption()
+ )
+
+ return {
+ "COOKIE_PASSWORD": "pfSqwTFXUTGEBBD1RQh2kt/oNJYxBgaoZan4Z8sMrKU=",
+ "SESSION_DATA": "session_data",
+ "CLIENT_ID": "client_123",
+ "USER_ID": "user_123",
+ "SESSION_ID": "session_123",
+ "ORGANIZATION_ID": "organization_123",
+ "CURRENT_TIMESTAMP": str(datetime.now(timezone.utc)),
+ "PRIVATE_KEY": private_pem,
+ "PUBLIC_KEY": public_key,
+ "TEST_TOKEN": jwt.encode(
+ {
+ "sid": "session_123",
+ "org_id": "organization_123",
+ "role": "admin",
+ "permissions": ["read"],
+ "entitlements": ["feature_1"],
+ "exp": int(datetime.now(timezone.utc).timestamp()) + 3600,
+ "iat": int(datetime.now(timezone.utc).timestamp()),
+ },
+ private_pem,
+ algorithm="RS256"
+ )
+ }
+
+@pytest.fixture
+def mock_user_management(TEST_CONSTANTS):
+ mock_jwks = Mock(spec=PyJWKClient)
+ mock_jwk_set = Mock()
+ mock_jwk_set.keys = [Mock(Algorithm="RS256")]
+ mock_jwks.get_jwk_set.return_value = mock_jwk_set
+
+
+
+ mock = Mock()
+ mock.get_jwks_url.return_value = "https://api.workos.com/user_management/sso/jwks/client_123"
+ mock.authenticate_with_refresh_token.return_value = RefreshTokenAuthenticationResponse(
+ **{
+ "access_token": "access_token_123",
+ "refresh_token": "refresh_token_123",
+ "user": {
+ "object": "user",
+ "id": TEST_CONSTANTS["USER_ID"],
+ "email": "user@example.com",
+ "first_name": "Test",
+ "last_name": "User",
+ "email_verified": True,
+ "created_at": TEST_CONSTANTS["CURRENT_TIMESTAMP"],
+ "updated_at": TEST_CONSTANTS["CURRENT_TIMESTAMP"],
+ }
+ }
+ )
+ return mock
+
+def test_initialize_session_module(TEST_CONSTANTS, mock_user_management):
+ session = SessionModule(
+ user_management=mock_user_management,
+ client_id=TEST_CONSTANTS["CLIENT_ID"],
+ session_data=TEST_CONSTANTS["SESSION_DATA"],
+ cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"]
+ )
+
+ assert session.client_id == TEST_CONSTANTS["CLIENT_ID"]
+ assert session.cookie_password is not None
+
+def test_initialize_without_cookie_password(mock_user_management):
+ with pytest.raises(ValueError, match="cookie_password is required"):
+ SessionModule(
+ user_management=mock_user_management,
+ client_id="client_123",
+ session_data="session_data",
+ cookie_password=""
+ )
+
+def test_authenticate_no_session_cookie_provided(TEST_CONSTANTS, mock_user_management):
+ session = SessionModule(
+ user_management=mock_user_management,
+ client_id=TEST_CONSTANTS["CLIENT_ID"],
+ session_data=None,
+ cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"]
+ )
+
+ response = session.authenticate()
+
+ assert response.reason == AuthenticateWithSessionCookieFailureReason.NO_SESSION_COOKIE_PROVIDED
+
+def test_authenticate_invalid_session_cookie(TEST_CONSTANTS, mock_user_management):
+ session = SessionModule(
+ user_management=mock_user_management,
+ client_id=TEST_CONSTANTS["CLIENT_ID"],
+ session_data="invalid_session_data",
+ cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"]
+ )
+
+ response = session.authenticate()
+
+ assert response.reason == AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE
+
+def test_authenticate_invalid_jwt(TEST_CONSTANTS, mock_user_management):
+ invalid_session_data = SessionModule.seal_data({ "access_token": "invalid_session_data" }, TEST_CONSTANTS["COOKIE_PASSWORD"])
+ session = SessionModule(
+ user_management=mock_user_management,
+ client_id=TEST_CONSTANTS["CLIENT_ID"],
+ session_data=invalid_session_data,
+ cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"]
+ )
+
+ response = session.authenticate()
+
+ assert response.reason == AuthenticateWithSessionCookieFailureReason.INVALID_JWT
+
+def test_authenticate_success(TEST_CONSTANTS, mock_user_management):
+ session = SessionModule(
+ user_management=mock_user_management,
+ client_id=TEST_CONSTANTS["CLIENT_ID"],
+ session_data=TEST_CONSTANTS["SESSION_DATA"],
+ cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"]
+ )
+
+ # Mock the session data that would be unsealed
+ mock_session = {
+ "access_token": jwt.encode(
+ {
+ "sid": TEST_CONSTANTS["SESSION_ID"],
+ "org_id": TEST_CONSTANTS["ORGANIZATION_ID"],
+ "role": "admin",
+ "permissions": ["read"],
+ "entitlements": ["feature_1"],
+ "exp": int(datetime.now(timezone.utc).timestamp()) + 3600,
+ "iat": int(datetime.now(timezone.utc).timestamp()),
+ },
+ TEST_CONSTANTS["PRIVATE_KEY"],
+ algorithm="RS256"
+ ),
+ "user": {
+ "object": "user",
+ "id": TEST_CONSTANTS["USER_ID"],
+ "email": "user@example.com",
+ "email_verified": True,
+ "created_at": TEST_CONSTANTS["CURRENT_TIMESTAMP"],
+ "updated_at": TEST_CONSTANTS["CURRENT_TIMESTAMP"],
+ },
+ "impersonator": None
+ }
+
+ # Mock the JWT payload that would be decoded
+ mock_jwt_payload = {
+ "sid": TEST_CONSTANTS["SESSION_ID"],
+ "org_id": TEST_CONSTANTS["ORGANIZATION_ID"],
+ "role": "admin",
+ "permissions": ["read"],
+ "entitlements": ["feature_1"]
+ }
+
+ with (
+ # Mock unsealing the session data
+ patch.object(
+ SessionModule,
+ "unseal_data",
+ return_value=mock_session
+ ),
+ # Mock JWT validation
+ patch.object(
+ session,
+ "is_valid_jwt",
+ return_value=True
+ ),
+ # Mock JWT decoding
+ patch(
+ "jwt.decode",
+ return_value=mock_jwt_payload
+ ),
+ # Mock JWT signing key retrieval
+ patch.object(
+ session.jwks,
+ "get_signing_key_from_jwt",
+ return_value=Mock(key=TEST_CONSTANTS["PUBLIC_KEY"])
+ )
+ ):
+ response = session.authenticate()
+
+ assert isinstance(response, AuthenticateWithSessionCookieSuccessResponse)
+ assert response.authenticated is True
+ assert response.session_id == TEST_CONSTANTS["SESSION_ID"]
+ assert response.organization_id == TEST_CONSTANTS["ORGANIZATION_ID"]
+ assert response.role == "admin"
+ assert response.permissions == ["read"]
+ assert response.entitlements == ["feature_1"]
+ assert response.user.id == TEST_CONSTANTS["USER_ID"]
+ assert response.impersonator is None
+
+def test_refresh_invalid_session_cookie(TEST_CONSTANTS, mock_user_management):
+ session = SessionModule(
+ user_management=mock_user_management,
+ client_id=TEST_CONSTANTS["CLIENT_ID"],
+ session_data="invalid_session_data",
+ cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"]
+ )
+
+ response = session.refresh()
+
+ assert isinstance(response, RefreshWithSessionCookieErrorResponse)
+ assert response.reason == AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE
+
+def test_refresh_success(TEST_CONSTANTS, mock_user_management):
+ # Create mock JWKS client
+ mock_jwks = Mock(spec=PyJWKClient)
+ mock_signing_key = Mock()
+ mock_signing_key.key = TEST_CONSTANTS["PUBLIC_KEY"]
+ mock_jwks.get_signing_key_from_jwt.return_value = mock_signing_key
+
+ test_user = {
+ "object": "user",
+ "id": TEST_CONSTANTS["USER_ID"],
+ "email": "user@example.com",
+ "first_name": "Test",
+ "last_name": "User",
+ "email_verified": True,
+ "created_at": TEST_CONSTANTS["CURRENT_TIMESTAMP"],
+ "updated_at": TEST_CONSTANTS["CURRENT_TIMESTAMP"],
+ }
+
+ session_data = SessionModule.seal_data({
+ "refresh_token": "refresh_token_12345",
+ "user": test_user
+ }, TEST_CONSTANTS["COOKIE_PASSWORD"])
+
+ mock_response = {
+ "access_token": TEST_CONSTANTS["TEST_TOKEN"],
+ "refresh_token": "refresh_token_123",
+ "sealed_session": session_data,
+ "user": test_user
+ }
+
+ mock_user_management.authenticate_with_refresh_token.return_value = RefreshTokenAuthenticationResponse(
+ **mock_response
+ )
+
+ with (
+ patch(
+ 'workos.session.PyJWKClient',
+ return_value=mock_jwks
+ ),
+ ):
+ session = SessionModule(
+ user_management=mock_user_management,
+ client_id=TEST_CONSTANTS["CLIENT_ID"],
+ session_data=session_data,
+ cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"]
+ )
+
+ with (
+ patch.object(
+ session,
+ "is_valid_jwt",
+ return_value=True
+ ),
+ patch(
+ "jwt.decode",
+ return_value={
+ "sid": TEST_CONSTANTS["SESSION_ID"],
+ "org_id": TEST_CONSTANTS["ORGANIZATION_ID"],
+ "role": "admin",
+ "permissions": ["read"],
+ "entitlements": ["feature_1"]
+ }
+ )
+ ):
+ response = session.refresh()
+
+ assert isinstance(response, RefreshWithSessionCookieSuccessResponse)
+ assert response.authenticated is True
+ assert response.user.id == test_user["id"]
+
+ # Verify the refresh token was used correctly
+ mock_user_management.authenticate_with_refresh_token.assert_called_once_with(
+ refresh_token="refresh_token_12345",
+ organization_id=None,
+ session={
+ "seal_session": True,
+ "cookie_password": TEST_CONSTANTS["COOKIE_PASSWORD"]
+ }
+ )
+
+
+def test_seal_data(TEST_CONSTANTS):
+ test_data = {"test": "data"}
+ sealed = SessionModule.seal_data(test_data, TEST_CONSTANTS["COOKIE_PASSWORD"])
+ assert isinstance(sealed, str)
+
+ # Test unsealing
+ unsealed = SessionModule.unseal_data(sealed, TEST_CONSTANTS["COOKIE_PASSWORD"])
+ assert unsealed == test_data
+
+def test_unseal_invalid_data(TEST_CONSTANTS):
+ with pytest.raises(Exception): # Adjust exception type based on your implementation
+ SessionModule.unseal_data("invalid_sealed_data", TEST_CONSTANTS["COOKIE_PASSWORD"])
diff --git a/tests/test_user_management.py b/tests/test_user_management.py
index 3935f74d..3df9e417 100644
--- a/tests/test_user_management.py
+++ b/tests/test_user_management.py
@@ -60,9 +60,12 @@ def base_authentication_params(self):
@pytest.fixture
def mock_auth_refresh_token_response(self):
+ user = MockUser("user_01H7ZGXFP5C6BBQY6Z7277ZCT0").dict()
+
return {
"access_token": "access_token_12345",
"refresh_token": "refresh_token_12345",
+ "user": user,
}
@pytest.fixture
diff --git a/workos/session.py b/workos/session.py
index 64537fb6..91590b23 100644
--- a/workos/session.py
+++ b/workos/session.py
@@ -8,6 +8,8 @@
AuthenticateWithSessionCookieFailureReason,
AuthenticateWithSessionCookieSuccessResponse,
AuthenticateWithSessionCookieErrorResponse,
+ RefreshWithSessionCookieErrorResponse,
+ RefreshWithSessionCookieSuccessResponse,
)
class SessionModule:
@@ -75,25 +77,24 @@ def authenticate(
entitlements=decoded.get("entitlements", None),
user=session["user"],
impersonator=session.get("impersonator", None),
- reason=None,
)
def refresh(self, options: Optional[Dict[str, Any]] = None) -> Union[
- AuthenticateWithSessionCookieSuccessResponse,
- AuthenticateWithSessionCookieErrorResponse,
+ RefreshWithSessionCookieSuccessResponse,
+ RefreshWithSessionCookieErrorResponse,
]:
- cookie_password = options.get("cookie_password", self.cookie_password)
- organization_id = options.get("organization_id", None)
+ cookie_password = self.cookie_password if options is None else options.get("cookie_password")
+ organization_id = None if options is None else options.get("organization_id")
try:
session = self.unseal_data(self.session_data, cookie_password)
except Exception:
- return AuthenticateWithSessionCookieErrorResponse(
+ return RefreshWithSessionCookieErrorResponse(
authenticated=False, reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE
)
if not session["refresh_token"] or not session["user"]:
- return AuthenticateWithSessionCookieErrorResponse(
+ return RefreshWithSessionCookieErrorResponse(
authenticated=False, reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE
)
@@ -101,19 +102,34 @@ def refresh(self, options: Optional[Dict[str, Any]] = None) -> Union[
auth_response = self.user_management.authenticate_with_refresh_token(
refresh_token=session["refresh_token"],
organization_id=organization_id,
+ session={
+ "seal_session": True,
+ "cookie_password": cookie_password
+ }
)
self.session_data = auth_response.sealed_session
self.cookie_password = cookie_password
- return AuthenticateWithSessionCookieSuccessResponse(
+ signing_key = self.jwks.get_signing_key_from_jwt(auth_response.access_token)
+
+ decoded = jwt.decode(
+ auth_response.access_token, signing_key.key, algorithms=self.jwk_algorithms
+ )
+
+ return RefreshWithSessionCookieSuccessResponse(
authenticated=True,
sealed_session=auth_response.sealed_session,
- session=auth_response,
- reason=None,
+ session_id=decoded["sid"],
+ organization_id=decoded.get("org_id", None),
+ role=decoded.get("role", None),
+ permissions=decoded.get("permissions", None),
+ entitlements=decoded.get("entitlements", None),
+ user=auth_response.user,
+ impersonator=auth_response.impersonator,
)
except Exception as e:
- return AuthenticateWithSessionCookieErrorResponse(
+ return RefreshWithSessionCookieErrorResponse(
authenticated=False, reason=str(e)
)
@@ -138,10 +154,13 @@ def is_valid_jwt(self, token: str) -> bool:
@staticmethod
def seal_data(data: Dict[str, Any], key: str) -> str:
fernet = Fernet(key)
- # take the data and encrypt it with the key using fernet
- return fernet.encrypt(json.dumps(data).encode())
+ # Encrypt and convert bytes to string
+ encrypted_bytes = fernet.encrypt(json.dumps(data).encode())
+ return encrypted_bytes.decode('utf-8')
@staticmethod
def unseal_data(sealed_data: str, key: str) -> Dict[str, Any]:
fernet = Fernet(key)
- return json.loads(fernet.decrypt(sealed_data).decode())
+ # Convert string back to bytes before decryption
+ encrypted_bytes = sealed_data.encode('utf-8')
+ return json.loads(fernet.decrypt(encrypted_bytes).decode())
diff --git a/workos/types/user_management/authenticate_with_common.py b/workos/types/user_management/authenticate_with_common.py
index 73de96ac..8adc9ee6 100644
--- a/workos/types/user_management/authenticate_with_common.py
+++ b/workos/types/user_management/authenticate_with_common.py
@@ -51,6 +51,7 @@ class AuthenticateWithRefreshTokenParameters(AuthenticateWithBaseParameters):
refresh_token: str
organization_id: Union[str, None]
grant_type: Literal["refresh_token"]
+ session: Union[SessionConfig, None]
AuthenticateWithParameters = Union[
diff --git a/workos/types/user_management/authentication_response.py b/workos/types/user_management/authentication_response.py
index 6a5f4449..96973fd7 100644
--- a/workos/types/user_management/authentication_response.py
+++ b/workos/types/user_management/authentication_response.py
@@ -39,7 +39,7 @@ class AuthKitAuthenticationResponse(AuthenticationResponse):
oauth_tokens: Optional[OAuthTokens] = None
-class RefreshTokenAuthenticationResponse(_AuthenticationResponseBase):
+class RefreshTokenAuthenticationResponse(AuthenticationResponse):
"""Representation of a WorkOS refresh token authentication response."""
pass
diff --git a/workos/types/user_management/session.py b/workos/types/user_management/session.py
index 5e543cb4..1c8cc6a7 100644
--- a/workos/types/user_management/session.py
+++ b/workos/types/user_management/session.py
@@ -1,4 +1,4 @@
-from typing import List, Optional, TypedDict
+from typing import List, Optional, TypedDict, Union
from enum import Enum
from workos.types.user_management.impersonator import Impersonator
@@ -22,7 +22,15 @@ class AuthenticateWithSessionCookieSuccessResponse(WorkOSModel):
class AuthenticateWithSessionCookieErrorResponse(WorkOSModel):
authenticated: bool = False
- reason: AuthenticateWithSessionCookieFailureReason
+ reason: Union[AuthenticateWithSessionCookieFailureReason, str]
+
+class RefreshWithSessionCookieSuccessResponse(AuthenticateWithSessionCookieSuccessResponse):
+ sealed_session: str
+
+
+class RefreshWithSessionCookieErrorResponse(WorkOSModel):
+ authenticated: bool = False
+ reason: Union[AuthenticateWithSessionCookieFailureReason, str]
class SessionConfig(TypedDict):
seal_session: bool
diff --git a/workos/user_management.py b/workos/user_management.py
index 66232e3a..976d36af 100644
--- a/workos/user_management.py
+++ b/workos/user_management.py
@@ -431,6 +431,7 @@ def authenticate_with_code(
self,
*,
code: str,
+ session: Optional[SessionConfig] = None,
code_verifier: Optional[str] = None,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
@@ -439,6 +440,7 @@ def authenticate_with_code(
Kwargs:
code (str): The authorization value which was passed back as a query parameter in the callback to the Redirect URI.
+ session (SessionConfig): Configuration for the session. (Optional)
code_verifier (str): The randomly generated string used to derive the code challenge that was passed to the authorization
url as part of the PKCE flow. This parameter is required when the client secret is not present. (Optional)
ip_address (str): The IP address of the request from the user who is attempting to authenticate. (Optional)
@@ -542,6 +544,7 @@ def authenticate_with_refresh_token(
self,
*,
refresh_token: str,
+ session: Optional[SessionConfig] = None,
organization_id: Optional[str] = None,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
@@ -550,6 +553,7 @@ def authenticate_with_refresh_token(
Kwargs:
refresh_token (str): The token associated to the user.
+ session (SessionConfig): Configuration for the session. (Optional)
organization_id (str): The organization to issue the new access token for. (Optional)
ip_address (str): The IP address of the request from the user who is attempting to authenticate. (Optional)
user_agent (str): The user agent of the request from the user who is attempting to authenticate. (Optional)
@@ -1030,8 +1034,8 @@ def _authenticate_with(
json=json,
)
- if payload["session"] is not None and payload["session"].get("seal_session") is True:
- response["sealed_session"] = SessionModule.seal_data(response, payload["session"]["cookie_password"])
+ if payload.get("session") is not None and payload.get("session").get("seal_session") is True:
+ response["sealed_session"] = SessionModule.seal_data(response, payload.get("session").get("cookie_password"))
return response_model.model_validate(response)
@@ -1158,16 +1162,21 @@ def authenticate_with_refresh_token(
self,
*,
refresh_token: str,
+ session: Optional[SessionConfig] = None,
organization_id: Optional[str] = None,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
) -> RefreshTokenAuthenticationResponse:
+ if session is not None and (session.get("seal_session") is True and session.get("cookie_password") is None or ""):
+ raise ValueError("cookie_password is required when sealing session")
+
payload: AuthenticateWithRefreshTokenParameters = {
"refresh_token": refresh_token,
"organization_id": organization_id,
"grant_type": "refresh_token",
"ip_address": ip_address,
"user_agent": user_agent,
+ "session": session,
}
return self._authenticate_with(
@@ -1614,6 +1623,9 @@ async def _authenticate_with(
json=json,
)
+ if payload.get("session") is not None and payload.get("session").get("seal_session") is True:
+ response["sealed_session"] = SessionModule.seal_data(response, payload.get("session").get("cookie_password"))
+
return response_model.model_validate(response)
async def authenticate_with_password(
@@ -1640,16 +1652,21 @@ async def authenticate_with_code(
self,
*,
code: str,
+ session: Optional[SessionConfig] = None,
code_verifier: Optional[str] = None,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
) -> AuthKitAuthenticationResponse:
+ if session is not None and (session.get("seal_session") is True and session.get("cookie_password") is None or ""):
+ raise ValueError("cookie_password is required when sealing session")
+
payload: AuthenticateWithCodeParameters = {
"code": code,
"grant_type": "authorization_code",
"ip_address": ip_address,
"user_agent": user_agent,
"code_verifier": code_verifier,
+ "session": session,
}
return await self._authenticate_with(
@@ -1744,16 +1761,21 @@ async def authenticate_with_refresh_token(
self,
*,
refresh_token: str,
+ session: Optional[SessionConfig] = None,
organization_id: Optional[str] = None,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
) -> RefreshTokenAuthenticationResponse:
+ if session is not None and (session.get("seal_session") is True and session.get("cookie_password") is None or ""):
+ raise ValueError("cookie_password is required when sealing session")
+
payload: AuthenticateWithRefreshTokenParameters = {
"refresh_token": refresh_token,
"organization_id": organization_id,
"grant_type": "refresh_token",
"ip_address": ip_address,
"user_agent": user_agent,
+ "session": session,
}
return await self._authenticate_with(
From 13bb07362711db32c63a62af2cf26784c538bef5 Mon Sep 17 00:00:00 2001
From: Paul Asjes
Date: Fri, 22 Nov 2024 15:08:13 +0100
Subject: [PATCH 06/14] Mock JWKS call to speed up tests
---
tests/conftest.py | 17 ++++++
tests/test_session.py | 123 +++++++++++++++++-------------------------
2 files changed, 66 insertions(+), 74 deletions(-)
diff --git a/tests/conftest.py b/tests/conftest.py
index 81ef0ca8..1bd77d72 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -24,6 +24,9 @@
from workos.utils.http_client import AsyncHTTPClient, HTTPClient, SyncHTTPClient
from workos.utils.request_helper import DEFAULT_LIST_RESPONSE_LIMIT
+from jwt import PyJWKClient
+from unittest.mock import Mock, patch
+from functools import wraps
def _get_test_client_setup(
http_client_class_name: str,
@@ -302,3 +305,17 @@ def inner(
assert request_kwargs["params"][param] == params[param]
return inner
+
+def with_jwks_mock(func):
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ # Create mock JWKS client
+ mock_jwks = Mock(spec=PyJWKClient)
+ mock_signing_key = Mock()
+ mock_signing_key.key = kwargs['TEST_CONSTANTS']["PUBLIC_KEY"]
+ mock_jwks.get_signing_key_from_jwt.return_value = mock_signing_key
+
+ # Apply the mock
+ with patch('workos.session.PyJWKClient', return_value=mock_jwks):
+ return func(*args, **kwargs)
+ return wrapper
\ No newline at end of file
diff --git a/tests/test_session.py b/tests/test_session.py
index a898949c..9e3cc021 100644
--- a/tests/test_session.py
+++ b/tests/test_session.py
@@ -4,6 +4,7 @@
from jwt import PyJWKClient
from datetime import datetime, timezone
+from tests.conftest import with_jwks_mock
from workos.session import SessionModule
from workos.types.user_management.authentication_response import RefreshTokenAuthenticationResponse
from workos.types.user_management.session import (
@@ -60,34 +61,13 @@ def TEST_CONSTANTS():
}
@pytest.fixture
-def mock_user_management(TEST_CONSTANTS):
- mock_jwks = Mock(spec=PyJWKClient)
- mock_jwk_set = Mock()
- mock_jwk_set.keys = [Mock(Algorithm="RS256")]
- mock_jwks.get_jwk_set.return_value = mock_jwk_set
-
-
-
+def mock_user_management():
mock = Mock()
mock.get_jwks_url.return_value = "https://api.workos.com/user_management/sso/jwks/client_123"
- mock.authenticate_with_refresh_token.return_value = RefreshTokenAuthenticationResponse(
- **{
- "access_token": "access_token_123",
- "refresh_token": "refresh_token_123",
- "user": {
- "object": "user",
- "id": TEST_CONSTANTS["USER_ID"],
- "email": "user@example.com",
- "first_name": "Test",
- "last_name": "User",
- "email_verified": True,
- "created_at": TEST_CONSTANTS["CURRENT_TIMESTAMP"],
- "updated_at": TEST_CONSTANTS["CURRENT_TIMESTAMP"],
- }
- }
- )
+
return mock
+@with_jwks_mock
def test_initialize_session_module(TEST_CONSTANTS, mock_user_management):
session = SessionModule(
user_management=mock_user_management,
@@ -99,15 +79,17 @@ def test_initialize_session_module(TEST_CONSTANTS, mock_user_management):
assert session.client_id == TEST_CONSTANTS["CLIENT_ID"]
assert session.cookie_password is not None
-def test_initialize_without_cookie_password(mock_user_management):
+@with_jwks_mock
+def test_initialize_without_cookie_password(TEST_CONSTANTS, mock_user_management):
with pytest.raises(ValueError, match="cookie_password is required"):
SessionModule(
user_management=mock_user_management,
- client_id="client_123",
- session_data="session_data",
+ client_id=TEST_CONSTANTS["CLIENT_ID"],
+ session_data=TEST_CONSTANTS["SESSION_DATA"],
cookie_password=""
)
+@with_jwks_mock
def test_authenticate_no_session_cookie_provided(TEST_CONSTANTS, mock_user_management):
session = SessionModule(
user_management=mock_user_management,
@@ -120,6 +102,7 @@ def test_authenticate_no_session_cookie_provided(TEST_CONSTANTS, mock_user_manag
assert response.reason == AuthenticateWithSessionCookieFailureReason.NO_SESSION_COOKIE_PROVIDED
+@with_jwks_mock
def test_authenticate_invalid_session_cookie(TEST_CONSTANTS, mock_user_management):
session = SessionModule(
user_management=mock_user_management,
@@ -132,6 +115,7 @@ def test_authenticate_invalid_session_cookie(TEST_CONSTANTS, mock_user_managemen
assert response.reason == AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE
+@with_jwks_mock
def test_authenticate_invalid_jwt(TEST_CONSTANTS, mock_user_management):
invalid_session_data = SessionModule.seal_data({ "access_token": "invalid_session_data" }, TEST_CONSTANTS["COOKIE_PASSWORD"])
session = SessionModule(
@@ -145,6 +129,7 @@ def test_authenticate_invalid_jwt(TEST_CONSTANTS, mock_user_management):
assert response.reason == AuthenticateWithSessionCookieFailureReason.INVALID_JWT
+@with_jwks_mock
def test_authenticate_success(TEST_CONSTANTS, mock_user_management):
session = SessionModule(
user_management=mock_user_management,
@@ -225,6 +210,7 @@ def test_authenticate_success(TEST_CONSTANTS, mock_user_management):
assert response.user.id == TEST_CONSTANTS["USER_ID"]
assert response.impersonator is None
+@with_jwks_mock
def test_refresh_invalid_session_cookie(TEST_CONSTANTS, mock_user_management):
session = SessionModule(
user_management=mock_user_management,
@@ -238,13 +224,8 @@ def test_refresh_invalid_session_cookie(TEST_CONSTANTS, mock_user_management):
assert isinstance(response, RefreshWithSessionCookieErrorResponse)
assert response.reason == AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE
+@with_jwks_mock
def test_refresh_success(TEST_CONSTANTS, mock_user_management):
- # Create mock JWKS client
- mock_jwks = Mock(spec=PyJWKClient)
- mock_signing_key = Mock()
- mock_signing_key.key = TEST_CONSTANTS["PUBLIC_KEY"]
- mock_jwks.get_signing_key_from_jwt.return_value = mock_signing_key
-
test_user = {
"object": "user",
"id": TEST_CONSTANTS["USER_ID"],
@@ -272,51 +253,45 @@ def test_refresh_success(TEST_CONSTANTS, mock_user_management):
**mock_response
)
+ session = SessionModule(
+ user_management=mock_user_management,
+ client_id=TEST_CONSTANTS["CLIENT_ID"],
+ session_data=session_data,
+ cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"]
+ )
+
with (
- patch(
- 'workos.session.PyJWKClient',
- return_value=mock_jwks
+ patch.object(
+ session,
+ "is_valid_jwt",
+ return_value=True
),
- ):
- session = SessionModule(
- user_management=mock_user_management,
- client_id=TEST_CONSTANTS["CLIENT_ID"],
- session_data=session_data,
- cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"]
- )
-
- with (
- patch.object(
- session,
- "is_valid_jwt",
- return_value=True
- ),
- patch(
- "jwt.decode",
- return_value={
- "sid": TEST_CONSTANTS["SESSION_ID"],
- "org_id": TEST_CONSTANTS["ORGANIZATION_ID"],
- "role": "admin",
- "permissions": ["read"],
- "entitlements": ["feature_1"]
- }
- )
- ):
- response = session.refresh()
-
- assert isinstance(response, RefreshWithSessionCookieSuccessResponse)
- assert response.authenticated is True
- assert response.user.id == test_user["id"]
-
- # Verify the refresh token was used correctly
- mock_user_management.authenticate_with_refresh_token.assert_called_once_with(
- refresh_token="refresh_token_12345",
- organization_id=None,
- session={
- "seal_session": True,
- "cookie_password": TEST_CONSTANTS["COOKIE_PASSWORD"]
+ patch(
+ "jwt.decode",
+ return_value={
+ "sid": TEST_CONSTANTS["SESSION_ID"],
+ "org_id": TEST_CONSTANTS["ORGANIZATION_ID"],
+ "role": "admin",
+ "permissions": ["read"],
+ "entitlements": ["feature_1"]
}
)
+ ):
+ response = session.refresh()
+
+ assert isinstance(response, RefreshWithSessionCookieSuccessResponse)
+ assert response.authenticated is True
+ assert response.user.id == test_user["id"]
+
+ # Verify the refresh token was used correctly
+ mock_user_management.authenticate_with_refresh_token.assert_called_once_with(
+ refresh_token="refresh_token_12345",
+ organization_id=None,
+ session={
+ "seal_session": True,
+ "cookie_password": TEST_CONSTANTS["COOKIE_PASSWORD"]
+ }
+ )
def test_seal_data(TEST_CONSTANTS):
From 38fa6c9eb0b2d5b2f23f6dcf98e7902d4b23c9de Mon Sep 17 00:00:00 2001
From: Paul Asjes
Date: Fri, 22 Nov 2024 15:22:04 +0100
Subject: [PATCH 07/14] linting
---
setup.py | 5 +-
tests/conftest.py | 9 +-
tests/test_session.py | 126 +++++++++++++-----------
workos/session.py | 48 +++++----
workos/types/user_management/session.py | 17 +++-
workos/user_management.py | 57 ++++++++---
6 files changed, 163 insertions(+), 99 deletions(-)
diff --git a/setup.py b/setup.py
index bfd5dbcb..103be545 100644
--- a/setup.py
+++ b/setup.py
@@ -10,10 +10,11 @@
with open(os.path.join(base_dir, "workos", "__about__.py")) as f:
exec(f.read(), about)
+
def read_requirements(filename):
with open(filename) as f:
- return [line.strip() for line in f
- if line.strip() and not line.startswith('#')]
+ return [line.strip() for line in f if line.strip() and not line.startswith("#")]
+
setup(
name=about["__package_name__"],
diff --git a/tests/conftest.py b/tests/conftest.py
index 1bd77d72..7f9f058a 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -28,6 +28,7 @@
from unittest.mock import Mock, patch
from functools import wraps
+
def _get_test_client_setup(
http_client_class_name: str,
) -> Tuple[Literal["async", "sync"], ClientConfiguration, HTTPClient]:
@@ -306,16 +307,18 @@ def inner(
return inner
+
def with_jwks_mock(func):
@wraps(func)
def wrapper(*args, **kwargs):
# Create mock JWKS client
mock_jwks = Mock(spec=PyJWKClient)
mock_signing_key = Mock()
- mock_signing_key.key = kwargs['TEST_CONSTANTS']["PUBLIC_KEY"]
+ mock_signing_key.key = kwargs["TEST_CONSTANTS"]["PUBLIC_KEY"]
mock_jwks.get_signing_key_from_jwt.return_value = mock_signing_key
# Apply the mock
- with patch('workos.session.PyJWKClient', return_value=mock_jwks):
+ with patch("workos.session.PyJWKClient", return_value=mock_jwks):
return func(*args, **kwargs)
- return wrapper
\ No newline at end of file
+
+ return wrapper
diff --git a/tests/test_session.py b/tests/test_session.py
index 9e3cc021..ee65f331 100644
--- a/tests/test_session.py
+++ b/tests/test_session.py
@@ -6,7 +6,9 @@
from tests.conftest import with_jwks_mock
from workos.session import SessionModule
-from workos.types.user_management.authentication_response import RefreshTokenAuthenticationResponse
+from workos.types.user_management.authentication_response import (
+ RefreshTokenAuthenticationResponse,
+)
from workos.types.user_management.session import (
AuthenticateWithSessionCookieFailureReason,
AuthenticateWithSessionCookieSuccessResponse,
@@ -18,13 +20,11 @@
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
+
@pytest.fixture(scope="session")
def TEST_CONSTANTS():
# Generate RSA key pair for testing
- private_key = rsa.generate_private_key(
- public_exponent=65537,
- key_size=2048
- )
+ private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
public_key = private_key.public_key()
@@ -32,7 +32,7 @@ def TEST_CONSTANTS():
private_pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
- encryption_algorithm=serialization.NoEncryption()
+ encryption_algorithm=serialization.NoEncryption(),
)
return {
@@ -56,29 +56,34 @@ def TEST_CONSTANTS():
"iat": int(datetime.now(timezone.utc).timestamp()),
},
private_pem,
- algorithm="RS256"
- )
+ algorithm="RS256",
+ ),
}
+
@pytest.fixture
def mock_user_management():
mock = Mock()
- mock.get_jwks_url.return_value = "https://api.workos.com/user_management/sso/jwks/client_123"
+ mock.get_jwks_url.return_value = (
+ "https://api.workos.com/user_management/sso/jwks/client_123"
+ )
return mock
+
@with_jwks_mock
def test_initialize_session_module(TEST_CONSTANTS, mock_user_management):
session = SessionModule(
user_management=mock_user_management,
client_id=TEST_CONSTANTS["CLIENT_ID"],
session_data=TEST_CONSTANTS["SESSION_DATA"],
- cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"]
+ cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"],
)
assert session.client_id == TEST_CONSTANTS["CLIENT_ID"]
assert session.cookie_password is not None
+
@with_jwks_mock
def test_initialize_without_cookie_password(TEST_CONSTANTS, mock_user_management):
with pytest.raises(ValueError, match="cookie_password is required"):
@@ -86,21 +91,26 @@ def test_initialize_without_cookie_password(TEST_CONSTANTS, mock_user_management
user_management=mock_user_management,
client_id=TEST_CONSTANTS["CLIENT_ID"],
session_data=TEST_CONSTANTS["SESSION_DATA"],
- cookie_password=""
+ cookie_password="",
)
+
@with_jwks_mock
def test_authenticate_no_session_cookie_provided(TEST_CONSTANTS, mock_user_management):
session = SessionModule(
user_management=mock_user_management,
client_id=TEST_CONSTANTS["CLIENT_ID"],
session_data=None,
- cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"]
+ cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"],
)
response = session.authenticate()
- assert response.reason == AuthenticateWithSessionCookieFailureReason.NO_SESSION_COOKIE_PROVIDED
+ assert (
+ response.reason
+ == AuthenticateWithSessionCookieFailureReason.NO_SESSION_COOKIE_PROVIDED
+ )
+
@with_jwks_mock
def test_authenticate_invalid_session_cookie(TEST_CONSTANTS, mock_user_management):
@@ -108,34 +118,41 @@ def test_authenticate_invalid_session_cookie(TEST_CONSTANTS, mock_user_managemen
user_management=mock_user_management,
client_id=TEST_CONSTANTS["CLIENT_ID"],
session_data="invalid_session_data",
- cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"]
+ cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"],
)
response = session.authenticate()
- assert response.reason == AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE
+ assert (
+ response.reason
+ == AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE
+ )
+
@with_jwks_mock
def test_authenticate_invalid_jwt(TEST_CONSTANTS, mock_user_management):
- invalid_session_data = SessionModule.seal_data({ "access_token": "invalid_session_data" }, TEST_CONSTANTS["COOKIE_PASSWORD"])
+ invalid_session_data = SessionModule.seal_data(
+ {"access_token": "invalid_session_data"}, TEST_CONSTANTS["COOKIE_PASSWORD"]
+ )
session = SessionModule(
user_management=mock_user_management,
client_id=TEST_CONSTANTS["CLIENT_ID"],
session_data=invalid_session_data,
- cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"]
+ cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"],
)
response = session.authenticate()
assert response.reason == AuthenticateWithSessionCookieFailureReason.INVALID_JWT
+
@with_jwks_mock
def test_authenticate_success(TEST_CONSTANTS, mock_user_management):
session = SessionModule(
user_management=mock_user_management,
client_id=TEST_CONSTANTS["CLIENT_ID"],
session_data=TEST_CONSTANTS["SESSION_DATA"],
- cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"]
+ cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"],
)
# Mock the session data that would be unsealed
@@ -151,7 +168,7 @@ def test_authenticate_success(TEST_CONSTANTS, mock_user_management):
"iat": int(datetime.now(timezone.utc).timestamp()),
},
TEST_CONSTANTS["PRIVATE_KEY"],
- algorithm="RS256"
+ algorithm="RS256",
),
"user": {
"object": "user",
@@ -161,7 +178,7 @@ def test_authenticate_success(TEST_CONSTANTS, mock_user_management):
"created_at": TEST_CONSTANTS["CURRENT_TIMESTAMP"],
"updated_at": TEST_CONSTANTS["CURRENT_TIMESTAMP"],
},
- "impersonator": None
+ "impersonator": None,
}
# Mock the JWT payload that would be decoded
@@ -170,33 +187,22 @@ def test_authenticate_success(TEST_CONSTANTS, mock_user_management):
"org_id": TEST_CONSTANTS["ORGANIZATION_ID"],
"role": "admin",
"permissions": ["read"],
- "entitlements": ["feature_1"]
+ "entitlements": ["feature_1"],
}
with (
# Mock unsealing the session data
- patch.object(
- SessionModule,
- "unseal_data",
- return_value=mock_session
- ),
+ patch.object(SessionModule, "unseal_data", return_value=mock_session),
# Mock JWT validation
- patch.object(
- session,
- "is_valid_jwt",
- return_value=True
- ),
+ patch.object(session, "is_valid_jwt", return_value=True),
# Mock JWT decoding
- patch(
- "jwt.decode",
- return_value=mock_jwt_payload
- ),
+ patch("jwt.decode", return_value=mock_jwt_payload),
# Mock JWT signing key retrieval
patch.object(
session.jwks,
"get_signing_key_from_jwt",
- return_value=Mock(key=TEST_CONSTANTS["PUBLIC_KEY"])
- )
+ return_value=Mock(key=TEST_CONSTANTS["PUBLIC_KEY"]),
+ ),
):
response = session.authenticate()
@@ -210,19 +216,24 @@ def test_authenticate_success(TEST_CONSTANTS, mock_user_management):
assert response.user.id == TEST_CONSTANTS["USER_ID"]
assert response.impersonator is None
+
@with_jwks_mock
def test_refresh_invalid_session_cookie(TEST_CONSTANTS, mock_user_management):
session = SessionModule(
user_management=mock_user_management,
client_id=TEST_CONSTANTS["CLIENT_ID"],
session_data="invalid_session_data",
- cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"]
+ cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"],
)
response = session.refresh()
assert isinstance(response, RefreshWithSessionCookieErrorResponse)
- assert response.reason == AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE
+ assert (
+ response.reason
+ == AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE
+ )
+
@with_jwks_mock
def test_refresh_success(TEST_CONSTANTS, mock_user_management):
@@ -237,35 +248,31 @@ def test_refresh_success(TEST_CONSTANTS, mock_user_management):
"updated_at": TEST_CONSTANTS["CURRENT_TIMESTAMP"],
}
- session_data = SessionModule.seal_data({
- "refresh_token": "refresh_token_12345",
- "user": test_user
- }, TEST_CONSTANTS["COOKIE_PASSWORD"])
+ session_data = SessionModule.seal_data(
+ {"refresh_token": "refresh_token_12345", "user": test_user},
+ TEST_CONSTANTS["COOKIE_PASSWORD"],
+ )
mock_response = {
"access_token": TEST_CONSTANTS["TEST_TOKEN"],
"refresh_token": "refresh_token_123",
"sealed_session": session_data,
- "user": test_user
+ "user": test_user,
}
- mock_user_management.authenticate_with_refresh_token.return_value = RefreshTokenAuthenticationResponse(
- **mock_response
+ mock_user_management.authenticate_with_refresh_token.return_value = (
+ RefreshTokenAuthenticationResponse(**mock_response)
)
session = SessionModule(
user_management=mock_user_management,
client_id=TEST_CONSTANTS["CLIENT_ID"],
session_data=session_data,
- cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"]
+ cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"],
)
with (
- patch.object(
- session,
- "is_valid_jwt",
- return_value=True
- ),
+ patch.object(session, "is_valid_jwt", return_value=True),
patch(
"jwt.decode",
return_value={
@@ -273,9 +280,9 @@ def test_refresh_success(TEST_CONSTANTS, mock_user_management):
"org_id": TEST_CONSTANTS["ORGANIZATION_ID"],
"role": "admin",
"permissions": ["read"],
- "entitlements": ["feature_1"]
- }
- )
+ "entitlements": ["feature_1"],
+ },
+ ),
):
response = session.refresh()
@@ -289,8 +296,8 @@ def test_refresh_success(TEST_CONSTANTS, mock_user_management):
organization_id=None,
session={
"seal_session": True,
- "cookie_password": TEST_CONSTANTS["COOKIE_PASSWORD"]
- }
+ "cookie_password": TEST_CONSTANTS["COOKIE_PASSWORD"],
+ },
)
@@ -303,6 +310,9 @@ def test_seal_data(TEST_CONSTANTS):
unsealed = SessionModule.unseal_data(sealed, TEST_CONSTANTS["COOKIE_PASSWORD"])
assert unsealed == test_data
+
def test_unseal_invalid_data(TEST_CONSTANTS):
with pytest.raises(Exception): # Adjust exception type based on your implementation
- SessionModule.unseal_data("invalid_sealed_data", TEST_CONSTANTS["COOKIE_PASSWORD"])
+ SessionModule.unseal_data(
+ "invalid_sealed_data", TEST_CONSTANTS["COOKIE_PASSWORD"]
+ )
diff --git a/workos/session.py b/workos/session.py
index 91590b23..dc4bf2d1 100644
--- a/workos/session.py
+++ b/workos/session.py
@@ -12,6 +12,7 @@
RefreshWithSessionCookieSuccessResponse,
)
+
class SessionModule:
def __init__(
self,
@@ -19,7 +20,7 @@ def __init__(
user_management: Any,
client_id: str,
session_data: str,
- cookie_password: str
+ cookie_password: str,
) -> None:
# If the cookie password is not provided, throw an error
if cookie_password is None or cookie_password == "":
@@ -33,7 +34,7 @@ def __init__(
self.jwks = PyJWKClient(self.user_management.get_jwks_url())
# Algorithms are hardcoded for security reasons. See https://pyjwt.readthedocs.io/en/stable/algorithms.html#specifying-an-algorithm
- self.jwk_algorithms = ['RS256']
+ self.jwk_algorithms = ["RS256"]
def authenticate(
self,
@@ -43,24 +44,28 @@ def authenticate(
]:
if self.session_data is None:
return AuthenticateWithSessionCookieErrorResponse(
- authenticated=False, reason=AuthenticateWithSessionCookieFailureReason.NO_SESSION_COOKIE_PROVIDED
+ authenticated=False,
+ reason=AuthenticateWithSessionCookieFailureReason.NO_SESSION_COOKIE_PROVIDED,
)
try:
session = self.unseal_data(self.session_data, self.cookie_password)
except Exception:
return AuthenticateWithSessionCookieErrorResponse(
- authenticated=False, reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE
+ authenticated=False,
+ reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE,
)
if not session["access_token"]:
return AuthenticateWithSessionCookieErrorResponse(
- authenticated=False, reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE
+ authenticated=False,
+ reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE,
)
if not self.is_valid_jwt(session["access_token"]):
return AuthenticateWithSessionCookieErrorResponse(
- authenticated=False, reason=AuthenticateWithSessionCookieFailureReason.INVALID_JWT
+ authenticated=False,
+ reason=AuthenticateWithSessionCookieFailureReason.INVALID_JWT,
)
signing_key = self.jwks.get_signing_key_from_jwt(session["access_token"])
@@ -83,29 +88,30 @@ def refresh(self, options: Optional[Dict[str, Any]] = None) -> Union[
RefreshWithSessionCookieSuccessResponse,
RefreshWithSessionCookieErrorResponse,
]:
- cookie_password = self.cookie_password if options is None else options.get("cookie_password")
+ cookie_password = (
+ self.cookie_password if options is None else options.get("cookie_password")
+ )
organization_id = None if options is None else options.get("organization_id")
try:
session = self.unseal_data(self.session_data, cookie_password)
except Exception:
return RefreshWithSessionCookieErrorResponse(
- authenticated=False, reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE
+ authenticated=False,
+ reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE,
)
if not session["refresh_token"] or not session["user"]:
return RefreshWithSessionCookieErrorResponse(
- authenticated=False, reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE
+ authenticated=False,
+ reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE,
)
try:
auth_response = self.user_management.authenticate_with_refresh_token(
refresh_token=session["refresh_token"],
organization_id=organization_id,
- session={
- "seal_session": True,
- "cookie_password": cookie_password
- }
+ session={"seal_session": True, "cookie_password": cookie_password},
)
self.session_data = auth_response.sealed_session
@@ -114,7 +120,9 @@ def refresh(self, options: Optional[Dict[str, Any]] = None) -> Union[
signing_key = self.jwks.get_signing_key_from_jwt(auth_response.access_token)
decoded = jwt.decode(
- auth_response.access_token, signing_key.key, algorithms=self.jwk_algorithms
+ auth_response.access_token,
+ signing_key.key,
+ algorithms=self.jwk_algorithms,
)
return RefreshWithSessionCookieSuccessResponse(
@@ -137,11 +145,11 @@ def get_logout_url(self) -> str:
auth_response = self.authenticate()
if not auth_response.authenticated:
- raise ValueError(f"Failed to extract session ID for logout URL: {auth_response.reason}")
+ raise ValueError(
+ f"Failed to extract session ID for logout URL: {auth_response.reason}"
+ )
- return self.user_management.get_logout_url(
- session_id=auth_response.session_id
- )
+ return self.user_management.get_logout_url(session_id=auth_response.session_id)
def is_valid_jwt(self, token: str) -> bool:
try:
@@ -156,11 +164,11 @@ def seal_data(data: Dict[str, Any], key: str) -> str:
fernet = Fernet(key)
# Encrypt and convert bytes to string
encrypted_bytes = fernet.encrypt(json.dumps(data).encode())
- return encrypted_bytes.decode('utf-8')
+ return encrypted_bytes.decode("utf-8")
@staticmethod
def unseal_data(sealed_data: str, key: str) -> Dict[str, Any]:
fernet = Fernet(key)
# Convert string back to bytes before decryption
- encrypted_bytes = sealed_data.encode('utf-8')
+ encrypted_bytes = sealed_data.encode("utf-8")
return json.loads(fernet.decrypt(encrypted_bytes).decode())
diff --git a/workos/types/user_management/session.py b/workos/types/user_management/session.py
index 1c8cc6a7..d6c859e5 100644
--- a/workos/types/user_management/session.py
+++ b/workos/types/user_management/session.py
@@ -5,10 +5,12 @@
from workos.types.user_management.user import User
from workos.types.workos_model import WorkOSModel
+
class AuthenticateWithSessionCookieFailureReason(Enum):
- INVALID_JWT = 'invalid_jwt'
- INVALID_SESSION_COOKIE = 'invalid_session_cookie'
- NO_SESSION_COOKIE_PROVIDED = 'no_session_cookie_provided'
+ INVALID_JWT = "invalid_jwt"
+ INVALID_SESSION_COOKIE = "invalid_session_cookie"
+ NO_SESSION_COOKIE_PROVIDED = "no_session_cookie_provided"
+
class AuthenticateWithSessionCookieSuccessResponse(WorkOSModel):
authenticated: bool = True
@@ -20,11 +22,15 @@ class AuthenticateWithSessionCookieSuccessResponse(WorkOSModel):
impersonator: Optional[Impersonator] = None
entitlements: Optional[List[str]] = None
+
class AuthenticateWithSessionCookieErrorResponse(WorkOSModel):
authenticated: bool = False
reason: Union[AuthenticateWithSessionCookieFailureReason, str]
-class RefreshWithSessionCookieSuccessResponse(AuthenticateWithSessionCookieSuccessResponse):
+
+class RefreshWithSessionCookieSuccessResponse(
+ AuthenticateWithSessionCookieSuccessResponse
+):
sealed_session: str
@@ -32,6 +38,7 @@ class RefreshWithSessionCookieErrorResponse(WorkOSModel):
authenticated: bool = False
reason: Union[AuthenticateWithSessionCookieFailureReason, str]
+
class SessionConfig(TypedDict):
seal_session: bool
- cookie_password: str
\ No newline at end of file
+ cookie_password: str
diff --git a/workos/user_management.py b/workos/user_management.py
index 976d36af..cdde4006 100644
--- a/workos/user_management.py
+++ b/workos/user_management.py
@@ -111,7 +111,9 @@ class UserManagementModule(Protocol):
_client_configuration: ClientConfiguration
- def load_sealed_session(self, *, sealed_session: str, cookie_password: str) -> SyncOrAsync[SessionModule]:
+ def load_sealed_session(
+ self, *, sealed_session: str, cookie_password: str
+ ) -> SyncOrAsync[SessionModule]:
"""Load a sealed session and return the session data.
Args:
@@ -822,8 +824,15 @@ def __init__(
self._client_configuration = client_configuration
self._http_client = http_client
- def load_sealed_session(self, *, session_data: str, cookie_password: str) -> SessionModule:
- return SessionModule(user_management=self, client_id=self._http_client.client_id, session_data=session_data, cookie_password=cookie_password)
+ def load_sealed_session(
+ self, *, session_data: str, cookie_password: str
+ ) -> SessionModule:
+ return SessionModule(
+ user_management=self,
+ client_id=self._http_client.client_id,
+ session_data=session_data,
+ cookie_password=cookie_password,
+ )
def get_user(self, user_id: str) -> User:
response = self._http_client.request(
@@ -1034,8 +1043,13 @@ def _authenticate_with(
json=json,
)
- if payload.get("session") is not None and payload.get("session").get("seal_session") is True:
- response["sealed_session"] = SessionModule.seal_data(response, payload.get("session").get("cookie_password"))
+ if (
+ payload.get("session") is not None
+ and payload.get("session").get("seal_session") is True
+ ):
+ response["sealed_session"] = SessionModule.seal_data(
+ response, payload.get("session").get("cookie_password")
+ )
return response_model.model_validate(response)
@@ -1066,7 +1080,11 @@ def authenticate_with_code(
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
) -> AuthKitAuthenticationResponse:
- if session is not None and (session.get("seal_session") is True and session.get("cookie_password") is None or ""):
+ if session is not None and (
+ session.get("seal_session") is True
+ and session.get("cookie_password") is None
+ or ""
+ ):
raise ValueError("cookie_password is required when sealing session")
payload: AuthenticateWithCodeParameters = {
@@ -1167,7 +1185,11 @@ def authenticate_with_refresh_token(
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
) -> RefreshTokenAuthenticationResponse:
- if session is not None and (session.get("seal_session") is True and session.get("cookie_password") is None or ""):
+ if session is not None and (
+ session.get("seal_session") is True
+ and session.get("cookie_password") is None
+ or ""
+ ):
raise ValueError("cookie_password is required when sealing session")
payload: AuthenticateWithRefreshTokenParameters = {
@@ -1623,8 +1645,13 @@ async def _authenticate_with(
json=json,
)
- if payload.get("session") is not None and payload.get("session").get("seal_session") is True:
- response["sealed_session"] = SessionModule.seal_data(response, payload.get("session").get("cookie_password"))
+ if (
+ payload.get("session") is not None
+ and payload.get("session").get("seal_session") is True
+ ):
+ response["sealed_session"] = SessionModule.seal_data(
+ response, payload.get("session").get("cookie_password")
+ )
return response_model.model_validate(response)
@@ -1657,7 +1684,11 @@ async def authenticate_with_code(
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
) -> AuthKitAuthenticationResponse:
- if session is not None and (session.get("seal_session") is True and session.get("cookie_password") is None or ""):
+ if session is not None and (
+ session.get("seal_session") is True
+ and session.get("cookie_password") is None
+ or ""
+ ):
raise ValueError("cookie_password is required when sealing session")
payload: AuthenticateWithCodeParameters = {
@@ -1766,7 +1797,11 @@ async def authenticate_with_refresh_token(
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
) -> RefreshTokenAuthenticationResponse:
- if session is not None and (session.get("seal_session") is True and session.get("cookie_password") is None or ""):
+ if session is not None and (
+ session.get("seal_session") is True
+ and session.get("cookie_password") is None
+ or ""
+ ):
raise ValueError("cookie_password is required when sealing session")
payload: AuthenticateWithRefreshTokenParameters = {
From 0b2e91ca2869f2094a25a91be9a726363a2e2b96 Mon Sep 17 00:00:00 2001
From: Paul Asjes
Date: Wed, 27 Nov 2024 13:52:07 +0100
Subject: [PATCH 08/14] Make tests and mypy happy
---
tests/test_session.py | 35 ++++++-----
workos/session.py | 30 ++++++----
workos/types/user_management/session.py | 10 ++--
workos/user_management.py | 80 +++++++++++--------------
4 files changed, 76 insertions(+), 79 deletions(-)
diff --git a/tests/test_session.py b/tests/test_session.py
index ee65f331..9c38372a 100644
--- a/tests/test_session.py
+++ b/tests/test_session.py
@@ -5,7 +5,7 @@
from datetime import datetime, timezone
from tests.conftest import with_jwks_mock
-from workos.session import SessionModule
+from workos.session import Session
from workos.types.user_management.authentication_response import (
RefreshTokenAuthenticationResponse,
)
@@ -73,7 +73,7 @@ def mock_user_management():
@with_jwks_mock
def test_initialize_session_module(TEST_CONSTANTS, mock_user_management):
- session = SessionModule(
+ session = Session(
user_management=mock_user_management,
client_id=TEST_CONSTANTS["CLIENT_ID"],
session_data=TEST_CONSTANTS["SESSION_DATA"],
@@ -87,7 +87,7 @@ def test_initialize_session_module(TEST_CONSTANTS, mock_user_management):
@with_jwks_mock
def test_initialize_without_cookie_password(TEST_CONSTANTS, mock_user_management):
with pytest.raises(ValueError, match="cookie_password is required"):
- SessionModule(
+ Session(
user_management=mock_user_management,
client_id=TEST_CONSTANTS["CLIENT_ID"],
session_data=TEST_CONSTANTS["SESSION_DATA"],
@@ -97,7 +97,7 @@ def test_initialize_without_cookie_password(TEST_CONSTANTS, mock_user_management
@with_jwks_mock
def test_authenticate_no_session_cookie_provided(TEST_CONSTANTS, mock_user_management):
- session = SessionModule(
+ session = Session(
user_management=mock_user_management,
client_id=TEST_CONSTANTS["CLIENT_ID"],
session_data=None,
@@ -114,7 +114,7 @@ def test_authenticate_no_session_cookie_provided(TEST_CONSTANTS, mock_user_manag
@with_jwks_mock
def test_authenticate_invalid_session_cookie(TEST_CONSTANTS, mock_user_management):
- session = SessionModule(
+ session = Session(
user_management=mock_user_management,
client_id=TEST_CONSTANTS["CLIENT_ID"],
session_data="invalid_session_data",
@@ -131,10 +131,10 @@ def test_authenticate_invalid_session_cookie(TEST_CONSTANTS, mock_user_managemen
@with_jwks_mock
def test_authenticate_invalid_jwt(TEST_CONSTANTS, mock_user_management):
- invalid_session_data = SessionModule.seal_data(
+ invalid_session_data = Session.seal_data(
{"access_token": "invalid_session_data"}, TEST_CONSTANTS["COOKIE_PASSWORD"]
)
- session = SessionModule(
+ session = Session(
user_management=mock_user_management,
client_id=TEST_CONSTANTS["CLIENT_ID"],
session_data=invalid_session_data,
@@ -143,12 +143,14 @@ def test_authenticate_invalid_jwt(TEST_CONSTANTS, mock_user_management):
response = session.authenticate()
+ print(response)
+
assert response.reason == AuthenticateWithSessionCookieFailureReason.INVALID_JWT
@with_jwks_mock
def test_authenticate_success(TEST_CONSTANTS, mock_user_management):
- session = SessionModule(
+ session = Session(
user_management=mock_user_management,
client_id=TEST_CONSTANTS["CLIENT_ID"],
session_data=TEST_CONSTANTS["SESSION_DATA"],
@@ -192,7 +194,7 @@ def test_authenticate_success(TEST_CONSTANTS, mock_user_management):
with (
# Mock unsealing the session data
- patch.object(SessionModule, "unseal_data", return_value=mock_session),
+ patch.object(Session, "unseal_data", return_value=mock_session),
# Mock JWT validation
patch.object(session, "is_valid_jwt", return_value=True),
# Mock JWT decoding
@@ -219,7 +221,7 @@ def test_authenticate_success(TEST_CONSTANTS, mock_user_management):
@with_jwks_mock
def test_refresh_invalid_session_cookie(TEST_CONSTANTS, mock_user_management):
- session = SessionModule(
+ session = Session(
user_management=mock_user_management,
client_id=TEST_CONSTANTS["CLIENT_ID"],
session_data="invalid_session_data",
@@ -248,7 +250,7 @@ def test_refresh_success(TEST_CONSTANTS, mock_user_management):
"updated_at": TEST_CONSTANTS["CURRENT_TIMESTAMP"],
}
- session_data = SessionModule.seal_data(
+ session_data = Session.seal_data(
{"refresh_token": "refresh_token_12345", "user": test_user},
TEST_CONSTANTS["COOKIE_PASSWORD"],
)
@@ -264,7 +266,7 @@ def test_refresh_success(TEST_CONSTANTS, mock_user_management):
RefreshTokenAuthenticationResponse(**mock_response)
)
- session = SessionModule(
+ session = Session(
user_management=mock_user_management,
client_id=TEST_CONSTANTS["CLIENT_ID"],
session_data=session_data,
@@ -303,16 +305,19 @@ def test_refresh_success(TEST_CONSTANTS, mock_user_management):
def test_seal_data(TEST_CONSTANTS):
test_data = {"test": "data"}
- sealed = SessionModule.seal_data(test_data, TEST_CONSTANTS["COOKIE_PASSWORD"])
+ sealed = Session.seal_data(test_data, TEST_CONSTANTS["COOKIE_PASSWORD"])
assert isinstance(sealed, str)
# Test unsealing
- unsealed = SessionModule.unseal_data(sealed, TEST_CONSTANTS["COOKIE_PASSWORD"])
+ unsealed = Session.unseal_data(sealed, TEST_CONSTANTS["COOKIE_PASSWORD"])
+ print(test_data)
+ print(unsealed)
+
assert unsealed == test_data
def test_unseal_invalid_data(TEST_CONSTANTS):
with pytest.raises(Exception): # Adjust exception type based on your implementation
- SessionModule.unseal_data(
+ Session.unseal_data(
"invalid_sealed_data", TEST_CONSTANTS["COOKIE_PASSWORD"]
)
diff --git a/workos/session.py b/workos/session.py
index dc4bf2d1..be48a0ff 100644
--- a/workos/session.py
+++ b/workos/session.py
@@ -1,5 +1,6 @@
+from __future__ import annotations
import json
-from typing import Any, Dict, List, Optional, Union
+from typing import Any, Dict, Optional, Union, cast
import jwt
from jwt import PyJWKClient
from cryptography.fernet import Fernet
@@ -12,12 +13,11 @@
RefreshWithSessionCookieSuccessResponse,
)
-
-class SessionModule:
+class Session:
def __init__(
self,
*,
- user_management: Any,
+ user_management: "UserManagementModule", # type: ignore
client_id: str,
session_data: str,
cookie_password: str,
@@ -42,7 +42,7 @@ def authenticate(
AuthenticateWithSessionCookieSuccessResponse,
AuthenticateWithSessionCookieErrorResponse,
]:
- if self.session_data is None:
+ if self.session_data is None or self.session_data == "":
return AuthenticateWithSessionCookieErrorResponse(
authenticated=False,
reason=AuthenticateWithSessionCookieFailureReason.NO_SESSION_COOKIE_PROVIDED,
@@ -56,7 +56,7 @@ def authenticate(
reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE,
)
- if not session["access_token"]:
+ if not session.get("access_token", None):
return AuthenticateWithSessionCookieErrorResponse(
authenticated=False,
reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE,
@@ -84,14 +84,15 @@ def authenticate(
impersonator=session.get("impersonator", None),
)
- def refresh(self, options: Optional[Dict[str, Any]] = None) -> Union[
+ def refresh(
+ self, *, organization_id: Optional[str] = None, cookie_password: Optional[str] = None
+ ) -> Union[
RefreshWithSessionCookieSuccessResponse,
RefreshWithSessionCookieErrorResponse,
]:
cookie_password = (
- self.cookie_password if options is None else options.get("cookie_password")
+ self.cookie_password if cookie_password is None else cookie_password
)
- organization_id = None if options is None else options.get("organization_id")
try:
session = self.unseal_data(self.session_data, cookie_password)
@@ -101,7 +102,7 @@ def refresh(self, options: Optional[Dict[str, Any]] = None) -> Union[
reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE,
)
- if not session["refresh_token"] or not session["user"]:
+ if not session.get("refresh_token", None) or not session.get("user", None):
return RefreshWithSessionCookieErrorResponse(
authenticated=False,
reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE,
@@ -144,12 +145,14 @@ def refresh(self, options: Optional[Dict[str, Any]] = None) -> Union[
def get_logout_url(self) -> str:
auth_response = self.authenticate()
- if not auth_response.authenticated:
+ if isinstance(auth_response, AuthenticateWithSessionCookieErrorResponse):
raise ValueError(
f"Failed to extract session ID for logout URL: {auth_response.reason}"
)
- return self.user_management.get_logout_url(session_id=auth_response.session_id)
+ result = self.user_management.get_logout_url(session_id=auth_response.session_id)
+ return str(result)
+
def is_valid_jwt(self, token: str) -> bool:
try:
@@ -171,4 +174,5 @@ def unseal_data(sealed_data: str, key: str) -> Dict[str, Any]:
fernet = Fernet(key)
# Convert string back to bytes before decryption
encrypted_bytes = sealed_data.encode("utf-8")
- return json.loads(fernet.decrypt(encrypted_bytes).decode())
+ decrypted_str = fernet.decrypt(encrypted_bytes).decode()
+ return cast(Dict[str, Any], json.loads(decrypted_str))
diff --git a/workos/types/user_management/session.py b/workos/types/user_management/session.py
index d6c859e5..7c81a0c3 100644
--- a/workos/types/user_management/session.py
+++ b/workos/types/user_management/session.py
@@ -1,6 +1,6 @@
from typing import List, Optional, TypedDict, Union
from enum import Enum
-
+from typing_extensions import Literal
from workos.types.user_management.impersonator import Impersonator
from workos.types.user_management.user import User
from workos.types.workos_model import WorkOSModel
@@ -13,7 +13,7 @@ class AuthenticateWithSessionCookieFailureReason(Enum):
class AuthenticateWithSessionCookieSuccessResponse(WorkOSModel):
- authenticated: bool = True
+ authenticated: Literal[True]
session_id: str
organization_id: Optional[str] = None
role: Optional[str] = None
@@ -24,7 +24,7 @@ class AuthenticateWithSessionCookieSuccessResponse(WorkOSModel):
class AuthenticateWithSessionCookieErrorResponse(WorkOSModel):
- authenticated: bool = False
+ authenticated: Literal[False]
reason: Union[AuthenticateWithSessionCookieFailureReason, str]
@@ -35,10 +35,10 @@ class RefreshWithSessionCookieSuccessResponse(
class RefreshWithSessionCookieErrorResponse(WorkOSModel):
- authenticated: bool = False
+ authenticated: Literal[False]
reason: Union[AuthenticateWithSessionCookieFailureReason, str]
-class SessionConfig(TypedDict):
+class SessionConfig(TypedDict, total=False):
seal_session: bool
cookie_password: str
diff --git a/workos/user_management.py b/workos/user_management.py
index cdde4006..1d2e02c1 100644
--- a/workos/user_management.py
+++ b/workos/user_management.py
@@ -1,6 +1,6 @@
-from typing import Optional, Protocol, Sequence, Set, Type
+from typing import Optional, Protocol, Sequence, Set, Type, cast
from workos._client_configuration import ClientConfiguration
-from workos.session import SessionModule
+from workos.session import Session
from workos.types.list_resource import (
ListArgs,
ListMetadata,
@@ -113,7 +113,7 @@ class UserManagementModule(Protocol):
def load_sealed_session(
self, *, sealed_session: str, cookie_password: str
- ) -> SyncOrAsync[SessionModule]:
+ ) -> SyncOrAsync[Session]:
"""Load a sealed session and return the session data.
Args:
@@ -121,7 +121,7 @@ def load_sealed_session(
cookie_password (str): The cookie password to use to decrypt the session data.
Returns:
- SessionModule: The session module.
+ Session: The session module.
"""
...
@@ -825,12 +825,12 @@ def __init__(
self._http_client = http_client
def load_sealed_session(
- self, *, session_data: str, cookie_password: str
- ) -> SessionModule:
- return SessionModule(
+ self, *, sealed_session: str, cookie_password: str
+ ) -> Session:
+ return Session(
user_management=self,
client_id=self._http_client.client_id,
- session_data=session_data,
+ session_data=sealed_session,
cookie_password=cookie_password,
)
@@ -1043,15 +1043,16 @@ def _authenticate_with(
json=json,
)
- if (
- payload.get("session") is not None
- and payload.get("session").get("seal_session") is True
- ):
- response["sealed_session"] = SessionModule.seal_data(
- response, payload.get("session").get("cookie_password")
+ response_data = dict(response)
+
+ session = cast(Optional[SessionConfig], payload.get("session", None))
+
+ if session is not None and session.get("seal_session") is True:
+ response_data["sealed_session"] = Session.seal_data(
+ response_data, str(session.get("cookie_password"))
)
- return response_model.model_validate(response)
+ return response_model.model_validate(response_data)
def authenticate_with_password(
self,
@@ -1080,11 +1081,7 @@ def authenticate_with_code(
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
) -> AuthKitAuthenticationResponse:
- if session is not None and (
- session.get("seal_session") is True
- and session.get("cookie_password") is None
- or ""
- ):
+ if session is not None and session.get("seal_session") and not session.get("cookie_password"):
raise ValueError("cookie_password is required when sealing session")
payload: AuthenticateWithCodeParameters = {
@@ -1185,11 +1182,7 @@ def authenticate_with_refresh_token(
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
) -> RefreshTokenAuthenticationResponse:
- if session is not None and (
- session.get("seal_session") is True
- and session.get("cookie_password") is None
- or ""
- ):
+ if session is not None and session.get("seal_session") and not session.get("cookie_password"):
raise ValueError("cookie_password is required when sealing session")
payload: AuthenticateWithRefreshTokenParameters = {
@@ -1273,10 +1266,7 @@ def get_magic_auth(self, magic_auth_id: str) -> MagicAuth:
return MagicAuth.model_validate(response)
def create_magic_auth(
- self,
- *,
- email: str,
- invitation_token: Optional[str] = None,
+ self, *, email: str, invitation_token: Optional[str] = None
) -> MagicAuth:
json = {
"email": email,
@@ -1435,6 +1425,11 @@ def __init__(
self._client_configuration = client_configuration
self._http_client = http_client
+ async def load_sealed_session(
+ self, *, sealed_session: str, cookie_password: str
+ ) -> Session:
+ raise NotImplementedError("Async load_sealed_session not implemented")
+
async def get_user(self, user_id: str) -> User:
response = await self._http_client.request(
USER_DETAIL_PATH.format(user_id), method=REQUEST_METHOD_GET
@@ -1645,15 +1640,16 @@ async def _authenticate_with(
json=json,
)
- if (
- payload.get("session") is not None
- and payload.get("session").get("seal_session") is True
- ):
- response["sealed_session"] = SessionModule.seal_data(
- response, payload.get("session").get("cookie_password")
+ response_data = dict(response)
+
+ session = cast(Optional[SessionConfig], payload.get("session", None))
+
+ if session is not None and session.get("seal_session") is True:
+ response_data["sealed_session"] = Session.seal_data(
+ response_data, str(session.get("cookie_password"))
)
- return response_model.model_validate(response)
+ return response_model.model_validate(response_data)
async def authenticate_with_password(
self,
@@ -1684,11 +1680,7 @@ async def authenticate_with_code(
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
) -> AuthKitAuthenticationResponse:
- if session is not None and (
- session.get("seal_session") is True
- and session.get("cookie_password") is None
- or ""
- ):
+ if session is not None and session.get("seal_session") and not session.get("cookie_password"):
raise ValueError("cookie_password is required when sealing session")
payload: AuthenticateWithCodeParameters = {
@@ -1797,11 +1789,7 @@ async def authenticate_with_refresh_token(
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
) -> RefreshTokenAuthenticationResponse:
- if session is not None and (
- session.get("seal_session") is True
- and session.get("cookie_password") is None
- or ""
- ):
+ if session is not None and session.get("seal_session") and not session.get("cookie_password"):
raise ValueError("cookie_password is required when sealing session")
payload: AuthenticateWithRefreshTokenParameters = {
From 0410381657f0c75f3450b3d8c92c975f191505e1 Mon Sep 17 00:00:00 2001
From: Paul Asjes
Date: Wed, 27 Nov 2024 13:54:32 +0100
Subject: [PATCH 09/14] make black happy too
---
tests/test_session.py | 4 +---
workos/session.py | 13 +++++++++----
workos/user_management.py | 24 ++++++++++++++++++++----
3 files changed, 30 insertions(+), 11 deletions(-)
diff --git a/tests/test_session.py b/tests/test_session.py
index 9c38372a..4646fcd0 100644
--- a/tests/test_session.py
+++ b/tests/test_session.py
@@ -318,6 +318,4 @@ def test_seal_data(TEST_CONSTANTS):
def test_unseal_invalid_data(TEST_CONSTANTS):
with pytest.raises(Exception): # Adjust exception type based on your implementation
- Session.unseal_data(
- "invalid_sealed_data", TEST_CONSTANTS["COOKIE_PASSWORD"]
- )
+ Session.unseal_data("invalid_sealed_data", TEST_CONSTANTS["COOKIE_PASSWORD"])
diff --git a/workos/session.py b/workos/session.py
index be48a0ff..5a649734 100644
--- a/workos/session.py
+++ b/workos/session.py
@@ -13,11 +13,12 @@
RefreshWithSessionCookieSuccessResponse,
)
+
class Session:
def __init__(
self,
*,
- user_management: "UserManagementModule", # type: ignore
+ user_management: "UserManagementModule", # type: ignore
client_id: str,
session_data: str,
cookie_password: str,
@@ -85,7 +86,10 @@ def authenticate(
)
def refresh(
- self, *, organization_id: Optional[str] = None, cookie_password: Optional[str] = None
+ self,
+ *,
+ organization_id: Optional[str] = None,
+ cookie_password: Optional[str] = None,
) -> Union[
RefreshWithSessionCookieSuccessResponse,
RefreshWithSessionCookieErrorResponse,
@@ -150,10 +154,11 @@ def get_logout_url(self) -> str:
f"Failed to extract session ID for logout URL: {auth_response.reason}"
)
- result = self.user_management.get_logout_url(session_id=auth_response.session_id)
+ result = self.user_management.get_logout_url(
+ session_id=auth_response.session_id
+ )
return str(result)
-
def is_valid_jwt(self, token: str) -> bool:
try:
signing_key = self.jwks.get_signing_key_from_jwt(token)
diff --git a/workos/user_management.py b/workos/user_management.py
index 1d2e02c1..3534c678 100644
--- a/workos/user_management.py
+++ b/workos/user_management.py
@@ -1081,7 +1081,11 @@ def authenticate_with_code(
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
) -> AuthKitAuthenticationResponse:
- if session is not None and session.get("seal_session") and not session.get("cookie_password"):
+ if (
+ session is not None
+ and session.get("seal_session")
+ and not session.get("cookie_password")
+ ):
raise ValueError("cookie_password is required when sealing session")
payload: AuthenticateWithCodeParameters = {
@@ -1182,7 +1186,11 @@ def authenticate_with_refresh_token(
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
) -> RefreshTokenAuthenticationResponse:
- if session is not None and session.get("seal_session") and not session.get("cookie_password"):
+ if (
+ session is not None
+ and session.get("seal_session")
+ and not session.get("cookie_password")
+ ):
raise ValueError("cookie_password is required when sealing session")
payload: AuthenticateWithRefreshTokenParameters = {
@@ -1680,7 +1688,11 @@ async def authenticate_with_code(
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
) -> AuthKitAuthenticationResponse:
- if session is not None and session.get("seal_session") and not session.get("cookie_password"):
+ if (
+ session is not None
+ and session.get("seal_session")
+ and not session.get("cookie_password")
+ ):
raise ValueError("cookie_password is required when sealing session")
payload: AuthenticateWithCodeParameters = {
@@ -1789,7 +1801,11 @@ async def authenticate_with_refresh_token(
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
) -> RefreshTokenAuthenticationResponse:
- if session is not None and session.get("seal_session") and not session.get("cookie_password"):
+ if (
+ session is not None
+ and session.get("seal_session")
+ and not session.get("cookie_password")
+ ):
raise ValueError("cookie_password is required when sealing session")
payload: AuthenticateWithRefreshTokenParameters = {
From 25e6e5afbf7bbb9429ded355abfa8e9dfe28ac19 Mon Sep 17 00:00:00 2001
From: Paul Asjes
Date: Wed, 27 Nov 2024 13:57:37 +0100
Subject: [PATCH 10/14] Forgot import
---
workos/session.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/workos/session.py b/workos/session.py
index 5a649734..55cca562 100644
--- a/workos/session.py
+++ b/workos/session.py
@@ -12,13 +12,14 @@
RefreshWithSessionCookieErrorResponse,
RefreshWithSessionCookieSuccessResponse,
)
+from workos.user_management import UserManagementModule
class Session:
def __init__(
self,
*,
- user_management: "UserManagementModule", # type: ignore
+ user_management: "UserManagementModule",
client_id: str,
session_data: str,
cookie_password: str,
From 28c213037b839f2e6889e2998727aae7e0ebb508 Mon Sep 17 00:00:00 2001
From: Paul Asjes
Date: Wed, 27 Nov 2024 14:29:31 +0100
Subject: [PATCH 11/14] Satisfy type checker
---
workos/session.py | 28 ++++++++++++++++++++--------
1 file changed, 20 insertions(+), 8 deletions(-)
diff --git a/workos/session.py b/workos/session.py
index 55cca562..df8970c1 100644
--- a/workos/session.py
+++ b/workos/session.py
@@ -1,10 +1,15 @@
from __future__ import annotations
+from typing import TYPE_CHECKING
+
import json
from typing import Any, Dict, Optional, Union, cast
import jwt
from jwt import PyJWKClient
from cryptography.fernet import Fernet
+from workos.types.user_management.authentication_response import (
+ RefreshTokenAuthenticationResponse,
+)
from workos.types.user_management.session import (
AuthenticateWithSessionCookieFailureReason,
AuthenticateWithSessionCookieSuccessResponse,
@@ -12,7 +17,9 @@
RefreshWithSessionCookieErrorResponse,
RefreshWithSessionCookieSuccessResponse,
)
-from workos.user_management import UserManagementModule
+
+if TYPE_CHECKING:
+ from workos.user_management import UserManagementModule
class Session:
@@ -114,14 +121,19 @@ def refresh(
)
try:
- auth_response = self.user_management.authenticate_with_refresh_token(
- refresh_token=session["refresh_token"],
- organization_id=organization_id,
- session={"seal_session": True, "cookie_password": cookie_password},
+ auth_response = cast(
+ RefreshTokenAuthenticationResponse,
+ self.user_management.authenticate_with_refresh_token(
+ refresh_token=session["refresh_token"],
+ organization_id=organization_id,
+ session={"seal_session": True, "cookie_password": cookie_password},
+ ),
)
- self.session_data = auth_response.sealed_session
- self.cookie_password = cookie_password
+ self.session_data = str(auth_response.sealed_session)
+ self.cookie_password = (
+ cookie_password if cookie_password is not None else self.cookie_password
+ )
signing_key = self.jwks.get_signing_key_from_jwt(auth_response.access_token)
@@ -133,7 +145,7 @@ def refresh(
return RefreshWithSessionCookieSuccessResponse(
authenticated=True,
- sealed_session=auth_response.sealed_session,
+ sealed_session=str(auth_response.sealed_session),
session_id=decoded["sid"],
organization_id=decoded.get("org_id", None),
role=decoded.get("role", None),
From 3cbc6a602471d62d9933e374928fffb7a65be297 Mon Sep 17 00:00:00 2001
From: Paul Asjes
Date: Wed, 27 Nov 2024 14:40:21 +0100
Subject: [PATCH 12/14] 3.8 compatibility and remove print statements
---
tests/test_session.py | 39 +++++++++++++--------------------------
1 file changed, 13 insertions(+), 26 deletions(-)
diff --git a/tests/test_session.py b/tests/test_session.py
index 4646fcd0..0ff10baf 100644
--- a/tests/test_session.py
+++ b/tests/test_session.py
@@ -143,8 +143,6 @@ def test_authenticate_invalid_jwt(TEST_CONSTANTS, mock_user_management):
response = session.authenticate()
- print(response)
-
assert response.reason == AuthenticateWithSessionCookieFailureReason.INVALID_JWT
@@ -192,19 +190,12 @@ def test_authenticate_success(TEST_CONSTANTS, mock_user_management):
"entitlements": ["feature_1"],
}
- with (
- # Mock unsealing the session data
- patch.object(Session, "unseal_data", return_value=mock_session),
- # Mock JWT validation
- patch.object(session, "is_valid_jwt", return_value=True),
- # Mock JWT decoding
- patch("jwt.decode", return_value=mock_jwt_payload),
- # Mock JWT signing key retrieval
- patch.object(
- session.jwks,
- "get_signing_key_from_jwt",
- return_value=Mock(key=TEST_CONSTANTS["PUBLIC_KEY"]),
- ),
+ with patch.object(Session, "unseal_data", return_value=mock_session), patch.object(
+ session, "is_valid_jwt", return_value=True
+ ), patch("jwt.decode", return_value=mock_jwt_payload), patch.object(
+ session.jwks,
+ "get_signing_key_from_jwt",
+ return_value=Mock(key=TEST_CONSTANTS["PUBLIC_KEY"]),
):
response = session.authenticate()
@@ -273,9 +264,8 @@ def test_refresh_success(TEST_CONSTANTS, mock_user_management):
cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"],
)
- with (
- patch.object(session, "is_valid_jwt", return_value=True),
- patch(
+ with patch.object(session, "is_valid_jwt", return_value=True) as _:
+ with patch(
"jwt.decode",
return_value={
"sid": TEST_CONSTANTS["SESSION_ID"],
@@ -284,13 +274,12 @@ def test_refresh_success(TEST_CONSTANTS, mock_user_management):
"permissions": ["read"],
"entitlements": ["feature_1"],
},
- ),
- ):
- response = session.refresh()
+ ):
+ response = session.refresh()
- assert isinstance(response, RefreshWithSessionCookieSuccessResponse)
- assert response.authenticated is True
- assert response.user.id == test_user["id"]
+ assert isinstance(response, RefreshWithSessionCookieSuccessResponse)
+ assert response.authenticated is True
+ assert response.user.id == test_user["id"]
# Verify the refresh token was used correctly
mock_user_management.authenticate_with_refresh_token.assert_called_once_with(
@@ -310,8 +299,6 @@ def test_seal_data(TEST_CONSTANTS):
# Test unsealing
unsealed = Session.unseal_data(sealed, TEST_CONSTANTS["COOKIE_PASSWORD"])
- print(test_data)
- print(unsealed)
assert unsealed == test_data
From 6b6c66af41859b5d72107cb27b727dd9ba77fa48 Mon Sep 17 00:00:00 2001
From: Paul Asjes
Date: Mon, 2 Dec 2024 15:35:06 +0100
Subject: [PATCH 13/14] Use sequence instead of list
---
workos/types/user_management/session.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/workos/types/user_management/session.py b/workos/types/user_management/session.py
index 7c81a0c3..76739f9d 100644
--- a/workos/types/user_management/session.py
+++ b/workos/types/user_management/session.py
@@ -1,4 +1,4 @@
-from typing import List, Optional, TypedDict, Union
+from typing import Optional, Sequence, TypedDict, Union
from enum import Enum
from typing_extensions import Literal
from workos.types.user_management.impersonator import Impersonator
@@ -17,10 +17,10 @@ class AuthenticateWithSessionCookieSuccessResponse(WorkOSModel):
session_id: str
organization_id: Optional[str] = None
role: Optional[str] = None
- permissions: Optional[List[str]] = None
+ permissions: Optional[Sequence[str]] = None
user: User
impersonator: Optional[Impersonator] = None
- entitlements: Optional[List[str]] = None
+ entitlements: Optional[Sequence[str]] = None
class AuthenticateWithSessionCookieErrorResponse(WorkOSModel):
From c9daffce40eacd3e79a3a72e8f2b3af76a83ad5a Mon Sep 17 00:00:00 2001
From: Paul Asjes
Date: Mon, 2 Dec 2024 15:47:13 +0100
Subject: [PATCH 14/14] Make is_valid_jwt private
---
tests/test_session.py | 4 ++--
workos/session.py | 4 ++--
2 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/tests/test_session.py b/tests/test_session.py
index 0ff10baf..fbb82717 100644
--- a/tests/test_session.py
+++ b/tests/test_session.py
@@ -191,7 +191,7 @@ def test_authenticate_success(TEST_CONSTANTS, mock_user_management):
}
with patch.object(Session, "unseal_data", return_value=mock_session), patch.object(
- session, "is_valid_jwt", return_value=True
+ session, "_is_valid_jwt", return_value=True
), patch("jwt.decode", return_value=mock_jwt_payload), patch.object(
session.jwks,
"get_signing_key_from_jwt",
@@ -264,7 +264,7 @@ def test_refresh_success(TEST_CONSTANTS, mock_user_management):
cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"],
)
- with patch.object(session, "is_valid_jwt", return_value=True) as _:
+ with patch.object(session, "_is_valid_jwt", return_value=True) as _:
with patch(
"jwt.decode",
return_value={
diff --git a/workos/session.py b/workos/session.py
index df8970c1..fea062ca 100644
--- a/workos/session.py
+++ b/workos/session.py
@@ -71,7 +71,7 @@ def authenticate(
reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE,
)
- if not self.is_valid_jwt(session["access_token"]):
+ if not self._is_valid_jwt(session["access_token"]):
return AuthenticateWithSessionCookieErrorResponse(
authenticated=False,
reason=AuthenticateWithSessionCookieFailureReason.INVALID_JWT,
@@ -172,7 +172,7 @@ def get_logout_url(self) -> str:
)
return str(result)
- def is_valid_jwt(self, token: str) -> bool:
+ def _is_valid_jwt(self, token: str) -> bool:
try:
signing_key = self.jwks.get_signing_key_from_jwt(token)
jwt.decode(token, signing_key.key, algorithms=self.jwk_algorithms)