diff --git a/api/environments/models.py b/api/environments/models.py index 2c3000ea2b01..225f68d21703 100644 --- a/api/environments/models.py +++ b/api/environments/models.py @@ -8,7 +8,7 @@ from django.conf import settings from django.contrib.contenttypes.fields import GenericRelation from django.core.cache import caches -from django.db import models +from django.db import connection, models, transaction from django.db.models import Max, Prefetch, Q, QuerySet from django.utils import timezone from django_lifecycle import ( # type: ignore[import-untyped] @@ -514,20 +514,40 @@ def trait_persistence_allowed(self, request: Request) -> bool: def get_segments_from_cache(self) -> typing.List[Segment]: """ Get any segments that have been overridden in this environment. + + Uses REPEATABLE READ isolation for PostgreSQL to ensure all prefetch + queries see the same database snapshot, avoiding race conditions + during concurrent segment updates. """ segments = environment_segments_cache.get(self.id) if not segments: - segments = list( - Segment.live_objects.filter( - feature_segments__feature_states__environment=self - ).prefetch_related( - "rules", - "rules__conditions", - "rules__rules", - "rules__rules__conditions", - "rules__rules__rules", - ) + # Use REPEATABLE READ isolation to ensure prefetch queries + # see a consistent snapshot, preventing race conditions where + # rules are fetched but their conditions are deleted by a + # concurrent PATCH before they can be prefetched. + # + # Only attempt to set isolation level if we're not already in an + # atomic block (nested transactions can't change isolation level). + use_repeatable_read = ( + connection.vendor == "postgresql" and not connection.in_atomic_block ) + with transaction.atomic(): + if use_repeatable_read: + with connection.cursor() as cursor: + cursor.execute( + "SET TRANSACTION ISOLATION LEVEL REPEATABLE READ" + ) + segments = list( + Segment.live_objects.filter( + feature_segments__feature_states__environment=self + ).prefetch_related( + "rules", + "rules__conditions", + "rules__rules", + "rules__rules__conditions", + "rules__rules__rules", + ) + ) environment_segments_cache.set(self.id, segments) return segments # type: ignore[no-any-return] diff --git a/api/projects/services.py b/api/projects/services.py index 6a5d752e1190..98a30816667f 100644 --- a/api/projects/services.py +++ b/api/projects/services.py @@ -3,30 +3,53 @@ from django.apps import apps from django.conf import settings from django.core.cache import caches +from django.db import connection, transaction if typing.TYPE_CHECKING: - from django.db.models import QuerySet - from segments.models import Segment project_segments_cache = caches[settings.PROJECT_SEGMENTS_CACHE_LOCATION] -def get_project_segments_from_cache(project_id: int) -> "QuerySet[Segment]": +def get_project_segments_from_cache(project_id: int) -> "list[Segment]": + """ + Get all segments for a project from cache or database. + + Uses REPEATABLE READ isolation for PostgreSQL to ensure all prefetch + queries see the same database snapshot, avoiding race conditions + during concurrent segment updates. + """ Segment = apps.get_model("segments", "Segment") segments = project_segments_cache.get(project_id) if not segments: - # This is optimised to account for rules nested one levels deep (since we - # don't support anything above that from the UI at the moment). Anything - # past that will require additional queries / thought on how to optimise. - segments = Segment.live_objects.filter(project_id=project_id).prefetch_related( - "rules", - "rules__conditions", - "rules__rules", - "rules__rules__conditions", - "rules__rules__rules", + # Use REPEATABLE READ isolation to ensure prefetch queries + # see a consistent snapshot, preventing race conditions where + # rules are fetched but their conditions are deleted by a + # concurrent PATCH before they can be prefetched. + # + # Only attempt to set isolation level if we're not already in an + # atomic block (nested transactions can't change isolation level). + use_repeatable_read = ( + connection.vendor == "postgresql" and not connection.in_atomic_block ) + with transaction.atomic(): + if use_repeatable_read: + with connection.cursor() as cursor: + cursor.execute("SET TRANSACTION ISOLATION LEVEL REPEATABLE READ") + # This is optimised to account for rules nested one levels deep + # (since we don't support anything above that from the UI at the + # moment). Anything past that will require additional queries / + # thought on how to optimise. + segments = list( + Segment.live_objects.filter(project_id=project_id).prefetch_related( + "rules", + "rules__conditions", + "rules__rules", + "rules__rules__conditions", + "rules__rules__rules", + ) + ) project_segments_cache.set( project_id, segments, timeout=settings.CACHE_PROJECT_SEGMENTS_SECONDS diff --git a/api/tests/integration/environments/identities/test_integration_segment_patch_atomic.py b/api/tests/integration/environments/identities/test_integration_segment_patch_atomic.py new file mode 100644 index 000000000000..793846553aa9 --- /dev/null +++ b/api/tests/integration/environments/identities/test_integration_segment_patch_atomic.py @@ -0,0 +1,262 @@ +import json +import threading +import time +from typing import Any, cast + +import pytest +from django.urls import reverse +from rest_framework import status +from rest_framework.test import APIClient + +from api_keys.models import MasterAPIKey +from organisations.models import Organisation + + +# This integration test reproduces a race condition during segment PATCH updates. +# When nested segment rules are deleted before new rules are created, the segment +# can temporarily evaluate as "match all". That causes identities without the +# required traits to receive an enabled feature state from a segment override. +# The test continuously PATCHes the same ruleset while polling identity feature +# states to surface any incorrect, transient enables. +@pytest.mark.django_db(transaction=True) +def test_segment_patch_atomic__looped_repro__detects_mismatch( # type: ignore[no-untyped-def] + admin_client, + admin_user, + environment, + environment_api_key, + feature, + organisation, + project, +): + # Given: a segment that should only match identities with specific traits. + # The feature override is enabled for that segment, so non-matching + # identities must always see the feature disabled. + rules_payload = _build_rules_payload() + segment_id = _create_segment( + admin_client=admin_client, + project_id=project, + rules_payload=rules_payload, + ) + feature_segment_id = _create_feature_segment( + admin_client=admin_client, + environment_id=environment, + feature_id=feature, + segment_id=segment_id, + ) + _create_feature_segment_override( + admin_client=admin_client, + environment_id=environment, + feature_id=feature, + feature_segment_id=feature_segment_id, + ) + identity_id = _create_identity( + admin_client=admin_client, + environment_api_key=environment_api_key, + identifier="disabled-identity", + ) + + # API endpoints under test: segment PATCH and identity feature state listing. + patch_url = reverse( + "api-v1:projects:project-segments-detail", + args=[project, segment_id], + ) + identity_feature_states_url = reverse( + "api-v1:environments:identity-featurestates-all", + args=(environment_api_key, identity_id), + ) + + # Use a master API key for PATCH requests so that concurrent writes + # are authenticated independently of the admin session client. + organisation_obj = Organisation.objects.get(id=organisation) + master_key_data = cast(Any, MasterAPIKey.objects).create_key( + name="test_key", + organisation=organisation_obj, + is_admin=True, + ) + _, master_key = master_key_data + patch_client = APIClient() + patch_client.credentials(HTTP_AUTHORIZATION="Api-Key " + master_key) + + # Use an authenticated admin client for polling identity feature states. + poll_client = APIClient() + poll_client.force_authenticate(user=admin_user) + + # Shared state used to coordinate the concurrent loops. + stop_event = threading.Event() + end_time = time.monotonic() + 10 + patch_errors: list[str] = [] + poll_errors: list[str] = [] + mismatches: list[dict[str, Any]] = [] + + def patch_loop() -> None: + # Repeatedly PATCH the same ruleset to simulate real-world churn in + # segment updates. This is intended to hit the race window where the + # rules are temporarily empty. + while time.monotonic() < end_time and not stop_event.is_set(): + response = patch_client.patch( + patch_url, + data=json.dumps( + { + "name": "Atomic Patch Segment", + "rules": rules_payload, + } + ), + content_type="application/json", + ) + if response.status_code != status.HTTP_200_OK: + patch_errors.append( + f"Unexpected patch response: {response.status_code}" + ) + stop_event.set() + return + + def poll_loop() -> None: + # Continuously fetch identity feature states while PATCH is running. + # Any enabled feature for the non-matching identity indicates the + # segment temporarily evaluated as true. + while time.monotonic() < end_time and not stop_event.is_set(): + response = poll_client.get(identity_feature_states_url) + if response.status_code != status.HTTP_200_OK: + poll_errors.append( + f"Unexpected feature states response: {response.status_code}" + ) + stop_event.set() + return + + response_json = response.json() + feature_state = next( + ( + feature_state + for feature_state in response_json + if feature_state["feature"]["id"] == feature + ), + None, + ) + if feature_state is None: + poll_errors.append("Feature state missing from response") + stop_event.set() + return + + if feature_state["enabled"] is True: + mismatches.append(feature_state) + stop_event.set() + return + + # When: execute concurrent PATCH and polling loops for up to 10 seconds. + patch_thread = threading.Thread(target=patch_loop) + patch_thread.start() + poll_loop() + stop_event.set() + patch_thread.join(timeout=2) + + # Then: failures indicate either bad API responses or a reproduced mismatch. + assert not patch_thread.is_alive() + assert not patch_errors + assert not poll_errors + assert not mismatches + + +def _build_rules_payload() -> list[dict[str, Any]]: + conditions = [ + { + "operator": "EQUAL", + "property": "flagEnabledId", + "value": f"enabled-{index}", + } + for index in range(10) + ] + return [ + { + "type": "ANY", + "conditions": conditions, + "rules": [], + } + ] + + +def _create_segment( + admin_client: APIClient, + project_id: int, + rules_payload: list[dict[str, Any]], +) -> int: + create_segment_url = reverse( + "api-v1:projects:project-segments-list", args=[project_id] + ) + response = admin_client.post( + create_segment_url, + data=json.dumps( + { + "name": "Atomic Patch Segment", + "project": project_id, + "rules": rules_payload, + } + ), + content_type="application/json", + ) + assert response.status_code == status.HTTP_201_CREATED + return int(response.json()["id"]) + + +def _create_feature_segment( + admin_client: APIClient, + environment_id: int, + feature_id: int, + segment_id: int, +) -> int: + create_feature_segment_url = reverse("api-v1:features:feature-segment-list") + response = admin_client.post( + create_feature_segment_url, + data=json.dumps( + { + "feature": feature_id, + "segment": segment_id, + "environment": environment_id, + } + ), + content_type="application/json", + ) + assert response.status_code == status.HTTP_201_CREATED + return int(response.json()["id"]) + + +def _create_feature_segment_override( + admin_client: APIClient, + environment_id: int, + feature_id: int, + feature_segment_id: int, +) -> None: + create_feature_state_url = reverse("api-v1:features:featurestates-list") + response = admin_client.post( + create_feature_state_url, + data=json.dumps( + { + "enabled": True, + "feature_state_value": { + "type": "unicode", + "string_value": "segment override", + }, + "feature": feature_id, + "environment": environment_id, + "feature_segment": feature_segment_id, + } + ), + content_type="application/json", + ) + assert response.status_code == status.HTTP_201_CREATED + + +def _create_identity( + admin_client: APIClient, + environment_api_key: str, + identifier: str, +) -> int: + create_identity_url = reverse( + "api-v1:environments:environment-identities-list", + args=[environment_api_key], + ) + response = admin_client.post( + create_identity_url, + data={"identifier": identifier}, + ) + assert response.status_code == status.HTTP_201_CREATED + return int(response.json()["id"]) diff --git a/api/tests/unit/projects/test_unit_projects_models.py b/api/tests/unit/projects/test_unit_projects_models.py index 5d50fcef9370..51db81265118 100644 --- a/api/tests/unit/projects/test_unit_projects_models.py +++ b/api/tests/unit/projects/test_unit_projects_models.py @@ -76,8 +76,8 @@ def test_get_segments_from_cache_set_to_empty_list( # Since we're calling the live_objects manager in the method, # only one copy of the segment should be returned, not the # other versioned copy of the segment. - assert segments.count() == 1 - assert segments.first() == segment + assert len(segments) == 1 + assert segments[0] == segment # And correct calls to cache are made mock_project_segments_cache.get.assert_called_once_with(project.id)