Skip to content

Commit

Permalink
feat: add feature flag endpoint (#531)
Browse files Browse the repository at this point in the history
* add feature endpoint

* add tests

* cleanup

* add caching

* set ttl back

* add option to skip caching

* update shared

* more cleanup

* rename RolloutIdentifier

use Feature util instead

cleanup

don't use django redis cache abstraction

* remove unused import

* update shared

* remove unused imports

* more renaming whoops

* fix tests
  • Loading branch information
daniel-codecov authored May 1, 2024
1 parent 8601246 commit 03b5754
Show file tree
Hide file tree
Showing 8 changed files with 386 additions and 7 deletions.
23 changes: 23 additions & 0 deletions api/internal/feature/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from shared.django_apps.rollouts.models import FeatureFlag, RolloutUniverse

FEATURES_CACHE_REDIS_KEY = "features_endpoint_cache"


def get_flag_cache_redis_key(flag_name):
return FEATURES_CACHE_REDIS_KEY + ":" + flag_name


def get_identifier(feature_flag: FeatureFlag, identifier_data):
"""
Returns the appropriate identifier string based on the rollout identifier type.
"""
if feature_flag.rollout_universe == RolloutUniverse.OWNER_ID:
return identifier_data["user_id"]
elif feature_flag.rollout_universe == RolloutUniverse.REPO_ID:
return identifier_data["repo_id"]
elif feature_flag.rollout_universe == RolloutUniverse.EMAIL:
return identifier_data["email"]
elif feature_flag.rollout_universe == RolloutUniverse.ORG_ID:
return identifier_data["org_id"]
else:
raise ValueError("Unknown RolloutUniverse type")
15 changes: 15 additions & 0 deletions api/internal/feature/serializers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from rest_framework import serializers


class FeatureIdentifierDataSerializer(serializers.Serializer):
email = serializers.CharField(max_length=200, allow_blank=True)
user_id = serializers.IntegerField()
repo_id = serializers.IntegerField()
org_id = serializers.IntegerField()


class FeatureRequestSerializer(serializers.Serializer):
feature_flags = serializers.ListField(
child=serializers.CharField(max_length=200), allow_empty=True
)
identifier_data = FeatureIdentifierDataSerializer()
104 changes: 104 additions & 0 deletions api/internal/feature/views.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import logging
import pickle

from rest_framework import status
from rest_framework.response import Response
from rest_framework.views import APIView
from shared.django_apps.rollouts.models import FeatureFlag
from shared.rollouts import Feature

from api.internal.feature.helpers import get_flag_cache_redis_key, get_identifier
from services.redis_configuration import get_redis_connection
from utils.config import get_config

from .serializers import FeatureRequestSerializer

log = logging.getLogger(__name__)


class FeaturesView(APIView):
skip_feature_cache = get_config("setup", "skip_feature_cache", default=False)
timeout = 300

def __init__(self, *args, **kwargs):
self.redis = get_redis_connection()
super().__init__(*args, **kwargs)

def get_many_from_redis(self, keys):
ret = self.redis.mget(keys)
return {k: pickle.loads(v) for k, v in zip(keys, ret) if v is not None}

def set_many_to_redis(self, data):
pipeline = self.redis.pipeline()
pipeline.mset({k: pickle.dumps(v) for k, v in data.items()})

# Setting timeout for each key as redis does not support timeout
# with mset().
for key in data:
pipeline.expire(key, self.timeout)
pipeline.execute()

def post(self, request):
serializer = FeatureRequestSerializer(data=request.data)
if serializer.is_valid():
flag_evaluations = {}
identifier_data = serializer.validated_data["identifier_data"]
feature_flag_names = serializer.validated_data["feature_flags"]

feature_flag_cache_keys = [
get_flag_cache_redis_key(flag_name) for flag_name in feature_flag_names
]
cache_misses = []

if not self.skip_feature_cache:
# fetch flags from cache
cached_flags = self.get_many_from_redis(feature_flag_cache_keys)

for ind in range(len(feature_flag_cache_keys)):
cache_key = feature_flag_cache_keys[ind]
flag_name = feature_flag_names[ind]

# if flag is in cache, make the evaluation. Otherwise, we'll
# fetch the flag from DB later
if cache_key in cached_flags:
feature_flag = cached_flags[cache_key]
identifier = get_identifier(feature_flag, identifier_data)

flag_evaluations[flag_name] = Feature(
flag_name, feature_flag, list(feature_flag.variants.all())
).check_value_no_fetch(identifier=identifier)
else:
cache_misses.append(flag_name)
else:
cache_misses = feature_flag_names
log.warning(
"skip_feature_cache for Feature should only be turned on in development environments, and should not be used in production"
)

flags_to_add_to_cache = {}

# fetch flags not in cache
missed_feature_flags = FeatureFlag.objects.filter(
name__in=cache_misses
).prefetch_related(
"variants"
) # include the feature flag variants aswell

# evaluate the remaining flags
for feature_flag in missed_feature_flags:
identifier = get_identifier(feature_flag, identifier_data)

flag_evaluations[feature_flag.name] = Feature(
feature_flag.name, feature_flag, list(feature_flag.variants.all())
).check_value_no_fetch(identifier=identifier)
flags_to_add_to_cache[
get_flag_cache_redis_key(feature_flag.name)
] = feature_flag

# add the new flags to cache
if len(flags_to_add_to_cache) >= 1:
self.set_many_to_redis(flags_to_add_to_cache)

return Response(flag_evaluations, status=status.HTTP_200_OK)
else:
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
230 changes: 230 additions & 0 deletions api/internal/tests/test_feature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
import json

import pytest
from django.urls import reverse
from rest_framework.test import APITestCase
from shared.django_apps.rollouts.models import (
FeatureFlag,
FeatureFlagVariant,
RolloutUniverse,
)

from codecov_auth.tests.factories import OwnerFactory
from utils.test_utils import Client


class FeatureEndpointTests(APITestCase):
def setUp(self):
self.client = Client()
self.owner = OwnerFactory(plan="users-free", plan_user_count=5)
self.client.force_login_owner(self.owner)

def send_feature_request(self, data: dict):
return self.client.post(
reverse("features"), data=json.dumps(data), content_type="application/json"
)

def test_invalid_request_body(self):
data = {
"feature_flagsssss": ["fjdsioj"],
"identifier_dataa": {
"email": "dsfio",
"user_id": 1,
"org_id": 2,
"repo_id": 3,
},
}

res = self.send_feature_request(data)
self.assertEqual(res.status_code, 400)

def test_valid_request_body(self):
data = {
"feature_flags": [],
"identifier_data": {
"email": "daniel.yu@sentry.io",
"user_id": 0,
"org_id": 0,
"repo_id": 0,
},
}

res = self.send_feature_request(data)
self.assertEqual(res.status_code, 200)

def test_variant_assigned_true(self):
feature_a = FeatureFlag.objects.create(
name="feature_a", proportion=1.0, salt="random_salt"
)
FeatureFlagVariant.objects.create(
name="enabled",
feature_flag=feature_a,
proportion=1.0,
value=True,
)

feature_b = FeatureFlag.objects.create(
name="feature_b", proportion=1.0, salt="random_salt"
)
FeatureFlagVariant.objects.create(
name="enabled",
feature_flag=feature_b,
proportion=1.0,
value=True,
)

data = {
"feature_flags": ["feature_a", "feature_b"],
"identifier_data": {
"email": "d",
"user_id": 1,
"org_id": 1,
"repo_id": 1,
},
}

res = self.send_feature_request(data)

self.assertEqual(res.status_code, 200)
self.assertEqual(res.data["feature_a"], True)
self.assertEqual(res.data["feature_b"], True)

def test_variant_assigned_false(self):
feature_aaa = FeatureFlag.objects.create(
name="feature_aaa", proportion=1.0, salt="random_salt"
)
FeatureFlagVariant.objects.create(
name="disabled",
feature_flag=feature_aaa,
proportion=1.0,
value=False,
)

feature_bbb = FeatureFlag.objects.create(
name="feature_bbb", proportion=1.0, salt="random_salt"
)
FeatureFlagVariant.objects.create(
name="disabled",
feature_flag=feature_bbb,
proportion=1.0,
value=False,
)

data = {
"feature_flags": ["feature_aaa", "feature_bbb"],
"identifier_data": {
"email": "d",
"user_id": 1,
"org_id": 1,
"repo_id": 1,
},
}

res = self.send_feature_request(data)

self.assertEqual(res.status_code, 200)
self.assertEqual(res.data["feature_aaa"], False)
self.assertEqual(res.data["feature_bbb"], False)


@pytest.mark.django_db
@pytest.mark.parametrize(
"rollout_universe,o_emails,o_owner_ids,o_repo_ids,o_org_ids,o_values",
[
(
RolloutUniverse.EMAIL,
(["david@gmail.com"], ["daniel@gmail.com"]),
([], []),
([], []),
([], []),
(1, 2),
),
(
RolloutUniverse.OWNER_ID,
([], []),
(["1"], ["2"]),
([], []),
([], []),
(3, 4),
),
(
RolloutUniverse.REPO_ID,
([], []),
([], []),
(["21"], ["31"]),
([], []),
(5, 6),
),
(
RolloutUniverse.ORG_ID,
([], []),
([], []),
([], []),
(["11"], ["21"]),
(7, 8),
),
],
)
def test_overrides_by_email(
rollout_universe, o_emails, o_owner_ids, o_repo_ids, o_org_ids, o_values
):
overrides = FeatureFlag.objects.create(
name="overrides_" + str(rollout_universe),
proportion=1.0,
rollout_universe=rollout_universe,
)
FeatureFlagVariant.objects.create(
name="overrides_a",
feature_flag=overrides,
proportion=1 / 3,
value=o_values[0],
override_emails=o_emails[0],
override_owner_ids=o_owner_ids[0],
override_repo_ids=o_repo_ids[0],
override_org_ids=o_org_ids[0],
)
FeatureFlagVariant.objects.create(
name="overrides_b",
feature_flag=overrides,
proportion=1 / 3,
value=o_values[1],
override_emails=o_emails[1],
override_owner_ids=o_owner_ids[1],
override_repo_ids=o_repo_ids[1],
override_org_ids=o_org_ids[1],
)
FeatureFlagVariant.objects.create(
name="overrides_c",
feature_flag=overrides,
proportion=1 / 3,
value="dfjosijsdiofjdos",
)

data1 = {
"feature_flags": ["overrides_" + str(rollout_universe)],
"identifier_data": {
"email": o_emails[0][0] if o_emails[0] else "",
"user_id": o_owner_ids[0][0] if o_owner_ids[0] else 0,
"org_id": o_org_ids[0][0] if o_org_ids[0] else 0,
"repo_id": o_repo_ids[0][0] if o_repo_ids[0] else 0,
},
}
mock = FeatureEndpointTests()
mock.setUp()
res1 = mock.send_feature_request(data1)

data2 = {
"feature_flags": ["overrides_" + str(rollout_universe)],
"identifier_data": {
"email": o_emails[1][0] if o_emails[1] else "",
"user_id": o_owner_ids[1][0] if o_owner_ids[1] else 0,
"org_id": o_org_ids[1][0] if o_org_ids[1] else 0,
"repo_id": o_repo_ids[1][0] if o_repo_ids[1] else 0,
},
}
res2 = mock.send_feature_request(data2)

assert res1.status_code == 200
assert res1.data["overrides_" + str(rollout_universe)] == o_values[0]
assert res2.status_code == 200
assert res2.data["overrides_" + str(rollout_universe)] == o_values[1]
2 changes: 2 additions & 0 deletions api/internal/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from api.internal.compare.views import CompareViewSet
from api.internal.coverage.views import CoverageViewSet
from api.internal.enterprise_urls import urlpatterns as enterprise_urlpatterns
from api.internal.feature.views import FeaturesView
from api.internal.owner.views import (
AccountDetailsViewSet,
InvoiceViewSet,
Expand Down Expand Up @@ -76,4 +77,5 @@
"<str:service>/<str:owner_username>/repos/<str:repo_name>/",
include(compare_router.urls),
),
path("features", FeaturesView.as_view(), name="features"),
]
Loading

0 comments on commit 03b5754

Please sign in to comment.