Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 31 additions & 11 deletions api/environments/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from django.conf import settings
from django.contrib.contenttypes.fields import GenericRelation
from django.core.cache import caches
from django.db import models
from django.db import connection, models, transaction
from django.db.models import Max, Prefetch, Q, QuerySet
from django.utils import timezone
from django_lifecycle import ( # type: ignore[import-untyped]
Expand Down Expand Up @@ -514,20 +514,40 @@ def trait_persistence_allowed(self, request: Request) -> bool:
def get_segments_from_cache(self) -> typing.List[Segment]:
"""
Get any segments that have been overridden in this environment.

Uses REPEATABLE READ isolation for PostgreSQL to ensure all prefetch
queries see the same database snapshot, avoiding race conditions
during concurrent segment updates.
"""
segments = environment_segments_cache.get(self.id)
if not segments:
segments = list(
Segment.live_objects.filter(
feature_segments__feature_states__environment=self
).prefetch_related(
"rules",
"rules__conditions",
"rules__rules",
"rules__rules__conditions",
"rules__rules__rules",
)
# Use REPEATABLE READ isolation to ensure prefetch queries
# see a consistent snapshot, preventing race conditions where
# rules are fetched but their conditions are deleted by a
# concurrent PATCH before they can be prefetched.
#
# Only attempt to set isolation level if we're not already in an
# atomic block (nested transactions can't change isolation level).
use_repeatable_read = (
connection.vendor == "postgresql" and not connection.in_atomic_block
)
with transaction.atomic():
if use_repeatable_read:
with connection.cursor() as cursor:
cursor.execute(
"SET TRANSACTION ISOLATION LEVEL REPEATABLE READ"
)
segments = list(
Segment.live_objects.filter(
feature_segments__feature_states__environment=self
).prefetch_related(
"rules",
"rules__conditions",
"rules__rules",
"rules__rules__conditions",
"rules__rules__rules",
)
)
environment_segments_cache.set(self.id, segments)
return segments # type: ignore[no-any-return]

Expand Down
47 changes: 35 additions & 12 deletions api/projects/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,53 @@
from django.apps import apps
from django.conf import settings
from django.core.cache import caches
from django.db import connection, transaction

if typing.TYPE_CHECKING:
from django.db.models import QuerySet

from segments.models import Segment

project_segments_cache = caches[settings.PROJECT_SEGMENTS_CACHE_LOCATION]


def get_project_segments_from_cache(project_id: int) -> "QuerySet[Segment]":
def get_project_segments_from_cache(project_id: int) -> "list[Segment]":
"""
Get all segments for a project from cache or database.

Uses REPEATABLE READ isolation for PostgreSQL to ensure all prefetch
queries see the same database snapshot, avoiding race conditions
during concurrent segment updates.
"""
Segment = apps.get_model("segments", "Segment")

segments = project_segments_cache.get(project_id)
if not segments:
# This is optimised to account for rules nested one levels deep (since we
# don't support anything above that from the UI at the moment). Anything
# past that will require additional queries / thought on how to optimise.
segments = Segment.live_objects.filter(project_id=project_id).prefetch_related(
"rules",
"rules__conditions",
"rules__rules",
"rules__rules__conditions",
"rules__rules__rules",
# Use REPEATABLE READ isolation to ensure prefetch queries
# see a consistent snapshot, preventing race conditions where
# rules are fetched but their conditions are deleted by a
# concurrent PATCH before they can be prefetched.
#
# Only attempt to set isolation level if we're not already in an
# atomic block (nested transactions can't change isolation level).
use_repeatable_read = (
connection.vendor == "postgresql" and not connection.in_atomic_block
)
with transaction.atomic():
if use_repeatable_read:
with connection.cursor() as cursor:
cursor.execute("SET TRANSACTION ISOLATION LEVEL REPEATABLE READ")
# This is optimised to account for rules nested one levels deep
# (since we don't support anything above that from the UI at the
# moment). Anything past that will require additional queries /
# thought on how to optimise.
segments = list(
Segment.live_objects.filter(project_id=project_id).prefetch_related(
"rules",
"rules__conditions",
"rules__rules",
"rules__rules__conditions",
"rules__rules__rules",
)
)

project_segments_cache.set(
project_id, segments, timeout=settings.CACHE_PROJECT_SEGMENTS_SECONDS
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
import json
import threading
import time
from typing import Any, cast

import pytest
from django.urls import reverse
from rest_framework import status
from rest_framework.test import APIClient

from api_keys.models import MasterAPIKey
from organisations.models import Organisation


# This integration test reproduces a race condition during segment PATCH updates.
# When nested segment rules are deleted before new rules are created, the segment
# can temporarily evaluate as "match all". That causes identities without the
# required traits to receive an enabled feature state from a segment override.
# The test continuously PATCHes the same ruleset while polling identity feature
# states to surface any incorrect, transient enables.
@pytest.mark.django_db(transaction=True)
def test_segment_patch_atomic__looped_repro__detects_mismatch( # type: ignore[no-untyped-def]
admin_client,
admin_user,
environment,
environment_api_key,
feature,
organisation,
project,
):
# Given: a segment that should only match identities with specific traits.
# The feature override is enabled for that segment, so non-matching
# identities must always see the feature disabled.
rules_payload = _build_rules_payload()
segment_id = _create_segment(
admin_client=admin_client,
project_id=project,
rules_payload=rules_payload,
)
feature_segment_id = _create_feature_segment(
admin_client=admin_client,
environment_id=environment,
feature_id=feature,
segment_id=segment_id,
)
_create_feature_segment_override(
admin_client=admin_client,
environment_id=environment,
feature_id=feature,
feature_segment_id=feature_segment_id,
)
identity_id = _create_identity(
admin_client=admin_client,
environment_api_key=environment_api_key,
identifier="disabled-identity",
)

# API endpoints under test: segment PATCH and identity feature state listing.
patch_url = reverse(
"api-v1:projects:project-segments-detail",
args=[project, segment_id],
)
identity_feature_states_url = reverse(
"api-v1:environments:identity-featurestates-all",
args=(environment_api_key, identity_id),
)

# Use a master API key for PATCH requests so that concurrent writes
# are authenticated independently of the admin session client.
organisation_obj = Organisation.objects.get(id=organisation)
master_key_data = cast(Any, MasterAPIKey.objects).create_key(
name="test_key",
organisation=organisation_obj,
is_admin=True,
)
_, master_key = master_key_data
patch_client = APIClient()
patch_client.credentials(HTTP_AUTHORIZATION="Api-Key " + master_key)

# Use an authenticated admin client for polling identity feature states.
poll_client = APIClient()
poll_client.force_authenticate(user=admin_user)

# Shared state used to coordinate the concurrent loops.
stop_event = threading.Event()
end_time = time.monotonic() + 10
patch_errors: list[str] = []
poll_errors: list[str] = []
mismatches: list[dict[str, Any]] = []

def patch_loop() -> None:
# Repeatedly PATCH the same ruleset to simulate real-world churn in
# segment updates. This is intended to hit the race window where the
# rules are temporarily empty.
while time.monotonic() < end_time and not stop_event.is_set():
response = patch_client.patch(
patch_url,
data=json.dumps(
{
"name": "Atomic Patch Segment",
"rules": rules_payload,
}
),
content_type="application/json",
)
if response.status_code != status.HTTP_200_OK:
patch_errors.append(
f"Unexpected patch response: {response.status_code}"
)
stop_event.set()
return

def poll_loop() -> None:
# Continuously fetch identity feature states while PATCH is running.
# Any enabled feature for the non-matching identity indicates the
# segment temporarily evaluated as true.
while time.monotonic() < end_time and not stop_event.is_set():
response = poll_client.get(identity_feature_states_url)
if response.status_code != status.HTTP_200_OK:
poll_errors.append(
f"Unexpected feature states response: {response.status_code}"
)
stop_event.set()
return

response_json = response.json()
feature_state = next(
(
feature_state
for feature_state in response_json
if feature_state["feature"]["id"] == feature
),
None,
)
if feature_state is None:
poll_errors.append("Feature state missing from response")
stop_event.set()
return

if feature_state["enabled"] is True:
mismatches.append(feature_state)
stop_event.set()
return

# When: execute concurrent PATCH and polling loops for up to 10 seconds.
patch_thread = threading.Thread(target=patch_loop)
patch_thread.start()
poll_loop()
stop_event.set()
patch_thread.join(timeout=2)

# Then: failures indicate either bad API responses or a reproduced mismatch.
assert not patch_thread.is_alive()
assert not patch_errors
assert not poll_errors
assert not mismatches


def _build_rules_payload() -> list[dict[str, Any]]:
conditions = [
{
"operator": "EQUAL",
"property": "flagEnabledId",
"value": f"enabled-{index}",
}
for index in range(10)
]
return [
{
"type": "ANY",
"conditions": conditions,
"rules": [],
}
]


def _create_segment(
admin_client: APIClient,
project_id: int,
rules_payload: list[dict[str, Any]],
) -> int:
create_segment_url = reverse(
"api-v1:projects:project-segments-list", args=[project_id]
)
response = admin_client.post(
create_segment_url,
data=json.dumps(
{
"name": "Atomic Patch Segment",
"project": project_id,
"rules": rules_payload,
}
),
content_type="application/json",
)
assert response.status_code == status.HTTP_201_CREATED
return int(response.json()["id"])


def _create_feature_segment(
admin_client: APIClient,
environment_id: int,
feature_id: int,
segment_id: int,
) -> int:
create_feature_segment_url = reverse("api-v1:features:feature-segment-list")
response = admin_client.post(
create_feature_segment_url,
data=json.dumps(
{
"feature": feature_id,
"segment": segment_id,
"environment": environment_id,
}
),
content_type="application/json",
)
assert response.status_code == status.HTTP_201_CREATED
return int(response.json()["id"])


def _create_feature_segment_override(
admin_client: APIClient,
environment_id: int,
feature_id: int,
feature_segment_id: int,
) -> None:
create_feature_state_url = reverse("api-v1:features:featurestates-list")
response = admin_client.post(
create_feature_state_url,
data=json.dumps(
{
"enabled": True,
"feature_state_value": {
"type": "unicode",
"string_value": "segment override",
},
"feature": feature_id,
"environment": environment_id,
"feature_segment": feature_segment_id,
}
),
content_type="application/json",
)
assert response.status_code == status.HTTP_201_CREATED


def _create_identity(
admin_client: APIClient,
environment_api_key: str,
identifier: str,
) -> int:
create_identity_url = reverse(
"api-v1:environments:environment-identities-list",
args=[environment_api_key],
)
response = admin_client.post(
create_identity_url,
data={"identifier": identifier},
)
assert response.status_code == status.HTTP_201_CREATED
return int(response.json()["id"])
4 changes: 2 additions & 2 deletions api/tests/unit/projects/test_unit_projects_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def test_get_segments_from_cache_set_to_empty_list(
# Since we're calling the live_objects manager in the method,
# only one copy of the segment should be returned, not the
# other versioned copy of the segment.
assert segments.count() == 1
assert segments.first() == segment
assert len(segments) == 1
assert segments[0] == segment

# And correct calls to cache are made
mock_project_segments_cache.get.assert_called_once_with(project.id)
Expand Down
Loading