diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 37c4d43241..6b6be33894 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -19,11 +19,11 @@ permissions: jobs: lint: name: Run Lint - uses: codecov/gha-workflows/.github/workflows/lint.yml@v1.2.21 + uses: codecov/gha-workflows/.github/workflows/lint.yml@v1.2.23 build: name: Build API - uses: codecov/gha-workflows/.github/workflows/build-app.yml@v1.2.21 + uses: codecov/gha-workflows/.github/workflows/build-app.yml@v1.2.23 secrets: inherit with: repo: ${{ vars.CODECOV_IMAGE_V2 || 'codecov/self-hosted-api' }} @@ -31,7 +31,7 @@ jobs: codecovstartup: name: Codecov Startup needs: build - uses: codecov/gha-workflows/.github/workflows/codecov-startup.yml@v1.2.21 + uses: codecov/gha-workflows/.github/workflows/codecov-startup.yml@v1.2.23 secrets: inherit # ats: @@ -47,7 +47,7 @@ jobs: test: name: Test needs: [build] - uses: codecov/gha-workflows/.github/workflows/run-tests.yml@v1.2.21 + uses: codecov/gha-workflows/.github/workflows/run-tests.yml@v1.2.23 secrets: inherit with: repo: ${{ vars.CODECOV_IMAGE_V2 || 'codecov/self-hosted-api' }} @@ -55,7 +55,7 @@ jobs: build-self-hosted: name: Build Self Hosted API needs: [build, test] - uses: codecov/gha-workflows/.github/workflows/self-hosted.yml@v1.2.21 + uses: codecov/gha-workflows/.github/workflows/self-hosted.yml@v1.2.23 secrets: inherit with: repo: ${{ vars.CODECOV_IMAGE_V2 || 'codecov/self-hosted-api' }} @@ -64,7 +64,7 @@ jobs: name: Push Staging Image needs: [build, test] if: ${{ github.event_name == 'push' && github.event.ref == 'refs/heads/staging' && github.repository_owner == 'codecov' }} - uses: codecov/gha-workflows/.github/workflows/push-env.yml@v1.2.21 + uses: codecov/gha-workflows/.github/workflows/push-env.yml@v1.2.23 secrets: inherit with: environment: staging @@ -74,7 +74,7 @@ jobs: name: Push Production Image needs: [build, test] if: ${{ github.event_name == 'push' && github.event.ref == 'refs/heads/main' && github.repository_owner == 'codecov' }} - uses: codecov/gha-workflows/.github/workflows/push-env.yml@v1.2.21 + uses: codecov/gha-workflows/.github/workflows/push-env.yml@v1.2.23 secrets: inherit with: environment: production @@ -85,7 +85,7 @@ jobs: needs: [build-self-hosted, test] secrets: inherit if: ${{ github.event_name == 'push' && github.event.ref == 'refs/heads/main' && github.repository_owner == 'codecov' }} - uses: codecov/gha-workflows/.github/workflows/self-hosted.yml@v1.2.21 + uses: codecov/gha-workflows/.github/workflows/self-hosted.yml@v1.2.23 with: push_rolling: true repo: ${{ vars.CODECOV_IMAGE_V2 || 'codecov/self-hosted-api' }} diff --git a/.github/workflows/pr_detect_shared_changes.yml b/.github/workflows/pr_detect_shared_changes.yml new file mode 100644 index 0000000000..15d4e96d9f --- /dev/null +++ b/.github/workflows/pr_detect_shared_changes.yml @@ -0,0 +1,14 @@ +name: Detect dep version changes + +on: + pull_request: + +permissions: + pull-requests: "write" + +jobs: + shared-change-checker: + name: See if shared changed + uses: codecov/gha-workflows/.github/workflows/diff-dep.yml@main + with: + dep: 'shared' diff --git a/VERSION b/VERSION index cbbcf2d0e7..2405744b2b 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -24.6.1 \ No newline at end of file +24.9.1 \ No newline at end of file diff --git a/api/internal/branch/views.py b/api/internal/branch/views.py index e28fb38761..f3a984d009 100644 --- a/api/internal/branch/views.py +++ b/api/internal/branch/views.py @@ -1,11 +1,8 @@ -from django.db.models import F, OuterRef, Subquery -from django_filters.rest_framework import DjangoFilterBackend -from rest_framework import filters, mixins, viewsets +from django.db.models import OuterRef, Subquery +from rest_framework import mixins from api.shared.branch.mixins import BranchViewSetMixin -from api.shared.mixins import RepoPropertyMixin -from api.shared.permissions import RepositoryArtifactPermissions -from core.models import Branch, Commit +from core.models import Commit from .serializers import BranchSerializer diff --git a/api/internal/chart/helpers.py b/api/internal/chart/helpers.py index 11e6d553b6..76de872e53 100644 --- a/api/internal/chart/helpers.py +++ b/api/internal/chart/helpers.py @@ -3,7 +3,7 @@ from cerberus import Validator from dateutil import parser from django.db import connection -from django.db.models import Case, F, FloatField, Value, When +from django.db.models import Case, FloatField, Value, When from django.db.models.fields.json import KeyTextTransform from django.db.models.functions import Cast, Trunc from django.utils import timezone @@ -11,7 +11,7 @@ from rest_framework.exceptions import ValidationError from codecov_auth.models import Owner -from core.models import Commit, Repository +from core.models import Repository class ChartParamValidator(Validator): diff --git a/api/internal/chart/urls.py b/api/internal/chart/urls.py index 0bbf753d89..3091df9f29 100644 --- a/api/internal/chart/urls.py +++ b/api/internal/chart/urls.py @@ -1,4 +1,4 @@ -from django.urls import path, re_path +from django.urls import re_path from .views import OrganizationChartHandler, RepositoryChartHandler diff --git a/api/internal/chart/views.py b/api/internal/chart/views.py index fdbac6222c..152a9e3765 100644 --- a/api/internal/chart/views.py +++ b/api/internal/chart/views.py @@ -6,6 +6,7 @@ from api.shared.mixins import RepositoriesMixin from api.shared.permissions import ChartPermissions from core.models import Commit +from utils import round_decimals_down from .filters import apply_default_filters, apply_simple_filters from .helpers import ( @@ -102,7 +103,9 @@ def post(self, request, *args, **kwargs): complexity = [ { "date": commit.timestamp, - "complexity_ratio": round(commit.complexity_ratio * 100, 2), + "complexity_ratio": round_decimals_down( + commit.complexity_ratio * 100, 2 + ), "commitid": commit.commitid, } for commit in commits @@ -136,7 +139,9 @@ def post(self, request, *args, **kwargs): complexity = [ { "date": commit.truncated_date, - "complexity_ratio": round(commit.complexity_ratio * 100, 2), + "complexity_ratio": round_decimals_down( + commit.complexity_ratio * 100, 2 + ), "commitid": commit.commitid, } for commit in complexity_grouped_queryset diff --git a/api/internal/enterprise_urls.py b/api/internal/enterprise_urls.py index 1fc7d509fa..cd8448b546 100644 --- a/api/internal/enterprise_urls.py +++ b/api/internal/enterprise_urls.py @@ -1,16 +1,12 @@ from django.urls import include, path -from api.internal.self_hosted.views import SettingsViewSet, UserViewSet -from utils.routers import OptionalTrailingSlashRouter, RetrieveUpdateDestroyRouter +from api.internal.self_hosted.views import UserViewSet +from utils.routers import OptionalTrailingSlashRouter self_hosted_router = OptionalTrailingSlashRouter() self_hosted_router.register(r"users", UserViewSet, basename="selfhosted-users") -settings_router = RetrieveUpdateDestroyRouter() -settings_router.register(r"settings", SettingsViewSet, basename="selfhosted-settings") - urlpatterns = [ path("license/", include("api.internal.license.urls")), - path("", include(settings_router.urls)), path("", include(self_hosted_router.urls)), ] diff --git a/api/internal/owner/serializers.py b/api/internal/owner/serializers.py index b29ceebc48..b6ceccb735 100644 --- a/api/internal/owner/serializers.py +++ b/api/internal/owner/serializers.py @@ -143,7 +143,11 @@ def validate_value(self, value): return value def validate(self, plan): - owner = self.context["view"].owner + current_org = self.context["view"].owner + if current_org.account: + raise serializers.ValidationError( + detail="You cannot update your plan manually, for help or changes to plan, connect with sales@codecov.io" + ) # Validate quantity here because we need access to whole plan object if plan["value"] in PAID_PLANS: @@ -156,16 +160,19 @@ def validate(self, plan): "Quantity for paid plan must be greater than 1" ) - plan_service = PlanService(current_org=owner) + plan_service = PlanService(current_org=current_org) is_org_trialing = plan_service.is_org_trialing - if plan["quantity"] < owner.activated_user_count and not is_org_trialing: + if ( + plan["quantity"] < current_org.activated_user_count + and not is_org_trialing + ): raise serializers.ValidationError( "Quantity cannot be lower than currently activated user count" ) if ( - plan["quantity"] == owner.plan_user_count - and plan["value"] == owner.plan + plan["quantity"] == current_org.plan_user_count + and plan["value"] == current_org.plan and not is_org_trialing ): raise serializers.ValidationError( @@ -190,6 +197,9 @@ class SubscriptionDetailSerializer(serializers.Serializer): current_period_end = serializers.IntegerField() customer = StripeCustomerSerializer() collection_method = serializers.CharField() + tax_ids = serializers.ListField( + source="customer.tax_ids.data", read_only=True, allow_null=True + ) trial_end = serializers.IntegerField() @@ -249,6 +259,10 @@ class AccountDetailsSerializer(serializers.ModelSerializer): root_organization = RootOrganizationSerializer() schedule_detail = serializers.SerializerMethodField() apply_cancellation_discount = serializers.BooleanField(write_only=True) + activated_student_count = serializers.SerializerMethodField() + activated_user_count = serializers.SerializerMethodField() + delinquent = serializers.SerializerMethodField() + uses_invoice = serializers.SerializerMethodField() class Meta: model = Owner @@ -258,21 +272,22 @@ class Meta: fields = read_only_fields + ( "activated_student_count", "activated_user_count", + "apply_cancellation_discount", "checkout_session_id", + "delinquent", "email", "inactive_user_count", "name", "nb_active_private_repos", - "plan", "plan_auto_activate", "plan_provider", - "uses_invoice", + "plan", "repo_total_credits", "root_organization", "schedule_detail", "student_count", "subscription_detail", - "apply_cancellation_discount", + "uses_invoice", ) def _get_billing(self): @@ -292,6 +307,26 @@ def get_schedule_detail(self, owner): def get_checkout_session_id(self, _): return self.context.get("checkout_session_id") + def get_activated_student_count(self, owner): + if owner.account: + return owner.account.activated_student_count + return owner.activated_student_count + + def get_activated_user_count(self, owner): + if owner.account: + return owner.account.activated_user_count + return owner.activated_user_count + + def get_delinquent(self, owner): + if owner.account: + return owner.account.is_delinquent + return owner.delinquent + + def get_uses_invoice(self, owner): + if owner.account: + return owner.account.invoice_billing.filter(is_active=True).exists() + return owner.uses_invoice + def update(self, instance, validated_data): if "pretty_plan" in validated_data: desired_plan = validated_data.pop("pretty_plan") diff --git a/api/internal/owner/views.py b/api/internal/owner/views.py index 76a4653321..96fe636844 100644 --- a/api/internal/owner/views.py +++ b/api/internal/owner/views.py @@ -1,12 +1,12 @@ import logging -from dataclasses import asdict from django.db.models import F from django_filters import rest_framework as django_filters from rest_framework import filters, mixins, status, viewsets from rest_framework.decorators import action -from rest_framework.exceptions import NotFound, PermissionDenied, ValidationError +from rest_framework.exceptions import PermissionDenied, ValidationError from rest_framework.response import Response +from shared.django_apps.codecov_auth.models import Owner from api.shared.mixins import OwnerPropertyMixin from api.shared.owner.mixins import OwnerViewSetMixin, UserViewSetMixin @@ -19,7 +19,6 @@ from .serializers import ( AccountDetailsSerializer, OwnerSerializer, - StripeInvoiceSerializer, UserSerializer, ) @@ -30,31 +29,6 @@ class OwnerViewSet(OwnerViewSetMixin, mixins.RetrieveModelMixin): serializer_class = OwnerSerializer -class InvoiceViewSet( - viewsets.GenericViewSet, - mixins.ListModelMixin, - mixins.RetrieveModelMixin, - OwnerPropertyMixin, -): - serializer_class = StripeInvoiceSerializer - permission_classes = [MemberOfOrgPermissions] - pagination_class = None - - def get_queryset(self): - return BillingService( - requesting_user=self.request.current_owner - ).list_filtered_invoices(self.owner, 100) - - def get_object(self): - invoice_id = self.kwargs.get("pk") - invoice = BillingService( - requesting_user=self.request.current_owner - ).get_invoice(self.owner, invoice_id) - if not invoice: - raise NotFound(f"Invoice {invoice_id} does not exist for that account") - return invoice - - class AccountDetailsViewSet( viewsets.GenericViewSet, mixins.UpdateModelMixin, @@ -82,6 +56,14 @@ def destroy(self, request, *args, **kwargs): return Response(status=status.HTTP_204_NO_CONTENT) def get_object(self): + if self.owner.account: + # gets the related account and invoice_billing objects from db in 1 query + # otherwise, each reference to owner.account would be an additional query + self.owner = ( + Owner.objects.filter(pk=self.owner.ownerid) + .select_related("account__invoice_billing") + .first() + ) return self.owner @action(detail=False, methods=["patch"]) @@ -109,12 +91,25 @@ def update_email(self, request, *args, **kwargs): @action(detail=False, methods=["patch"]) @stripe_safe def update_billing_address(self, request, *args, **kwargs): + name = request.data.get("name") + if not name: + raise ValidationError(detail="No name sent") billing_address = request.data.get("billing_address") if not billing_address: raise ValidationError(detail="No billing_address sent") owner = self.get_object() + + formatted_address = { + "line1": billing_address["line_1"], + "line2": billing_address["line_2"], + "city": billing_address["city"], + "state": billing_address["state"], + "postal_code": billing_address["postal_code"], + "country": billing_address["country"], + } + billing = BillingService(requesting_user=request.current_owner) - billing.update_billing_address(owner, billing_address) + billing.update_billing_address(owner, name, billing_address=formatted_address) return Response(self.get_serializer(owner).data) diff --git a/api/internal/repo/repository_actions.py b/api/internal/repo/repository_actions.py deleted file mode 100644 index fc51bd8ac1..0000000000 --- a/api/internal/repo/repository_actions.py +++ /dev/null @@ -1,64 +0,0 @@ -import asyncio -import logging - -from asgiref.sync import async_to_sync -from django.conf import settings - -from utils.config import get_config -from webhook_handlers.constants import ( - BitbucketWebhookEvents, - GitHubWebhookEvents, - GitLabWebhookEvents, -) - -log = logging.getLogger(__name__) - - -WEBHOOK_EVENTS = { - "github": GitHubWebhookEvents.repository_events, - "github_enterprise": [ - "pull_request", - "delete", - "push", - "public", - "status", - "repository", - ], - "bitbucket": BitbucketWebhookEvents.subscribed_events, - # https://confluence.atlassian.com/bitbucketserver/post-service-webhook-for-bitbucket-server-776640367.html - "bitbucket_server": [], - "gitlab": GitLabWebhookEvents.subscribed_events, - "gitlab_enterprise": GitLabWebhookEvents.subscribed_events, -} - - -@async_to_sync -async def delete_webhook_on_provider(repository_service, repo): - """ - Deletes webhook on provider - """ - return await repository_service.delete_webhook(hookid=repo.hookid) - - -def create_webhook_on_provider(repository_service, repo): - """ - Creates webhook on provider - """ - - webhook_url = settings.WEBHOOK_URL - - log.info( - "Resetting webhook with webhook url: %s" - % f"{webhook_url}/webhooks/{repository_service.service}" - ) - - return async_to_sync(repository_service.post_webhook)( - f"Codecov Webhook. {webhook_url}", - f"{webhook_url}/webhooks/{repository_service.service}", - WEBHOOK_EVENTS[repository_service.service], - get_config( - repository_service.service, - "webhook_secret", - default="testixik8qdauiab1yiffydimvi72ekq", - ), - ).get("id") diff --git a/api/internal/repo/utils.py b/api/internal/repo/utils.py deleted file mode 100644 index a3835c5628..0000000000 --- a/api/internal/repo/utils.py +++ /dev/null @@ -1,9 +0,0 @@ -from django.conf import settings -from shared.encryption.yaml_secret import yaml_secret_encryptor - - -def encode_secret_string(value): - ## Reminder -- this should probably be rewritten to reuse the same code - ## as in the new worker, whenever the API starts using the new worker. - encryptor = yaml_secret_encryptor - return "secret:%s" % encryptor.encode(value).decode() diff --git a/api/internal/repo/views.py b/api/internal/repo/views.py index 90178b1ca5..ff39f9ff62 100644 --- a/api/internal/repo/views.py +++ b/api/internal/repo/views.py @@ -1,28 +1,18 @@ import logging -import uuid from django.utils import timezone from django_filters import rest_framework as django_filters -from rest_framework import filters, mixins, status, viewsets -from rest_framework.decorators import action +from rest_framework import filters, mixins from rest_framework.exceptions import PermissionDenied -from rest_framework.response import Response from api.internal.repo.filter import RepositoryOrderingFilter from api.shared.repo.filter import RepositoryFilters from api.shared.repo.mixins import RepositoryViewSetMixin -from services.decorators import torngit_safe -from services.repo_providers import RepoProviderService -from services.task import TaskService -from .repository_actions import create_webhook_on_provider, delete_webhook_on_provider from .serializers import ( RepoDetailsSerializer, - RepoSerializer, RepoWithMetricsSerializer, - SecretStringPayloadSerializer, ) -from .utils import encode_secret_string log = logging.getLogger(__name__) @@ -89,59 +79,3 @@ def perform_update(self, serializer): if owner.has_legacy_plan and owner.repo_credits <= 0: raise PermissionDenied("Private repository limit reached.") return super().perform_update(serializer) - - @action(detail=True, methods=["patch"], url_path="regenerate-upload-token") - def regenerate_upload_token(self, request, *args, **kwargs): - repo = self.get_object() - repo.upload_token = uuid.uuid4() - repo.save() - return Response(self.get_serializer(repo).data) - - @action(detail=True, methods=["patch"]) - def erase(self, request, *args, **kwargs): - self._assert_is_admin() - repo = self.get_object() - TaskService().delete_timeseries(repository_id=repo.repoid) - TaskService().flush_repo(repository_id=repo.repoid) - return Response(RepoSerializer(repo).data) - - @action(detail=True, methods=["post"]) - def encode(self, request, *args, **kwargs): - serializer = SecretStringPayloadSerializer(data=request.data) - serializer.is_valid(raise_exception=True) - - owner, repo = self.owner, self.get_object() - - to_encode = "/".join( - ( - owner.service, - owner.service_id, - repo.service_id, - serializer.validated_data["value"], - ) - ) - - return Response( - SecretStringPayloadSerializer( - {"value": encode_secret_string(to_encode)} - ).data, - status=status.HTTP_201_CREATED, - ) - - @action(detail=True, methods=["put"], url_path="reset-webhook") - @torngit_safe - def reset_webhook(self, request, *args, **kwargs): - repo = self.get_object() - repository_service = RepoProviderService().get_adapter( - self.request.current_owner, repo - ) - - if repo.hookid: - delete_webhook_on_provider(repository_service, repo) - repo.hookid = None - repo.save() - - repo.hookid = create_webhook_on_provider(repository_service, repo) - repo.save() - - return Response(self.get_serializer(repo).data, status=status.HTTP_200_OK) diff --git a/api/internal/self_hosted/serializers.py b/api/internal/self_hosted/serializers.py index 24f5ce178e..384a02315c 100644 --- a/api/internal/self_hosted/serializers.py +++ b/api/internal/self_hosted/serializers.py @@ -32,20 +32,3 @@ def update(self, instance, validated_data): # re-query for object to get updated `activated` value return self.context["view"].get_queryset().filter(pk=instance.pk).first() - - -class SettingsSerializer(serializers.Serializer): - # this name is used to be consistent with org-level auto activation - plan_auto_activate = serializers.BooleanField() - - seats_used = serializers.IntegerField() - seats_limit = serializers.IntegerField() - - def update(self, instance, validated_data): - if "plan_auto_activate" in validated_data: - if validated_data["plan_auto_activate"] is True: - self_hosted.enable_autoactivation() - else: - self_hosted.disable_autoactivation() - - return self.context["view"]._get_settings() diff --git a/api/internal/self_hosted/views.py b/api/internal/self_hosted/views.py index 29bccd38a3..246030bcbb 100644 --- a/api/internal/self_hosted/views.py +++ b/api/internal/self_hosted/views.py @@ -11,7 +11,7 @@ from .filters import UserFilters from .permissions import AdminPermissions -from .serializers import SettingsSerializer, UserSerializer +from .serializers import UserSerializer class UserViewSet( @@ -63,31 +63,3 @@ def current_update(self, request): serializer.is_valid(raise_exception=True) serializer.save() return Response(serializer.data) - - -class SettingsViewSet( - viewsets.GenericViewSet, - mixins.RetrieveModelMixin, - mixins.UpdateModelMixin, -): - serializer_class = SettingsSerializer - permission_classes = [AdminPermissions] - - def retrieve(self, request, *args, **kwargs): - serializer = self.get_serializer(self._get_settings()) - return Response(serializer.data) - - def update(self, request, *args, **kwargs): - serializer = self.get_serializer( - self._get_settings(), data=request.data, partial=True - ) - serializer.is_valid(raise_exception=True) - serializer.save() - return Response(serializer.data) - - def _get_settings(self): - return { - "plan_auto_activate": self_hosted.is_autoactivation_enabled(), - "seats_used": self_hosted.activated_owners().count(), - "seats_limit": self_hosted.license_seats(), - } diff --git a/api/internal/tests/test_charts.py b/api/internal/tests/test_charts.py index f17ec27341..56f42b426d 100644 --- a/api/internal/tests/test_charts.py +++ b/api/internal/tests/test_charts.py @@ -1,4 +1,4 @@ -from datetime import date, datetime, time, timedelta +from datetime import datetime, timedelta from decimal import Decimal from math import isclose from random import randint @@ -751,7 +751,7 @@ def test_first_complete_commit_date_returns_date_of_first_complete_commit_in_rep ) self.user.permission = [repo1.repoid, repo2.repoid] self.user.save() - older_incomplete_commit = G( + G( model=Commit, repository=repo1, branch=repo1.branch, @@ -765,9 +765,7 @@ def test_first_complete_commit_date_returns_date_of_first_complete_commit_in_rep state="complete", timestamp=timezone.now() - timedelta(days=3), ) - commit2 = G( - model=Commit, repository=repo2, branch=repo2.branch, state="complete" - ) + G(model=Commit, repository=repo2, branch=repo2.branch, state="complete") qr = ChartQueryRunner( self.user, diff --git a/api/internal/tests/test_pagination.py b/api/internal/tests/test_pagination.py index b72cd96150..b774eda77d 100644 --- a/api/internal/tests/test_pagination.py +++ b/api/internal/tests/test_pagination.py @@ -1,4 +1,3 @@ -import pytest from rest_framework.reverse import reverse from rest_framework.test import APITestCase diff --git a/api/internal/tests/test_permissions.py b/api/internal/tests/test_permissions.py index 99bc29aa09..b06e28e30a 100644 --- a/api/internal/tests/test_permissions.py +++ b/api/internal/tests/test_permissions.py @@ -2,7 +2,6 @@ from django.test import TestCase, override_settings from rest_framework.exceptions import APIException -from rest_framework.test import APIRequestFactory from api.internal.tests.test_utils import ( GetAdminErrorProviderAdapter, @@ -193,5 +192,5 @@ def test_is_admin_on_provider_handles_torngit_exception(self, mock_get_provider) org = OwnerFactory() user = OwnerFactory() - with self.assertRaises(APIException) as e: + with self.assertRaises(APIException): self.permissions_class._is_admin_on_provider(user, org) diff --git a/api/internal/tests/test_views.py b/api/internal/tests/test_views.py index 0f80053589..94f8b7ac4d 100644 --- a/api/internal/tests/test_views.py +++ b/api/internal/tests/test_views.py @@ -193,7 +193,7 @@ def test_get_pulls_no_head_commit_returns_null_for_head_totals(self, mock_provid assert response.status_code == status.HTTP_200_OK assert [p for p in response.data["results"] if p["pullid"] == 13][0][ "head_totals" - ] == None + ] is None def test_get_pulls_no_base_commit_returns_null_for_base_totals(self, mock_provider): mock_provider.return_value = True, True @@ -213,7 +213,7 @@ def test_get_pulls_no_base_commit_returns_null_for_base_totals(self, mock_provid assert response.status_code == status.HTTP_200_OK assert [p for p in response.data["results"] if p["pullid"] == 13][0][ "base_totals" - ] == None + ] is None def test_get_pulls_as_inactive_user_returns_403(self, mock_provider): self.org.plan = "users-inappm" @@ -336,9 +336,7 @@ def setUp(self): other_org = OwnerFactory(username="other_org") # Create different types of repos / pulls repo = RepositoryFactory(author=self.org, name="testRepoName", active=True) - other_repo = RepositoryFactory( - author=other_org, name="otherRepoName", active=True - ) + RepositoryFactory(author=other_org, name="otherRepoName", active=True) repo_with_permission = [repo.repoid] self.current_owner = OwnerFactory( username="codecov-user", @@ -389,7 +387,6 @@ def test_get_pull_no_permissions(self, mock_provider): self.assertEqual(response.status_code, 404) def test_get_pull_as_inactive_user_returns_403(self, mock_provider): - mock_provider = True, True self.org.plan = "users-inappm" self.org.plan_auto_activate = False self.org.save() @@ -533,7 +530,7 @@ def test_get_commits(self, mock_provider): "complexity": 3.0, "complexity_total": 5.0, "complexity_ratio": 60.0, - "coverage": 79.17, + "coverage": 79.16, "diff": 0, "files": 3, "hits": 19, @@ -569,7 +566,7 @@ def test_get_commits(self, mock_provider): "complexity": 2.0, "complexity_total": 5.0, "complexity_ratio": 40.0, - "coverage": 79.17, + "coverage": 79.16, "diff": 0, "files": 3, "hits": 19, @@ -631,7 +628,6 @@ def test_fetch_commits_no_permissions(self, mock_provider): assert response.status_code == 404 def test_fetch_commits_inactive_user_returns_403(self, mock_provider): - mock_provider = True, True self.org.plan = "users-inappm" self.org.plan_auto_activate = False self.org.save() diff --git a/api/internal/tests/unit/views/test_compare_flags_view.py b/api/internal/tests/unit/views/test_compare_flags_view.py index b8a3867e99..6573d846da 100644 --- a/api/internal/tests/unit/views/test_compare_flags_view.py +++ b/api/internal/tests/unit/views/test_compare_flags_view.py @@ -92,7 +92,7 @@ def test_compare_flags___success( "complexity": 0, "complexity_total": 0, "complexity_ratio": 0, - "coverage": 79.17, + "coverage": 79.16, "diff": 0, "files": 3, "hits": 19, @@ -143,7 +143,7 @@ def test_compare_flags___success( "complexity": 0, "complexity_total": 0, "complexity_ratio": 0, - "coverage": 79.17, + "coverage": 79.16, "diff": 0, "files": 3, "hits": 19, @@ -283,7 +283,7 @@ def test_compare_flags_with_report_with_cff_and_non_cff( "complexity": 0, "complexity_total": 0, "complexity_ratio": 0, - "coverage": 79.17, + "coverage": 79.16, "diff": 0, "files": 3, "hits": 19, @@ -380,7 +380,7 @@ def test_compare_flags_doesnt_crash_if_base_doesnt_have_flags( diff_totals_mock.return_value = ReportTotals() # should not crash - response = self._get_compare_flags( + self._get_compare_flags( kwargs={ "service": self.repo.author.service, "owner_username": self.repo.author.username, diff --git a/api/internal/tests/unit/views/test_compare_view.py b/api/internal/tests/unit/views/test_compare_view.py index f5fe027551..74f13f4418 100644 --- a/api/internal/tests/unit/views/test_compare_view.py +++ b/api/internal/tests/unit/views/test_compare_view.py @@ -46,7 +46,6 @@ def build_commits(client): return repo, commit_base, commit_head -@patch("services.comparison.Comparison.has_unmerged_base_commits", lambda self: False) @patch("services.archive.ArchiveService.read_chunks", lambda obj, sha: "") @patch( "api.shared.repo.repository_accessors.RepoAccessors.get_repo_permissions", diff --git a/api/internal/tests/views/cassetes/test_account_viewset/AccountViewSetTests/test_update_payment_method.yaml b/api/internal/tests/views/cassetes/test_account_viewset/AccountViewSetTests/test_update_payment_method.yaml new file mode 100644 index 0000000000..acef70d627 --- /dev/null +++ b/api/internal/tests/views/cassetes/test_account_viewset/AccountViewSetTests/test_update_payment_method.yaml @@ -0,0 +1,91 @@ +interactions: +- request: + body: default_payment_method=pm_123 + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate + Connection: + - keep-alive + Content-Length: + - '29' + Content-Type: + - application/x-www-form-urlencoded + Idempotency-Key: + - 7c40f9e9-3a01-4109-87bf-3218dfbcf27f + Stripe-Version: + - '2024-04-10' + User-Agent: + - Stripe/v1 PythonBindings/9.6.0 + X-Stripe-Client-User-Agent: + - '{"bindings_version": "9.6.0", "lang": "python", "publisher": "stripe", "httplib": + "requests", "lang_version": "3.12.4", "platform": "Linux-6.6.31-linuxkit-aarch64-with-glibc2.36", + "uname": "Linux 2b87f96d1995 6.6.31-linuxkit #1 SMP Thu May 23 08:36:57 UTC + 2024 aarch64 "}' + method: POST + uri: https://api.stripe.com/v1/subscriptions/djfos + response: + body: + string: "{\n \"error\": {\n \"code\": \"resource_missing\",\n \"doc_url\": + \"https://stripe.com/docs/error-codes/resource-missing\",\n \"message\": + \"No such PaymentMethod: 'pm_123'\",\n \"param\": \"default_payment_method\",\n + \ \"request_log_url\": \"https://dashboard.stripe.com/test/logs/req_xT5h1VWY7P75Lu?t=1719007484\",\n + \ \"type\": \"invalid_request_error\"\n }\n}\n" + headers: + Access-Control-Allow-Credentials: + - 'true' + Access-Control-Allow-Methods: + - GET,HEAD,PUT,PATCH,POST,DELETE + Access-Control-Allow-Origin: + - '*' + Access-Control-Expose-Headers: + - Request-Id, Stripe-Manage-Version, Stripe-Should-Retry, X-Stripe-External-Auth-Required, + X-Stripe-Privileged-Session-Required + Access-Control-Max-Age: + - '300' + Cache-Control: + - no-cache, no-store + Connection: + - keep-alive + Content-Length: + - '346' + Content-Security-Policy: + - report-uri https://q.stripe.com/csp-report?p=v1%2Fsubscriptions%2F%3Asubscription_exposed_id; + block-all-mixed-content; default-src 'none'; base-uri 'none'; form-action + 'none'; frame-ancestors 'none'; img-src 'self'; script-src 'self' 'report-sample'; + style-src 'self' + Content-Type: + - application/json + Cross-Origin-Opener-Policy-Report-Only: + - same-origin; report-to="coop" + Date: + - Fri, 21 Jun 2024 22:04:44 GMT + Idempotency-Key: + - 7c40f9e9-3a01-4109-87bf-3218dfbcf27f + Original-Request: + - req_xT5h1VWY7P75Lu + Report-To: + - '{"group":"coop","max_age":8640,"endpoints":[{"url":"https://q.stripe.com/coop-report?s=billing-api-srv"}],"include_subdomains":true}' + Reporting-Endpoints: + - coop="https://q.stripe.com/coop-report?s=billing-api-srv" + Request-Id: + - req_xT5h1VWY7P75Lu + Server: + - nginx + Strict-Transport-Security: + - max-age=63072000; includeSubDomains; preload + Stripe-Version: + - '2024-04-10' + Vary: + - Origin + X-Content-Type-Options: + - nosniff + X-Stripe-Priority-Routing-Enabled: + - 'true' + X-Stripe-Routing-Context-Priority-Tier: + - api-testmode + status: + code: 400 + message: Bad Request +version: 1 diff --git a/api/internal/tests/views/test_account_viewset.py b/api/internal/tests/views/test_account_viewset.py index 5205de4d11..5fb5ccbdfc 100644 --- a/api/internal/tests/views/test_account_viewset.py +++ b/api/internal/tests/views/test_account_viewset.py @@ -8,6 +8,10 @@ from rest_framework import status from rest_framework.reverse import reverse from rest_framework.test import APITestCase +from shared.django_apps.codecov_auth.tests.factories import ( + AccountFactory, + InvoiceBillingFactory, +) from stripe import StripeError from api.internal.tests.test_utils import GetAdminProviderAdapter @@ -170,6 +174,7 @@ def test_retrieve_account_gets_account_fields(self): "student_count": 0, "schedule_detail": None, "uses_invoice": False, + "delinquent": None, } @patch("services.billing.stripe.SubscriptionSchedule.retrieve") @@ -190,6 +195,7 @@ def test_retrieve_account_gets_account_fields_when_there_are_scheduled_details( "latest_invoice": None, "schedule_id": "sub_sched_456", "collection_method": "charge_automatically", + "tax_ids": None, } mock_retrieve_subscription.return_value = MockSubscription(subscription_params) @@ -244,6 +250,7 @@ def test_retrieve_account_gets_account_fields_when_there_are_scheduled_details( "customer": {"id": "cus_LK&*Hli8YLIO", "discount": None, "email": None}, "collection_method": "charge_automatically", "trial_end": None, + "tax_ids": None, }, "checkout_session_id": None, "name": owner.name, @@ -262,6 +269,7 @@ def test_retrieve_account_gets_account_fields_when_there_are_scheduled_details( }, }, "uses_invoice": False, + "delinquent": None, } @patch("services.billing.stripe.SubscriptionSchedule.retrieve") @@ -283,6 +291,7 @@ def test_retrieve_account_returns_last_phase_when_more_than_one_scheduled_phases "schedule_id": "sub_sched_456678999", "collection_method": "charge_automatically", "trial_end": 1633512445, + "tax_ids": None, } mock_retrieve_subscription.return_value = MockSubscription(subscription_params) @@ -344,6 +353,7 @@ def test_retrieve_account_returns_last_phase_when_more_than_one_scheduled_phases "customer": {"id": "cus_LK&*Hli8YLIO", "discount": None, "email": None}, "collection_method": "charge_automatically", "trial_end": 1633512445, + "tax_ids": None, }, "checkout_session_id": None, "name": owner.name, @@ -362,6 +372,7 @@ def test_retrieve_account_returns_last_phase_when_more_than_one_scheduled_phases }, }, "uses_invoice": False, + "delinquent": None, } @patch("services.billing.stripe.Subscription.retrieve") @@ -381,6 +392,7 @@ def test_retrieve_account_gets_none_for_schedule_details_when_schedule_is_nonexi "latest_invoice": None, "schedule_id": None, "collection_method": "charge_automatically", + "tax_ids": None, } mock_retrieve_subscription.return_value = MockSubscription(subscription_params) @@ -415,6 +427,7 @@ def test_retrieve_account_gets_none_for_schedule_details_when_schedule_is_nonexi "customer": {"id": "cus_LK&*Hli8YLIO", "discount": None, "email": None}, "collection_method": "charge_automatically", "trial_end": None, + "tax_ids": None, }, "checkout_session_id": None, "name": owner.name, @@ -426,6 +439,7 @@ def test_retrieve_account_gets_none_for_schedule_details_when_schedule_is_nonexi "student_count": 0, "schedule_detail": None, "uses_invoice": False, + "delinquent": None, } def test_retrieve_account_gets_account_students(self): @@ -435,8 +449,8 @@ def test_retrieve_account_gets_account_students(self): ) self.current_owner.organizations = [owner.ownerid] self.current_owner.save() - student_1 = OwnerFactory(organizations=[owner.ownerid], student=True) - student_2 = OwnerFactory(organizations=[owner.ownerid], student=True) + OwnerFactory(organizations=[owner.ownerid], student=True) + OwnerFactory(organizations=[owner.ownerid], student=True) response = self._retrieve( kwargs={"service": owner.service, "owner_username": owner.username} ) @@ -459,6 +473,7 @@ def test_retrieve_account_gets_account_students(self): "student_count": 3, "schedule_detail": None, "uses_invoice": False, + "delinquent": None, } def test_account_with_free_user_plan(self): @@ -563,6 +578,7 @@ def test_retrieve_subscription_with_stripe_invoice_data(self, mock_subscription) "latest_invoice": json.load(f)["data"][0], "schedule_id": None, "collection_method": "charge_automatically", + "tax_ids": None, } mock_subscription.return_value = MockSubscription(subscription_params) @@ -588,6 +604,7 @@ def test_retrieve_subscription_with_stripe_invoice_data(self, mock_subscription) "customer": {"id": "cus_LK&*Hli8YLIO", "discount": None, "email": None}, "collection_method": "charge_automatically", "trial_end": None, + "tax_ids": None, } @patch("services.billing.stripe.Subscription.retrieve") @@ -642,6 +659,26 @@ def test_update_can_set_plan_auto_activate_to_false(self): assert self.current_owner.plan_auto_activate is False assert response.data["plan_auto_activate"] is False + def test_update_can_set_plan_auto_activate_on_org_with_account(self): + self.current_owner.account = AccountFactory() + self.current_owner.plan_auto_activate = True + self.current_owner.save() + + response = self._update( + kwargs={ + "service": self.current_owner.service, + "owner_username": self.current_owner.username, + }, + data={"plan_auto_activate": False}, + ) + + assert response.status_code == status.HTTP_200_OK + + self.current_owner.refresh_from_db() + + assert self.current_owner.plan_auto_activate is False + assert response.data["plan_auto_activate"] is False + def test_update_can_set_plan_to_users_basic(self): self.current_owner.plan = PlanName.CODECOV_PRO_YEARLY_LEGACY.value self.current_owner.save() @@ -714,6 +751,7 @@ def test_update_can_upgrade_to_paid_plan_for_existing_customer_and_set_plan_info "latest_invoice": json.load(f)["data"][0], "schedule_id": None, "collection_method": "charge_automatically", + "tax_ids": None, } retrieve_subscription_mock.return_value = MockSubscription(subscription_params) @@ -924,6 +962,48 @@ def test_update_must_fail_if_team_plan_and_too_many_users(self): == "Quantity for Team plan cannot exceed 10" ) + def test_update_quantity_must_fail_if_account(self): + desired_plans = [ + {"quantity": 10}, + ] + self.current_owner.account = AccountFactory() + self.current_owner.save() + for desired_plan in desired_plans: + response = self._update( + kwargs={ + "service": self.current_owner.service, + "owner_username": self.current_owner.username, + }, + data={"plan": desired_plan}, + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert ( + str(response.data["plan"]["non_field_errors"][0]) + == "You cannot update your plan manually, for help or changes to plan, connect with sales@codecov.io" + ) + + def test_update_plan_must_fail_if_account(self): + desired_plans = [ + {"value": PlanName.CODECOV_PRO_YEARLY.value}, + ] + self.current_owner.account = AccountFactory() + self.current_owner.save() + for desired_plan in desired_plans: + response = self._update( + kwargs={ + "service": self.current_owner.service, + "owner_username": self.current_owner.username, + }, + data={"plan": desired_plan}, + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert ( + str(response.data["plan"]["non_field_errors"][0]) + == "You cannot update your plan manually, for help or changes to plan, connect with sales@codecov.io" + ) + def test_update_quantity_must_be_at_least_2_if_paid_plan(self): desired_plan = {"value": PlanName.CODECOV_PRO_YEARLY.value, "quantity": 1} response = self._update( @@ -952,8 +1032,13 @@ def test_update_payment_method_without_body(self): @patch("services.billing.stripe.Subscription.retrieve") @patch("services.billing.stripe.PaymentMethod.attach") @patch("services.billing.stripe.Customer.modify") + @patch("services.billing.stripe.Subscription.modify") def test_update_payment_method( - self, modify_customer_mock, attach_payment_mock, retrieve_subscription_mock + self, + modify_subscription_mock, + modify_customer_mock, + attach_payment_mock, + retrieve_subscription_mock, ): self.current_owner.stripe_customer_id = "flsoe" self.current_owner.stripe_subscription_id = "djfos" @@ -977,6 +1062,7 @@ def test_update_payment_method( "latest_invoice": json.load(f)["data"][0], "schedule_id": None, "collection_method": "charge_automatically", + "tax_ids": None, } retrieve_subscription_mock.return_value = MockSubscription(subscription_params) @@ -998,6 +1084,11 @@ def test_update_payment_method( invoice_settings={"default_payment_method": payment_method_id}, ) + modify_subscription_mock.assert_called_once_with( + self.current_owner.stripe_subscription_id, + default_payment_method=payment_method_id, + ) + @patch("services.billing.StripeService.update_payment_method") def test_update_payment_method_handles_stripe_error(self, upm_mock): code, message = 402, "Oops, nope" @@ -1077,6 +1168,34 @@ def test_update_billing_address_without_body(self): response = self.client.patch(url, format="json") assert response.status_code == status.HTTP_400_BAD_REQUEST + def test_update_billing_address_without_name(self): + kwargs = { + "service": self.current_owner.service, + "owner_username": self.current_owner.username, + } + billing_address = { + "line_1": "45 Fremont St.", + "line_2": "", + "city": "San Francisco", + "state": "CA", + "country": "US", + "postal_code": "94105", + } + data = {"billing_address": billing_address} + url = reverse("account_details-update-billing-address", kwargs=kwargs) + response = self.client.patch(url, data=data, format="json") + assert response.status_code == status.HTTP_400_BAD_REQUEST + + def test_update_billing_address_without_address(self): + kwargs = { + "service": self.current_owner.service, + "owner_username": self.current_owner.username, + } + data = {"name": "John Doe"} + url = reverse("account_details-update-billing-address", kwargs=kwargs) + response = self.client.patch(url, data=data, format="json") + assert response.status_code == status.HTTP_400_BAD_REQUEST + @patch("services.billing.StripeService.update_billing_address") def test_update_billing_address_handles_stripe_error(self, stripe_mock): code, message = 402, "Oops, nope" @@ -1098,18 +1217,27 @@ def test_update_billing_address_handles_stripe_error(self, stripe_mock): "service": self.current_owner.service, "owner_username": self.current_owner.username, } - data = {"billing_address": billing_address} + data = {"name": "John Doe", "billing_address": billing_address} url = reverse("account_details-update-billing-address", kwargs=kwargs) response = self.client.patch(url, data=data, format="json") assert response.status_code == code assert response.data["detail"] == message @patch("services.billing.stripe.Subscription.retrieve") + @patch("services.billing.stripe.Customer.retrieve") + @patch("services.billing.stripe.PaymentMethod.modify") @patch("services.billing.stripe.Customer.modify") - def test_update_billing_address(self, modify_customer_mock, retrieve_mock): + def test_update_billing_address( + self, + modify_customer_mock, + modify_payment_mock, + retrieve_customer_mock, + retrieve_sub_mock, + ): self.current_owner.stripe_customer_id = "flsoe" self.current_owner.stripe_subscription_id = "djfos" self.current_owner.save() + f = open("./services/tests/samples/stripe_invoice.json") billing_address = { "line_1": "45 Fremont St.", @@ -1119,17 +1247,51 @@ def test_update_billing_address(self, modify_customer_mock, retrieve_mock): "country": "US", "postal_code": "94105", } + + formatted_address = { + "line1": "45 Fremont St.", + "line2": "", + "city": "San Francisco", + "state": "CA", + "country": "US", + "postal_code": "94105", + } + + default_payment_method = { + "id": "pm_123", + "card": { + "brand": "visa", + "exp_month": 12, + "exp_year": 2024, + "last4": "abcd", + }, + } + + subscription_params = { + "default_payment_method": default_payment_method, + "cancel_at_period_end": False, + "current_period_end": 1633512445, + "latest_invoice": json.load(f)["data"][0], + "schedule_id": None, + "collection_method": "charge_automatically", + "tax_ids": None, + } + + retrieve_sub_mock.return_value = MockSubscription(subscription_params) + kwargs = { "service": self.current_owner.service, "owner_username": self.current_owner.username, } - data = {"billing_address": billing_address} + data = {"name": "John Doe", "billing_address": billing_address} url = reverse("account_details-update-billing-address", kwargs=kwargs) response = self.client.patch(url, data=data, format="json") assert response.status_code == status.HTTP_200_OK + retrieve_customer_mock.assert_called_once() + modify_payment_mock.assert_called_once() modify_customer_mock.assert_called_once_with( - self.current_owner.stripe_customer_id, address=billing_address + self.current_owner.stripe_customer_id, address=formatted_address ) @patch("api.shared.permissions.get_provider") @@ -1321,6 +1483,7 @@ def test_update_apply_cancellation_discount( "duration_in_months": 6, "created": int(datetime(2023, 1, 1, 0, 0, 0).timestamp()), }, + "tax_ids": None, } retrieve_subscription_mock.return_value = MockSubscription(subscription_params) @@ -1366,6 +1529,7 @@ def test_update_apply_cancellation_discount_yearly( "latest_invoice": None, "schedule_id": None, "collection_method": "charge_automatically", + "tax_ids": None, } retrieve_subscription_mock.return_value = MockSubscription(subscription_params) @@ -1381,7 +1545,7 @@ def test_update_apply_cancellation_discount_yearly( assert not modify_customer_mock.called assert not coupon_create_mock.called assert response.status_code == status.HTTP_200_OK - assert response.json()["subscription_detail"]["customer"]["discount"] == None + assert response.json()["subscription_detail"]["customer"]["discount"] is None @patch("services.task.TaskService.delete_owner") def test_destroy_triggers_delete_owner_task(self, delete_owner_mock): @@ -1401,6 +1565,103 @@ def test_destroy_not_own_account_returns_404(self): ) assert response.status_code == status.HTTP_404_NOT_FOUND + def test_retrieve_org_with_account(self): + account = AccountFactory( + name="Hello World", + plan_seat_count=5, + free_seat_count=3, + plan="users-enterprisey", + is_delinquent=False, + ) + InvoiceBillingFactory(is_active=True, account=account) + org_1 = OwnerFactory( + account=account, + service=Service.GITHUB.value, + username="Test", + delinquent=True, + uses_invoice=False, + ) + org_2 = OwnerFactory( + account=account, + service=Service.GITHUB.value, + ) + activated_owner = OwnerFactory( + user=UserFactory(), organizations=[org_1.ownerid, org_2.ownerid] + ) + account.users.add(activated_owner.user) + student_owner = OwnerFactory( + user=UserFactory(), + student=True, + organizations=[org_1.ownerid, org_2.ownerid], + ) + account.users.add(student_owner.user) + other_activated_owner = OwnerFactory( + user=UserFactory(), organizations=[org_2.ownerid] + ) + account.users.add(other_activated_owner.user) + other_student_owner = OwnerFactory( + user=UserFactory(), + student=True, + organizations=[org_2.ownerid], + ) + account.users.add(other_student_owner.user) + org_1.plan_activated_users = [activated_owner.ownerid, student_owner.ownerid] + org_1.admins = [activated_owner.ownerid] + org_1.save() + org_2.plan_activated_users = [ + activated_owner.ownerid, + student_owner.ownerid, + other_activated_owner.ownerid, + other_student_owner.ownerid, + ] + org_2.save() + + self.client.force_login_owner(activated_owner) + response = self._retrieve( + kwargs={"service": Service.GITHUB.value, "owner_username": org_1.username} + ) + assert response.status_code == status.HTTP_200_OK + # these fields are all overridden by account fields if the org has an account + self.assertEqual(org_1.activated_user_count, 1) + self.assertEqual(org_1.activated_student_count, 1) + self.assertTrue(org_1.delinquent) + self.assertFalse(org_1.uses_invoice) + self.assertEqual(org_1.plan_user_count, 1) + expected_response = { + "activated_user_count": 2, + "activated_student_count": 2, + "delinquent": False, + "uses_invoice": True, + "plan": { + "marketing_name": "Enterprise Cloud", + "value": PlanName.ENTERPRISE_CLOUD_YEARLY.value, + "billing_rate": "annually", + "base_unit_price": 10, + "benefits": [ + "Configurable # of users", + "Unlimited public repositories", + "Unlimited private repositories", + "Priority Support", + ], + "quantity": 5, + }, + "root_organization": None, + "integration_id": org_1.integration_id, + "plan_auto_activate": org_1.plan_auto_activate, + "inactive_user_count": 0, + "subscription_detail": None, + "checkout_session_id": None, + "name": org_1.name, + "email": org_1.email, + "nb_active_private_repos": 0, + "repo_total_credits": 99999999, + "plan_provider": org_1.plan_provider, + "student_count": 1, + "schedule_detail": None, + } + self.assertDictEqual(response.data["plan"], expected_response["plan"]) + self.assertDictEqual(response.data, expected_response) + @override_settings(IS_ENTERPRISE=True) class EnterpriseAccountViewSetTests(APITestCase): diff --git a/api/internal/tests/views/test_compare_viewset.py b/api/internal/tests/views/test_compare_viewset.py index 00a1326d39..c5eba93847 100644 --- a/api/internal/tests/views/test_compare_viewset.py +++ b/api/internal/tests/views/test_compare_viewset.py @@ -51,13 +51,12 @@ async def get_authenticated(self): return False, False -@patch("services.comparison.Comparison.has_unmerged_base_commits", lambda self: True) @patch("services.comparison.Comparison.head_report", new_callable=PropertyMock) @patch("services.comparison.Comparison.base_report", new_callable=PropertyMock) @patch("services.repo_providers.RepoProviderService.get_adapter") class TestCompareViewSetRetrieve(APITestCase): """ - Tests for retrieving a comparison. Does not test data that will be depracated, + Tests for retrieving a comparison. Does not test data that will be deprecated, eg base and head report fields. Tests for commits etc will be added as the compare-api refactor progresses. """ @@ -223,7 +222,6 @@ def test_returns_200_and_expected_files_on_success( assert response.status_code == status.HTTP_200_OK assert response.data["files"] == self.expected_files - assert response.data["has_unmerged_base_commits"] is True def test_returns_404_if_base_or_head_references_not_found( self, adapter_mock, base_report_mock, head_report_mock @@ -330,12 +328,12 @@ def test_diffs_larger_than_MAX_DIFF_SIZE_doesnt_include_lines( assert response.status_code == status.HTTP_200_OK assert ( - response.data["files"][0]["lines"] == None + response.data["files"][0]["lines"] is None ) # None means diff was truncated comparison.MAX_DIFF_SIZE = previous_max - def test_file_returns_comparefile_with_diff_and_src_data( + def test_file_returns_compare_file_with_diff_and_src_data( self, adapter_mock, base_report_mock, head_report_mock ): base_report_mock.return_value = self.base_report @@ -397,7 +395,7 @@ def test_missing_base_report_returns_none_base_totals( response = self._get_comparison() assert response.status_code == status.HTTP_200_OK - assert response.data["totals"]["base"] == None + assert response.data["totals"]["base"] is None def test_no_raw_reports_returns_404( self, adapter_mock, base_report_mock, head_report_mock @@ -468,46 +466,3 @@ def test_pull_request_pseudo_comparison_can_update_base_report( assert response.status_code == status.HTTP_200_OK assert response.data["files"] == self.expected_files - - @patch("redis.Redis.get", lambda self, key: None) - @patch("redis.Redis.set", lambda self, key, val, ex: None) - @patch( - "services.comparison.PullRequestComparison.pseudo_diff_adjusts_tracked_lines", - new_callable=PropertyMock, - ) - @patch( - "services.comparison.PullRequestComparison.allow_coverage_offsets", - new_callable=PropertyMock, - ) - @patch( - "services.comparison.PullRequestComparison.update_base_report_with_pseudo_diff" - ) - def test_pull_request_pseudo_comparison_returns_error_if_coverage_offsets_not_allowed( - self, - update_base_report_mock, - allow_coverage_offsets_mock, - pseudo_diff_adjusts_tracked_lines_mock, - adapter_mock, - base_report_mock, - head_report_mock, - ): - adapter_mock.return_value = self.mocked_compare_adapter - base_report_mock.return_value = self.base_report - head_report_mock.return_value = self.head_report - - pseudo_diff_adjusts_tracked_lines_mock.return_value = True - allow_coverage_offsets_mock.return_value = False - - response = self._get_comparison( - query_params={ - "pullid": PullFactory( - base=self.base.commitid, - head=self.head.commitid, - compared_to=self.base.commitid, - pullid=2, - repository=self.repo, - ).pullid - } - ) - - assert response.status_code == status.HTTP_400_BAD_REQUEST diff --git a/api/internal/tests/views/test_coverage_viewset.py b/api/internal/tests/views/test_coverage_viewset.py index 6772a9f371..6a8d2b3d3a 100644 --- a/api/internal/tests/views/test_coverage_viewset.py +++ b/api/internal/tests/views/test_coverage_viewset.py @@ -76,7 +76,7 @@ def setUp(self): self.commit3 = CommitFactory( author=self.current_owner, repository=self.repo, - branch=self.branch, + branch=self.branch.name, ) with connection.cursor() as cursor: cursor.execute( diff --git a/api/internal/tests/views/test_invoice_viewset.py b/api/internal/tests/views/test_invoice_viewset.py deleted file mode 100644 index 29cfed4936..0000000000 --- a/api/internal/tests/views/test_invoice_viewset.py +++ /dev/null @@ -1,132 +0,0 @@ -import json -import os -from unittest.mock import patch - -from rest_framework import status -from rest_framework.reverse import reverse -from rest_framework.test import APITestCase -from stripe import InvalidRequestError, StripeError - -from api.internal.tests.test_utils import GetAdminProviderAdapter -from codecov_auth.tests.factories import OwnerFactory -from utils.test_utils import Client - -curr_path = os.path.dirname(__file__) - - -class InvoiceViewSetTests(APITestCase): - def setUp(self): - self.service = "gitlab" - self.current_owner = OwnerFactory(stripe_customer_id="1000") - self.expected_invoice = { - "number": "EF0A41E-0001", - "status": "paid", - "id": "in_19yTU92eZvKYlo2C7uDjvu6v", - "created": 1489789429, - "period_start": 1487370220, - "period_end": 1489789420, - "due_date": None, - "customer_name": "Peer Company", - "customer_address": "6639 Boulevard Dr, Westwood FL 34202 USA", - "currency": "usd", - "amount_paid": 999, - "amount_due": 999, - "amount_remaining": 0, - "total": 999, - "subtotal": 999, - "invoice_pdf": "https://pay.stripe.com/invoice/acct_1032D82eZvKYlo2C/invst_a7KV10HpLw2QxrihgVyuOkOjMZ/pdf", - "line_items": [ - { - "description": "(10) users-pr-inappm", - "amount": 120, - "quantity": 1, - "currency": "usd", - "plan_name": "users-pr-inappm", - "period": {"end": 1521326190, "start": 1518906990}, - } - ], - "footer": None, - "customer_email": "olivia.williams.03@example.com", - "customer_shipping": None, - } - - self.client = Client() - self.client.force_login_owner(self.current_owner) - - def _list(self, kwargs): - return self.client.get(reverse("invoices-list", kwargs=kwargs)) - - def _retrieve(self, kwargs): - return self.client.get(reverse("invoices-detail", kwargs=kwargs)) - - @patch("services.billing.stripe.Invoice.list") - def test_invoices_returns_100_recent_invoices(self, mock_list_filtered_invoices): - with open("./services/tests/samples/stripe_invoice.json") as f: - stripe_invoice_response = json.load(f) - # make it so there's 100 invoices, which is the max stripe returns - stripe_invoice_response["data"] = stripe_invoice_response["data"] * 100 - mock_list_filtered_invoices.return_value = stripe_invoice_response - expected_invoices = [self.expected_invoice] * 100 - - response = self._list( - kwargs={ - "service": self.current_owner.service, - "owner_username": self.current_owner.username, - } - ) - - assert response.status_code == status.HTTP_200_OK - assert len(response.data) == 100 - assert response.data == expected_invoices - - @patch("api.shared.permissions.get_provider") - def test_invoices_returns_404_if_user_not_admin(self, get_provider_mock): - get_provider_mock.return_value = GetAdminProviderAdapter() - owner = OwnerFactory() - response = self._list( - kwargs={"service": owner.service, "owner_username": owner.username} - ) - assert response.status_code == status.HTTP_404_NOT_FOUND - - @patch("services.billing.stripe.Invoice.retrieve") - def test_invoice(self, mock_retrieve_invoice): - with open("./services/tests/samples/stripe_invoice.json") as f: - stripe_invoice_response = json.load(f) - invoice = stripe_invoice_response["data"][0] - invoice["customer"] = self.current_owner.stripe_customer_id - mock_retrieve_invoice.return_value = invoice - response = self._retrieve( - kwargs={ - "service": self.current_owner.service, - "owner_username": self.current_owner.username, - "pk": invoice["id"], - } - ) - assert response.status_code == status.HTTP_200_OK - assert response.data == self.expected_invoice - - @patch("services.billing.stripe.Invoice.retrieve") - def test_when_invoice_not_found(self, mock_retrieve_invoice): - mock_retrieve_invoice.side_effect = InvalidRequestError( - message="not found", param="abc" - ) - response = self._retrieve( - kwargs={ - "service": self.current_owner.service, - "owner_username": self.current_owner.username, - "pk": "abc", - } - ) - assert response.status_code == status.HTTP_404_NOT_FOUND - - @patch("services.billing.stripe.Invoice.retrieve") - def test_when_no_customer_dont_match(self, mock_retrieve_invoice): - mock_retrieve_invoice.return_value = {"customer": "123456789"} - response = self._retrieve( - kwargs={ - "service": self.current_owner.service, - "owner_username": self.current_owner.username, - "pk": "abc", - } - ) - assert response.status_code == status.HTTP_404_NOT_FOUND diff --git a/api/internal/tests/views/test_owner_viewset.py b/api/internal/tests/views/test_owner_viewset.py index 144f9d88f7..36eeb74523 100644 --- a/api/internal/tests/views/test_owner_viewset.py +++ b/api/internal/tests/views/test_owner_viewset.py @@ -1,6 +1,5 @@ -from unittest.mock import patch - from rest_framework import status +from rest_framework.exceptions import ErrorDetail from rest_framework.reverse import reverse from rest_framework.test import APITestCase @@ -50,11 +49,19 @@ def test_retrieve_returns_owner_with_period_username(self): def test_retrieve_returns_404_if_no_matching_username(self): response = self._retrieve(kwargs={"service": "github", "owner_username": "fff"}) assert response.status_code == status.HTTP_404_NOT_FOUND - assert response.data == {"detail": "Not found."} + assert response.data == { + "detail": ErrorDetail( + string="No Owner matches the given query.", code="not_found" + ) + } def test_retrieve_owner_unknown_service_returns_404(self): response = self._retrieve( kwargs={"service": "not-real", "owner_username": "anything"} ) assert response.status_code == status.HTTP_404_NOT_FOUND - assert response.data == {"detail": "Service not found: not-real"} + assert response.data == { + "detail": ErrorDetail( + string="Service not found: not-real", code="not_found" + ) + } diff --git a/api/internal/tests/views/test_repo_view.py b/api/internal/tests/views/test_repo_view.py index b34aebf9ab..d910557a6b 100644 --- a/api/internal/tests/views/test_repo_view.py +++ b/api/internal/tests/views/test_repo_view.py @@ -1,7 +1,6 @@ import json from unittest.mock import Mock, patch -from django.test import override_settings from django.utils import timezone from rest_framework.reverse import reverse from shared.torngit.exceptions import TorngitClientGeneralError @@ -57,38 +56,6 @@ def _destroy(self, kwargs={}): } return self.client.delete(reverse("repos-detail", kwargs=kwargs)) - def _regenerate_upload_token(self, kwargs={}): - if kwargs == {}: - kwargs = { - "service": self.org.service, - "owner_username": self.org.username, - "repo_name": self.repo.name, - } - return self.client.patch( - reverse("repos-regenerate-upload-token", kwargs=kwargs) - ) - - def _erase(self, kwargs={}): - if kwargs == {}: - kwargs = { - "service": self.org.service, - "owner_username": self.org.username, - "repo_name": self.repo.name, - } - return self.client.patch(reverse("repos-erase", kwargs=kwargs)) - - def _encode(self, kwargs, data): - return self.client.post(reverse("repos-encode", kwargs=kwargs), data=data) - - def _reset_webhook(self, kwargs={}): - if kwargs == {}: - kwargs = { - "service": self.org.service, - "owner_username": self.org.username, - "repo_name": self.repo.name, - } - return self.client.put(reverse("repos-reset-webhook", kwargs=kwargs)) - class TestRepositoryViewSetList(RepositoryViewSetTestSuite): def setUp(self): @@ -334,7 +301,7 @@ def test_get_active_repos(self): ) def test_get_inactive_repos(self): - new_repo = RepositoryFactory(author=self.org, name="C", private=False) + RepositoryFactory(author=self.org, name="C", private=False) response = self._list(query_params={"active": False}) self.assertEqual(response.status_code, 200) @@ -345,7 +312,7 @@ def test_get_inactive_repos(self): ) def test_get_all_repos(self): - new_repo = RepositoryFactory(author=self.org, name="C", private=False) + RepositoryFactory(author=self.org, name="C", private=False) response = self._list() self.assertEqual(response.status_code, 200) @@ -356,7 +323,7 @@ def test_get_all_repos(self): ) def test_get_all_repos_by_name(self): - new_repo = RepositoryFactory(author=self.org, name="C", private=False) + RepositoryFactory(author=self.org, name="C", private=False) response = self._list(query_params={"names": ["A", "B"]}) self.assertEqual(response.status_code, 200) @@ -686,39 +653,6 @@ def test_destroy_repo_as_inactive_user_returns_403(self, mocked_get_permissions) assert response.data["detail"] == "User not activated" assert Repository.objects.filter(name="repo1").exists() - def test_regenerate_upload_token_with_permissions_succeeds( - self, mocked_get_permissions - ): - mocked_get_permissions.return_value = True, True - old_upload_token = self.repo.upload_token - - response = self._regenerate_upload_token() - - assert response.status_code == 200 - self.repo.refresh_from_db() - assert str(self.repo.upload_token) == response.data["upload_token"] - assert str(self.repo.upload_token) != old_upload_token - - def test_regenerate_upload_token_without_permissions_returns_403( - self, mocked_get_permissions - ): - mocked_get_permissions.return_value = False, False - response = self._regenerate_upload_token() - self.assertEqual(response.status_code, 403) - - def test_regenerate_upload_token_as_inactive_user_returns_403( - self, mocked_get_permissions - ): - mocked_get_permissions.return_value = True, True - self.org.plan = "users-inappy" - self.org.plan_auto_activate = False - self.org.save() - - response = self._regenerate_upload_token() - - assert response.status_code == 403 - assert response.data["detail"] == "User not activated" - def test_update_default_branch_with_permissions_succeeds( self, mocked_get_permissions ): @@ -740,93 +674,10 @@ def test_update_default_branch_without_write_permissions_returns_403( self, mocked_get_permissions ): mocked_get_permissions.return_value = True, False - new_default_branch = "no_write_permissions" response = self._update(data={"branch": "dev"}) self.assertEqual(response.status_code, 403) - @patch("services.task.TaskService.delete_timeseries") - @patch("services.task.TaskService.flush_repo") - def test_erase_triggers_task( - self, mocked_flush_repo, mocked_delete_timeseries, mocked_get_permissions - ): - mocked_get_permissions.return_value = True, True - self.org.admins = [self.current_owner.ownerid] - self.org.save() - - response = self._erase() - assert response.status_code == 200 - - mocked_flush_repo.assert_called_once_with(repository_id=self.repo.pk) - mocked_delete_timeseries.assert_called_once_with(repository_id=self.repo.pk) - - @patch("api.shared.permissions.get_provider") - def test_erase_without_admin_rights_returns_403( - self, mocked_get_provider, mocked_get_permissions - ): - mocked_get_provider.return_value = GetAdminProviderAdapter() - mocked_get_permissions.return_value = True, True - - assert self.current_owner.ownerid not in self.org.admins - - response = self._erase() - assert response.status_code == 403 - - def test_erase_as_inactive_user_returns_403(self, mocked_get_permissions): - mocked_get_permissions.return_value = True, True - self.org.plan = "users-inappy" - self.org.plan_auto_activate = False - self.org.admins = [self.current_owner.ownerid] - self.org.save() - - response = self._erase() - - assert response.status_code == 403 - assert response.data["detail"] == "User not activated" - - @override_settings(IS_ENTERPRISE=True) - @patch("api.shared.repo.mixins.RepositoryViewSetMixin.get_object") - @patch("services.self_hosted.get_config") - @patch("services.task.TaskService.delete_timeseries") - @patch("services.task.TaskService.flush_repo") - def test_erase_as_admin_self_hosted( - self, - mocked_flush_repo, - mocked_delete_timeseries, - mocked_get_config, - mocked_get_object, - mocked_get_permissions, - ): - mocked_get_permissions.return_value = True, True - self.org.admins = [self.current_owner.ownerid] - self.org.save() - - mocked_get_config.return_value = [ - {"service": "github", "username": "codecov-user"}, - ] - mocked_get_object.return_value = self.repo - - response = self._erase() - assert response.status_code == 200 - - mocked_flush_repo.assert_called_once_with(repository_id=self.repo.pk) - mocked_delete_timeseries.assert_called_once_with(repository_id=self.repo.pk) - - @override_settings(IS_ENTERPRISE=True) - @patch("services.self_hosted.get_config") - @patch("api.shared.permissions.get_provider") - def test_erase_as_non_admin_self_hosted( - self, mocked_get_provider, mocked_get_config, mocked_get_permissions - ): - mocked_get_provider.return_value = GetAdminProviderAdapter() - mocked_get_config.return_value = [ - {"service": "github", "username": "someone-else"}, - ] - mocked_get_permissions.return_value = True, True - - response = self._erase() - assert response.status_code == 403 - def test_retrieve_returns_yaml(self, mocked_get_permissions): mocked_get_permissions.return_value = True, False @@ -851,71 +702,13 @@ def test_activation_checks_if_credits_available_for_legacy_users( name=str(i) + "random", author=self.org, private=True, active=True ) - inactive_repo = RepositoryFactory(author=self.org, private=True, active=False) + RepositoryFactory(author=self.org, private=True, active=False) activation_data = {"active": True} response = self._update(data=activation_data) assert response.status_code == 403 - def test_encode_returns_200_on_success(self, mocked_get_permissions): - mocked_get_permissions.return_value = True, True - - to_encode = {"value": "hjrok"} - response = self._encode( - kwargs={ - "service": self.org.service, - "owner_username": self.org.username, - "repo_name": self.repo.name, - }, - data=to_encode, - ) - - assert response.status_code == 201 - - @patch("api.internal.repo.views.encode_secret_string") - def test_encode_returns_encoded_string_on_success( - self, encoder_mock, mocked_get_permissions - ): - mocked_get_permissions.return_value = True, True - encrypted_string = "string:encrypted string" - encoder_mock.return_value = encrypted_string - - to_encode = {"value": "hjrok"} - response = self._encode( - kwargs={ - "service": self.org.service, - "owner_username": self.org.username, - "repo_name": self.repo.name, - }, - data=to_encode, - ) - - assert response.status_code == 201 - assert response.data["value"] == encrypted_string - - def test_encode_secret_string_encodes_with_right_key(self, _): - from api.internal.repo.utils import encode_secret_string - - string_arg = "hi there" - to_encode = "/".join( - ( # this is the format expected by the key - self.org.service, - self.org.service_id, - self.repo.service_id, - string_arg, - ) - ) - - from shared.encryption.yaml_secret import yaml_secret_encryptor - - check_encryptor = yaml_secret_encryptor - - encoded = encode_secret_string(to_encode) - - # we slice to take off the word "secret" prepended by the util - assert check_encryptor.decode(encoded[7:]) == to_encode - def test_repo_bot_returns_username_if_bot_not_null(self, mocked_get_permissions): mocked_get_permissions.return_value = True, True username = "huecoTanks" @@ -935,74 +728,6 @@ def test_retrieve_with_no_commits_doesnt_crash(self, mocked_get_permissions): response = self._retrieve() assert response.status_code == 200 - @patch("api.internal.repo.views.delete_webhook_on_provider") - @patch("api.internal.repo.views.create_webhook_on_provider") - def test_reset_webhook_unsets_original_hookid_and_sets_new_if_hookid_exists( - self, create_webhook_mock, delete_webhook_mock, mocked_get_permissions - ): - mocked_get_permissions.return_value = True, True - - old_webhook_id, new_webhook_id = "123", "456" - - self.repo.hookid = old_webhook_id - self.repo.save() - - create_webhook_mock.return_value = new_webhook_id - - response = self._reset_webhook() - - assert response.status_code == 200 - - self.repo.refresh_from_db() - - assert self.repo.hookid == new_webhook_id - - @patch("api.internal.repo.views.delete_webhook_on_provider") - @patch("api.internal.repo.views.create_webhook_on_provider") - def test_reset_webhook_doesnt_delete_if_no_hookid( - self, create_webhook_mock, delete_webhook_mock, mocked_get_permissions - ): - mocked_get_permissions.return_value = True, True - create_webhook_mock.return_value = "irrelevant" - - # Make delete function throw exception, so if it's called this test fails - delete_webhook_mock.side_effect = Exception( - "Attempted to delete nonexistent webhook" - ) - - response = self._reset_webhook() - - @patch("api.internal.repo.views.delete_webhook_on_provider") - @patch("api.internal.repo.views.create_webhook_on_provider") - def test_reset_webhook_creates_new_webhook_even_if_no_hookid( - self, create_webhook_mock, delete_webhook_mock, mocked_get_permissions - ): - mocked_get_permissions.return_value = True, True - new_webhook_id = "123" - create_webhook_mock.return_value = new_webhook_id - - response = self._reset_webhook() - - self.repo.refresh_from_db() - assert self.repo.hookid == new_webhook_id - - @patch("api.internal.repo.views.delete_webhook_on_provider") - @patch("api.internal.repo.views.create_webhook_on_provider") - def test_reset_webhook_returns_correct_code_and_response_if_TorgitClientError_raised( - self, create_webhook_mock, delete_webhook_mock, mocked_get_permissions - ): - code = 403 - message = "No can do, buddy" - mocked_get_permissions.return_value = True, True - create_webhook_mock.side_effect = TorngitClientGeneralError( - code, response_data=None, message=message - ) - - response = self._reset_webhook() - - assert response.status_code == code - assert response.data == {"detail": message} - @patch("services.archive.ArchiveService.read_chunks", lambda obj, _: "") def test_retrieve_returns_latest_commit_data(self, mocked_get_permissions): self.maxDiff = None @@ -1017,14 +742,12 @@ def test_retrieve_returns_latest_commit_data(self, mocked_get_permissions): "filename": "test_file_1.py", "file_index": 2, "file_totals": [1, 10, 8, 2, 5, "80.00000", 6, 7, 9, 8, 20, 40, 13], - "session_totals": [[0, 10, 8, 2, 0, "80.00000", 0, 0, 0, 0, 0, 0, 0]], "diff_totals": [0, 2, 1, 1, 0, "50.00000", 0, 0, 0, 0, 0, 0, 0], }, { "filename": "test_file_2.py", "file_index": 0, "file_totals": [1, 3, 2, 1, 0, "66.66667", 0, 0, 0, 0, 0, 0, 0], - "session_totals": [[0, 3, 2, 1, 0, "66.66667", 0, 0, 0, 0, 0, 0, 0]], "diff_totals": None, }, ] @@ -1069,7 +792,7 @@ def test_retrieve_returns_latest_commit_data(self, mocked_get_permissions): "hits": 2, "misses": 1, "partials": 0, - "coverage": 66.67, + "coverage": 66.66, "branches": 0, "methods": 0, "sessions": 0, @@ -1116,9 +839,6 @@ def side_effect(path, *args, **kwargs): 40, 13, ], - "session_totals": [ - [0, 10, 8, 2, 0, "80.00000", 0, 0, 0, 0, 0, 0, 0] - ], "diff_totals": [ 0, 2, @@ -1153,9 +873,6 @@ def side_effect(path, *args, **kwargs): 0, 0, ], - "session_totals": [ - [0, 3, 2, 1, 0, "66.66667", 0, 0, 0, 0, 0, 0, 0] - ], "diff_totals": None, }, ] @@ -1210,7 +927,7 @@ def side_effect(path, *args, **kwargs): "hits": 2, "misses": 1, "partials": 0, - "coverage": 66.67, + "coverage": 66.66, "branches": 0, "methods": 0, "sessions": 0, @@ -1259,7 +976,7 @@ def test_latest_commit_is_none_if_dne(self, mocked_get_permissions): response = self._retrieve() - assert response.data["latest_commit"] == None + assert response.data["latest_commit"] is None def test_can_retrieve_repo_name_containing_dot(self, mocked_get_permissions): mocked_get_permissions.return_value = True, True diff --git a/api/internal/tests/views/test_self_hosted_settings_viewset.py b/api/internal/tests/views/test_self_hosted_settings_viewset.py index 2c6eb85334..a3514b713b 100644 --- a/api/internal/tests/views/test_self_hosted_settings_viewset.py +++ b/api/internal/tests/views/test_self_hosted_settings_viewset.py @@ -1,11 +1,7 @@ -from unittest.mock import patch - from django.test import TestCase, override_settings from rest_framework.reverse import reverse -from codecov_auth.models import Owner from codecov_auth.tests.factories import OwnerFactory -from services.self_hosted import activate_owner, is_autoactivation_enabled from utils.test_utils import APIClient @@ -28,64 +24,3 @@ def test_settings(self): res = self.client.get(reverse("selfhosted-users-list")) # not authenticated assert res.status_code == 403 - - -@override_settings(IS_ENTERPRISE=True, ROOT_URLCONF="api.internal.enterprise_urls") -class SettingsViewsetTestCase(TestCase): - def setUp(self): - self.current_owner = OwnerFactory() - self.client = APIClient() - self.client.force_login_owner(self.current_owner) - - @patch("services.self_hosted.license_seats") - @patch("services.self_hosted.is_autoactivation_enabled") - @patch("services.self_hosted.admin_owners") - def test_settings(self, admin_owners, is_autoactivation_enabled, license_seats): - admin_owners.return_value = Owner.objects.filter(pk__in=[self.current_owner.pk]) - - is_autoactivation_enabled.return_value = False - license_seats.return_value = 5 - - res = self.client.get(reverse("selfhosted-settings-detail")) - assert res.status_code == 200 - assert res.json() == { - "plan_auto_activate": False, - "seats_used": 0, - "seats_limit": 5, - } - - is_autoactivation_enabled.return_value = True - - org = OwnerFactory() - owner = OwnerFactory(organizations=[org.pk]) - activate_owner(owner) - - res = self.client.get(reverse("selfhosted-settings-detail")) - assert res.status_code == 200 - assert res.json() == { - "plan_auto_activate": True, - "seats_used": 1, - "seats_limit": 5, - } - - @patch("services.self_hosted.admin_owners") - def test_settings_update(self, admin_owners): - admin_owners.return_value = Owner.objects.filter(pk__in=[self.current_owner.pk]) - - res = self.client.patch( - reverse("selfhosted-settings-detail"), - data={"plan_auto_activate": True}, - format="json", - ) - assert res.status_code == 200 - assert res.json()["plan_auto_activate"] == True - assert is_autoactivation_enabled() == True - - res = self.client.patch( - reverse("selfhosted-settings-detail"), - data={"plan_auto_activate": False}, - format="json", - ) - assert res.status_code == 200 - assert res.json()["plan_auto_activate"] == False - assert is_autoactivation_enabled() == False diff --git a/api/internal/tests/views/test_self_hosted_user_viewset.py b/api/internal/tests/views/test_self_hosted_user_viewset.py index 523df096cc..4d3acf5d18 100644 --- a/api/internal/tests/views/test_self_hosted_user_viewset.py +++ b/api/internal/tests/views/test_self_hosted_user_viewset.py @@ -86,10 +86,8 @@ class UserViewsetAdminTestCase(UserViewsetTestCase): def test_list_users(self, admin_owners): admin_owners.return_value = Owner.objects.filter(pk__in=[self.current_owner.pk]) - other_owner = OwnerFactory() - other_other_owner = OwnerFactory( - oauth_token=None, organizations=[self.owner.ownerid] - ) + OwnerFactory() + OwnerFactory(oauth_token=None, organizations=[self.owner.ownerid]) activated_owner = OwnerFactory( oauth_token=None, organizations=None, @@ -128,7 +126,7 @@ def test_list_users(self, admin_owners): def test_list_users_filter_admin(self, admin_owners): admin_owners.return_value = Owner.objects.filter(pk__in=[self.current_owner.pk]) - other_owner = OwnerFactory() + OwnerFactory() res = self.client.get(reverse("selfhosted-users-list"), {"is_admin": True}) assert res.status_code == 200 diff --git a/api/internal/urls.py b/api/internal/urls.py index 9f0e4087fa..312eb44523 100644 --- a/api/internal/urls.py +++ b/api/internal/urls.py @@ -10,7 +10,6 @@ from api.internal.feature.views import FeaturesView from api.internal.owner.views import ( AccountDetailsViewSet, - InvoiceViewSet, OwnerViewSet, UserViewSet, ) @@ -29,7 +28,6 @@ owner_artifacts_router = OptionalTrailingSlashRouter() owner_artifacts_router.register(r"users", UserViewSet, basename="users") -owner_artifacts_router.register(r"invoices", InvoiceViewSet, basename="invoices") account_details_router = RetrieveUpdateDestroyRouter() account_details_router.register( diff --git a/api/public/v1/tests/views/test_pull_viewset.py b/api/public/v1/tests/views/test_pull_viewset.py index 3a364a946b..7e6424ccd2 100644 --- a/api/public/v1/tests/views/test_pull_viewset.py +++ b/api/public/v1/tests/views/test_pull_viewset.py @@ -1,5 +1,5 @@ import json -from unittest.mock import call, patch +from unittest.mock import patch from rest_framework.test import APIClient, APITestCase diff --git a/api/public/v1/views.py b/api/public/v1/views.py index 5004de6df5..37a45ec68c 100644 --- a/api/public/v1/views.py +++ b/api/public/v1/views.py @@ -7,7 +7,7 @@ from api.shared.mixins import RepoPropertyMixin from codecov_auth.authentication.repo_auth import RepositoryLegacyTokenAuthentication -from core.models import Commit, Pull +from core.models import Commit from services.task import TaskService from .permissions import PullUpdatePermission diff --git a/api/public/v2/compare/views.py b/api/public/v2/compare/views.py index 518e64b3e2..fd57c6f151 100644 --- a/api/public/v2/compare/views.py +++ b/api/public/v2/compare/views.py @@ -1,5 +1,3 @@ -from inspect import Parameter - from distutils.util import strtobool from drf_spectacular.types import OpenApiTypes from drf_spectacular.utils import OpenApiParameter, extend_schema @@ -15,7 +13,6 @@ ImpactedFilesComparisonSerializer, ImpactedFileSegmentsSerializer, ) -from core.models import Commit from services.components import ComponentComparison, commit_components from services.decorators import torngit_safe diff --git a/api/public/v2/component/views.py b/api/public/v2/component/views.py index dde63ecb5d..8a26c6f2a9 100644 --- a/api/public/v2/component/views.py +++ b/api/public/v2/component/views.py @@ -8,6 +8,7 @@ from api.shared.mixins import RepoPropertyMixin from api.shared.permissions import RepositoryArtifactPermissions from services.components import commit_components, component_filtered_report +from utils import round_decimals_down @extend_schema( @@ -45,7 +46,9 @@ def list(self, request, *args, **kwargs): component_report = component_filtered_report(report, [component]) coverage = None if component_report.totals.coverage is not None: - coverage = round(float(component_report.totals.coverage), 2) + coverage = round_decimals_down( + float(component_report.totals.coverage), 2 + ) components_with_coverage.append( { "component_id": component.component_id, diff --git a/api/public/v2/flag/serializers.py b/api/public/v2/flag/serializers.py index f3ecde3764..0c2b94c4d0 100644 --- a/api/public/v2/flag/serializers.py +++ b/api/public/v2/flag/serializers.py @@ -3,3 +3,4 @@ class FlagSerializer(serializers.Serializer): flag_name = serializers.CharField(label="flag name") + coverage = serializers.FloatField(label="flag coverage") diff --git a/api/public/v2/flag/views.py b/api/public/v2/flag/views.py index 6ca6a5f6b2..a4832248a2 100644 --- a/api/public/v2/flag/views.py +++ b/api/public/v2/flag/views.py @@ -1,5 +1,11 @@ +from typing import Any + +from django.db.models import QuerySet from drf_spectacular.utils import extend_schema from rest_framework import mixins, viewsets +from rest_framework.exceptions import NotFound +from rest_framework.request import Request +from rest_framework.response import Response from api.public.v2.flag.serializers import FlagSerializer from api.public.v2.schema import repo_parameters @@ -15,11 +21,24 @@ class FlagViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, RepoPropertyMi lookup_field = "flag_name" queryset = RepositoryFlag.objects.none() - def get_queryset(self): - return self.repo.flags.all() + def get_queryset(self) -> QuerySet: + results = [ + {"flag_name": f.flag_name, "coverage": None} for f in self.repo.flags.all() + ] + try: + report = self.get_commit().full_report + if not report: + return results + except NotFound: + return results + + for i, val in enumerate(results): + flag_report = report.filter(flags=[val["flag_name"]]) + results[i]["coverage"] = flag_report.totals.coverage or 0 + return results @extend_schema(summary="Flag list") - def list(self, request, *args, **kwargs): + def list(self, request: Request, *args: Any, **kwargs: Any) -> Response: """ Returns a paginated list of flags for the specified repository """ diff --git a/api/public/v2/owner/views.py b/api/public/v2/owner/views.py index 72ec0f9219..d6dc182f3b 100644 --- a/api/public/v2/owner/views.py +++ b/api/public/v2/owner/views.py @@ -27,10 +27,7 @@ def retrieve(self, request, *args, **kwargs): @extend_schema(parameters=owner_parameters, tags=["Users"]) -class UserViewSet( - UserViewSetMixin, - mixins.ListModelMixin, -): +class UserViewSet(UserViewSetMixin, mixins.ListModelMixin, mixins.RetrieveModelMixin): serializer_class = UserSerializer queryset = Owner.objects.none() @@ -41,6 +38,13 @@ def list(self, request, *args, **kwargs): """ return super().list(request, *args, **kwargs) + @extend_schema(summary="User detail") + def retrieve(self, request, *args, **kwargs): + """ + Returns a user for the specified owner_username or ownerid + """ + return super().retrieve(request, *args, **kwargs) + @extend_schema( parameters=[ diff --git a/api/public/v2/repo/serializers.py b/api/public/v2/repo/serializers.py index e8dbe34c5a..a99cd57c45 100644 --- a/api/public/v2/repo/serializers.py +++ b/api/public/v2/repo/serializers.py @@ -1,5 +1,3 @@ -from cProfile import label - from rest_framework import serializers from api.public.v2.owner.serializers import OwnerSerializer diff --git a/api/public/v2/report/views.py b/api/public/v2/report/views.py index 8f141ae12a..1045b2b9df 100644 --- a/api/public/v2/report/views.py +++ b/api/public/v2/report/views.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from drf_spectacular.types import OpenApiTypes from drf_spectacular.utils import OpenApiParameter, extend_schema diff --git a/api/public/v2/test_results/__init__.py b/api/public/v2/test_results/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/public/v2/test_results/serializers.py b/api/public/v2/test_results/serializers.py new file mode 100644 index 0000000000..161368567c --- /dev/null +++ b/api/public/v2/test_results/serializers.py @@ -0,0 +1,40 @@ +from rest_framework import serializers + +from reports.models import TestInstance + + +class TestInstanceSerializer(serializers.ModelSerializer): + id = serializers.IntegerField(label="id") + name = serializers.CharField(source="test.name", read_only=True, label="test name") + test_id = serializers.CharField(label="test id") + failure_message = serializers.CharField(label="test name") + duration_seconds = serializers.FloatField(label="duration in seconds") + commitid = serializers.CharField(label="commit SHA") + outcome = serializers.CharField(label="outcome") + branch = serializers.CharField(label="branch name") + repoid = serializers.IntegerField(label="repo id") + failure_rate = serializers.FloatField( + source="test.failure_rate", read_only=True, label="failure rate" + ) + commits_where_fail = serializers.ListField( + source="test.commits_where_fail", + read_only=True, + label="commits where test failed", + ) + + class Meta: + model = TestInstance + read_only_fields = ( + "id", + "test_id", + "failure_message", + "duration_seconds", + "commitid", + "outcome", + "branch", + "repoid", + "failure_rate", + "name", + "commits_where_fail", + ) + fields = read_only_fields diff --git a/api/public/v2/test_results/views.py b/api/public/v2/test_results/views.py new file mode 100644 index 0000000000..825bd8f305 --- /dev/null +++ b/api/public/v2/test_results/views.py @@ -0,0 +1,102 @@ +import django_filters +from django_filters.rest_framework import DjangoFilterBackend +from drf_spectacular.types import OpenApiTypes +from drf_spectacular.utils import OpenApiParameter, extend_schema +from rest_framework import mixins, viewsets + +from api.shared.mixins import RepoPropertyMixin +from api.shared.permissions import RepositoryArtifactPermissions +from reports.models import TestInstance + +from .serializers import TestInstanceSerializer + + +class TestResultsFilters(django_filters.FilterSet): + commit_id = django_filters.CharFilter(field_name="commitid") + outcome = django_filters.CharFilter(field_name="outcome") + duration_min = django_filters.NumberFilter( + field_name="duration_seconds", lookup_expr="gte" + ) + duration_max = django_filters.NumberFilter( + field_name="duration_seconds", lookup_expr="lte" + ) + branch = django_filters.CharFilter(field_name="branch") + + class Meta: + model = TestInstance + fields = ["commit_id", "outcome", "duration_min", "duration_max", "branch"] + + +@extend_schema( + parameters=[ + OpenApiParameter( + "commit_id", + OpenApiTypes.STR, + OpenApiParameter.QUERY, + description="Commit SHA for which to return test results", + ), + OpenApiParameter( + "outcome", + OpenApiTypes.STR, + OpenApiParameter.QUERY, + description="Status of the test (failure, skip, error, pass)", + ), + OpenApiParameter( + "duration_min", + OpenApiTypes.INT, + OpenApiParameter.QUERY, + description="Minimum duration of the test in seconds", + ), + OpenApiParameter( + "duration_max", + OpenApiTypes.INT, + OpenApiParameter.QUERY, + description="Maximum duration of the test in seconds", + ), + OpenApiParameter( + "branch", + OpenApiTypes.STR, + OpenApiParameter.QUERY, + description="Branch name for which to return test results", + ), + ], + tags=["Test Results"], + summary="Retrieve test results", +) +class TestResultsView( + viewsets.GenericViewSet, + mixins.ListModelMixin, + mixins.RetrieveModelMixin, + RepoPropertyMixin, +): + serializer_class = TestInstanceSerializer + permission_classes = [RepositoryArtifactPermissions] + filter_backends = [DjangoFilterBackend] + filterset_class = TestResultsFilters + + def get_queryset(self): + return TestInstance.objects.filter(repoid=self.repo.repoid) + + @extend_schema(summary="Test results list") + def list(self, request, *args, **kwargs): + """ + Returns a list of test results for the specified repository and commit + """ + return super().list(request, *args, **kwargs) + + @extend_schema( + summary="Test results detail", + parameters=[ + OpenApiParameter( + "id", + OpenApiTypes.INT, + OpenApiParameter.PATH, + description="Test instance ID", + ), + ], + ) + def retrieve(self, request, *args, **kwargs): + """ + Returns a single test result by ID + """ + return super().retrieve(request, *args, **kwargs) diff --git a/api/public/v2/tests/test_api_commit_viewset.py b/api/public/v2/tests/test_api_commit_viewset.py index 2836bd94c5..cf8dc1d199 100644 --- a/api/public/v2/tests/test_api_commit_viewset.py +++ b/api/public/v2/tests/test_api_commit_viewset.py @@ -91,7 +91,7 @@ def test_commit_list_not_authenticated(self, get_repo_permissions): author = OwnerFactory() repo = RepositoryFactory(author=author, private=False) - commit = CommitFactory(repository=repo) + CommitFactory(repository=repo) self.client.logout() response = self.client.get( @@ -146,7 +146,7 @@ def test_commit_list_authenticated(self, get_repo_permissions): "hits": 19, "misses": 5, "partials": 0, - "coverage": 79.17, + "coverage": 79.16, "branches": 0, "methods": 0, "sessions": 2, @@ -282,7 +282,7 @@ def test_commit_detail_authenticated( "hits": 19, "misses": 5, "partials": 0, - "coverage": 79.17, + "coverage": 79.16, "branches": 0, "methods": 0, "sessions": 2, diff --git a/api/public/v2/tests/test_api_compare_viewset.py b/api/public/v2/tests/test_api_compare_viewset.py index 8c80a53254..25c22664f7 100644 --- a/api/public/v2/tests/test_api_compare_viewset.py +++ b/api/public/v2/tests/test_api_compare_viewset.py @@ -223,13 +223,12 @@ def __init__(self): ] -@patch("services.comparison.Comparison.has_unmerged_base_commits", lambda self: True) @patch("services.comparison.Comparison.head_report", new_callable=PropertyMock) @patch("services.comparison.Comparison.base_report", new_callable=PropertyMock) @patch("services.repo_providers.RepoProviderService.get_adapter") class TestCompareViewSetRetrieve(APITestCase): """ - Tests for retrieving a comparison. Does not test data that will be depracated, + Tests for retrieving a comparison. Does not test data that will be deprecated, eg base and head report fields. Tests for commits etc will be added as the compare-api refactor progresses. """ @@ -412,7 +411,6 @@ def test_returns_200_and_expected_files_on_success( assert response.status_code == status.HTTP_200_OK assert response.data["files"] == self.expected_files - assert response.data["has_unmerged_base_commits"] is True def test_returns_404_if_base_or_head_references_not_found( self, adapter_mock, base_report_mock, head_report_mock @@ -535,7 +533,7 @@ def test_pullid_with_nonexistent_head_returns_404( assert response.status_code == status.HTTP_404_NOT_FOUND - def test_file_returns_comparefile_with_diff_and_src_data( + def test_file_returns_compare_file_with_diff_and_src_data( self, adapter_mock, base_report_mock, head_report_mock ): base_report_mock.return_value = self.base_report @@ -597,7 +595,7 @@ def test_missing_base_report_returns_none_base_totals( response = self._get_comparison() assert response.status_code == status.HTTP_200_OK - assert response.data["totals"]["base"] == None + assert response.data["totals"]["base"] is None def test_no_raw_reports_returns_404( self, adapter_mock, base_report_mock, head_report_mock @@ -669,49 +667,6 @@ def test_pull_request_pseudo_comparison_can_update_base_report( assert response.status_code == status.HTTP_200_OK assert response.data["files"] == self.expected_files - @patch("redis.Redis.get", lambda self, key: None) - @patch("redis.Redis.set", lambda self, key, val, ex: None) - @patch( - "services.comparison.PullRequestComparison.pseudo_diff_adjusts_tracked_lines", - new_callable=PropertyMock, - ) - @patch( - "services.comparison.PullRequestComparison.allow_coverage_offsets", - new_callable=PropertyMock, - ) - @patch( - "services.comparison.PullRequestComparison.update_base_report_with_pseudo_diff" - ) - def test_pull_request_pseudo_comparison_returns_error_if_coverage_offsets_not_allowed( - self, - update_base_report_mock, - allow_coverage_offsets_mock, - pseudo_diff_adjusts_tracked_lines_mock, - adapter_mock, - base_report_mock, - head_report_mock, - ): - adapter_mock.return_value = self.mocked_compare_adapter - base_report_mock.return_value = self.base_report - head_report_mock.return_value = self.head_report - - pseudo_diff_adjusts_tracked_lines_mock.return_value = True - allow_coverage_offsets_mock.return_value = False - - response = self._get_comparison( - query_params={ - "pullid": PullFactory( - base=self.base.commitid, - head=self.head.commitid, - compared_to=self.base.commitid, - pullid=2, - repository=self.repo, - ).pullid - } - ) - - assert response.status_code == status.HTTP_400_BAD_REQUEST - def test_flags_comparison(self, adapter_mock, base_report_mock, head_report_mock): adapter_mock.return_value = self.mocked_compare_adapter base_report_mock.return_value = self.base_report @@ -787,7 +742,7 @@ def test_components_comparison( "hits": 5, "misses": 4, "partials": 0, - "coverage": 55.56, + "coverage": 55.55, "branches": 0, "methods": 0, "messages": 0, @@ -839,7 +794,7 @@ def test_components_comparison( "hits": 2, "misses": 0, "partials": 1, - "coverage": 66.67, + "coverage": 66.66, "branches": 1, "methods": 0, "messages": 0, diff --git a/api/public/v2/tests/test_api_owner_viewset.py b/api/public/v2/tests/test_api_owner_viewset.py index 71a8ee445b..a46b8252a6 100644 --- a/api/public/v2/tests/test_api_owner_viewset.py +++ b/api/public/v2/tests/test_api_owner_viewset.py @@ -1,4 +1,5 @@ from rest_framework import status +from rest_framework.exceptions import ErrorDetail from rest_framework.reverse import reverse from rest_framework.test import APITestCase @@ -37,7 +38,11 @@ def test_retrieve_returns_owner_with_period_username(self): def test_retrieve_returns_404_if_no_matching_username(self): response = self._retrieve(kwargs={"service": "github", "owner_username": "fff"}) assert response.status_code == status.HTTP_404_NOT_FOUND - assert response.data == {"detail": "Not found."} + assert response.data == { + "detail": ErrorDetail( + string="No Owner matches the given query.", code="not_found" + ) + } def test_retrieve_owner_unknown_service_returns_404(self): response = self._retrieve( @@ -51,6 +56,9 @@ class UserViewSetTests(APITestCase): def _list(self, kwargs): return self.client.get(reverse("api-v2-users-list", kwargs=kwargs)) + def _detail(self, kwargs): + return self.client.get(reverse("api-v2-users-detail", kwargs=kwargs)) + def setUp(self): self.org = OwnerFactory(service="github") self.current_owner = OwnerFactory(service="github", organizations=[self.org.pk]) @@ -78,3 +86,91 @@ def test_list(self): ], "total_pages": 1, } + + def test_retrieve_by_username(self): + another_user = OwnerFactory(service="github", organizations=[self.org.pk]) + response = self._detail( + kwargs={ + "service": self.org.service, + "owner_username": self.org.username, + "user_username_or_ownerid": another_user.username, + } + ) + assert response.status_code == status.HTTP_200_OK + assert response.data == { + "service": "github", + "username": another_user.username, + "name": another_user.name, + "activated": False, + "is_admin": False, + "email": another_user.email, + } + + def test_retrieve_by_ownerid(self): + another_user = OwnerFactory(service="github", organizations=[self.org.pk]) + response = self._detail( + kwargs={ + "service": self.org.service, + "owner_username": self.org.username, + "user_username_or_ownerid": another_user.ownerid, + } + ) + assert response.status_code == status.HTTP_200_OK + assert response.data == { + "service": "github", + "username": another_user.username, + "name": another_user.name, + "activated": False, + "is_admin": False, + "email": another_user.email, + } + + def test_retrieve_cannot_get_details_of_members_of_other_orgs(self): + another_org = OwnerFactory(service="github") + another_user = OwnerFactory(service="github", organizations=[another_org.pk]) + kwargs = { + "service": self.org.service, + "owner_username": self.org.username, + "user_username_or_ownerid": another_user.username, + } + response = self._detail(kwargs=kwargs) + assert response.status_code == status.HTTP_404_NOT_FOUND + + another_user.organizations.append(self.org.pk) + another_user.save() + + response = self._detail(kwargs=kwargs) + assert response.status_code == status.HTTP_200_OK + assert response.data == { + "service": "github", + "username": another_user.username, + "name": another_user.name, + "activated": False, + "is_admin": False, + "email": another_user.email, + } + + def test_retrieve_cannot_get_details_if_not_member_of_org(self): + another_org = OwnerFactory(service="github") + another_user = OwnerFactory(service="github", organizations=[another_org.pk]) + kwargs = { + "service": another_org.service, + "owner_username": another_org.username, + "user_username_or_ownerid": another_user.username, + } + response = self._detail(kwargs=kwargs) + assert response.status_code == status.HTTP_404_NOT_FOUND + + self.current_owner.organizations.append(another_org.pk) + self.current_owner.save() + + response = self._detail(kwargs=kwargs) + assert response.status_code == status.HTTP_200_OK + assert response.data == { + "service": "github", + "username": another_user.username, + "name": another_user.name, + "activated": False, + "is_admin": False, + "email": another_user.email, + } diff --git a/api/public/v2/tests/test_file_report_viewset.py b/api/public/v2/tests/test_file_report_viewset.py index 935252a18d..120f347538 100644 --- a/api/public/v2/tests/test_file_report_viewset.py +++ b/api/public/v2/tests/test_file_report_viewset.py @@ -1,18 +1,15 @@ -import os from unittest.mock import call, patch from urllib.parse import urlencode from django.conf import settings -from django.test import TestCase, override_settings +from django.test import TestCase from rest_framework.reverse import reverse from shared.reports.resources import Report, ReportFile, ReportLine from shared.utils.sessions import Session -from codecov_auth.models import UserToken -from codecov_auth.tests.factories import OwnerFactory, UserTokenFactory +from codecov_auth.tests.factories import OwnerFactory from core.models import Branch -from core.tests.factories import BranchFactory, CommitFactory, RepositoryFactory -from services.components import Component +from core.tests.factories import CommitFactory, RepositoryFactory from utils.test_utils import APIClient diff --git a/api/public/v2/tests/test_flag_viewset.py b/api/public/v2/tests/test_flag_viewset.py index e48972c24b..fb1166eaa0 100644 --- a/api/public/v2/tests/test_flag_viewset.py +++ b/api/public/v2/tests/test_flag_viewset.py @@ -2,13 +2,41 @@ from django.test import TestCase from rest_framework.reverse import reverse +from shared.reports.resources import Report, ReportFile, ReportLine +from shared.utils.sessions import Session from codecov_auth.tests.factories import OwnerFactory -from core.tests.factories import RepositoryFactory +from core.tests.factories import CommitFactory, RepositoryFactory from reports.tests.factories import RepositoryFlagFactory from utils.test_utils import APIClient +def flags_report(): + report = Report() + session_a_id, _ = report.add_session(Session(flags=["foo"])) + session_b_id, _ = report.add_session(Session(flags=["bar"])) + + file_a = ReportFile("foo/file1.py") + file_a.append(1, ReportLine.create(coverage=1, sessions=[[session_a_id, 1]])) + file_a.append(2, ReportLine.create(coverage=0, sessions=[[session_a_id, 0]])) + file_a.append(3, ReportLine.create(coverage=1, sessions=[[session_a_id, 1]])) + file_a.append(5, ReportLine.create(coverage=1, sessions=[[session_a_id, 1]])) + file_a.append(6, ReportLine.create(coverage=0, sessions=[[session_a_id, 0]])) + file_a.append(8, ReportLine.create(coverage=1, sessions=[[session_a_id, 1]])) + file_a.append(9, ReportLine.create(coverage=1, sessions=[[session_a_id, 1]])) + file_a.append(10, ReportLine.create(coverage=0, sessions=[[session_a_id, 0]])) + report.append(file_a) + + file_b = ReportFile("bar/file2.py") + file_b.append(12, ReportLine.create(coverage=1, sessions=[[session_b_id, 1]])) + file_b.append( + 51, ReportLine.create(coverage="1/2", type="b", sessions=[[session_b_id, 2]]) + ) + report.append(file_b) + + return report + + @patch("api.shared.repo.repository_accessors.RepoAccessors.get_repo_permissions") class FlagViewSetTestCase(TestCase): def setUp(self): @@ -38,8 +66,54 @@ def _request_flags(self): ) return self.client.get(url) - def test_flag_list(self, get_repo_permissions): + def test_flag_list_no_commit(self, get_repo_permissions): + get_repo_permissions.return_value = (True, True) + + res = self._request_flags() + assert res.status_code == 200 + assert res.json() == { + "count": 2, + "next": None, + "previous": None, + "results": [ + {"flag_name": "foo", "coverage": None}, + {"flag_name": "bar", "coverage": None}, + ], + "total_pages": 1, + } + + def test_flag_list_no_report(self, get_repo_permissions): + CommitFactory( + author=self.org, + repository=self.repo, + branch=self.repo.branch, + ) + get_repo_permissions.return_value = (True, True) + + res = self._request_flags() + assert res.status_code == 200 + assert res.json() == { + "count": 2, + "next": None, + "previous": None, + "results": [ + {"flag_name": "foo", "coverage": None}, + {"flag_name": "bar", "coverage": None}, + ], + "total_pages": 1, + } + + @patch("shared.reports.api_report_service.build_report_from_commit") + def test_flag_list_with_coverage( + self, build_report_from_commit, get_repo_permissions + ): + CommitFactory( + author=self.org, + repository=self.repo, + branch=self.repo.branch, + ) get_repo_permissions.return_value = (True, True) + build_report_from_commit.return_value = flags_report() res = self._request_flags() assert res.status_code == 200 @@ -47,6 +121,9 @@ def test_flag_list(self, get_repo_permissions): "count": 2, "next": None, "previous": None, - "results": [{"flag_name": "foo"}, {"flag_name": "bar"}], + "results": [ + {"flag_name": "foo", "coverage": 62.5}, + {"flag_name": "bar", "coverage": 100}, + ], "total_pages": 1, } diff --git a/api/public/v2/tests/test_report_viewset.py b/api/public/v2/tests/test_report_viewset.py index 05191630cd..d59454992a 100644 --- a/api/public/v2/tests/test_report_viewset.py +++ b/api/public/v2/tests/test_report_viewset.py @@ -1,4 +1,3 @@ -import os from unittest.mock import call, patch from urllib.parse import urlencode @@ -8,7 +7,6 @@ from shared.reports.resources import Report, ReportFile, ReportLine from shared.utils.sessions import Session -from codecov_auth.models import UserToken from codecov_auth.tests.factories import OwnerFactory, UserTokenFactory from core.tests.factories import BranchFactory, CommitFactory, RepositoryFactory from services.components import Component @@ -110,7 +108,7 @@ def setUp(self): self.commit3 = CommitFactory( author=self.org, repository=self.repo, - branch=self.branch, + branch=self.branch.name, ) self.branch.head = self.commit3.commitid self.branch.save() diff --git a/api/public/v2/tests/test_test_results_view.py b/api/public/v2/tests/test_test_results_view.py new file mode 100644 index 0000000000..f8a9339c93 --- /dev/null +++ b/api/public/v2/tests/test_test_results_view.py @@ -0,0 +1,142 @@ +from unittest.mock import patch + +from django.urls import reverse +from freezegun import freeze_time +from rest_framework import status + +from codecov.tests.base_test import InternalAPITest +from codecov_auth.tests.factories import OwnerFactory +from core.tests.factories import RepositoryFactory +from reports.tests.factories import TestInstanceFactory +from utils.test_utils import APIClient + + +@freeze_time("2022-01-01T00:00:00") +class TestResultsViewsetTests(InternalAPITest): + def setUp(self): + self.org = OwnerFactory() + self.repo = RepositoryFactory(author=self.org) + self.current_owner = OwnerFactory( + permission=[self.repo.repoid], organizations=[self.org.ownerid] + ) + self.test_instances = [ + TestInstanceFactory(repoid=self.repo.repoid, commitid="1234"), + TestInstanceFactory(repoid=self.repo.repoid, commitid="3456"), + ] + + self.client = APIClient() + self.client.force_login_owner(self.current_owner) + + def test_list(self): + url = reverse( + "api-v2-tests-results-list", + kwargs={ + "service": self.org.service, + "owner_username": self.org.username, + "repo_name": self.repo.name, + }, + ) + res = self.client.get(url) + assert res.status_code == status.HTTP_200_OK + assert res.json() == { + "count": 2, + "next": None, + "previous": None, + "results": [ + { + "id": self.test_instances[0].id, + "name": self.test_instances[0].test.name, + "test_id": self.test_instances[0].test_id, + "failure_message": self.test_instances[0].failure_message, + "duration_seconds": self.test_instances[0].duration_seconds, + "commitid": self.test_instances[0].commitid, + "outcome": self.test_instances[0].outcome, + "branch": self.test_instances[0].branch, + "repoid": self.test_instances[0].repoid, + "failure_rate": self.test_instances[0].test.failure_rate, + "commits_where_fail": self.test_instances[ + 0 + ].test.commits_where_fail, + }, + { + "id": self.test_instances[1].id, + "name": self.test_instances[1].test.name, + "test_id": self.test_instances[1].test_id, + "failure_message": self.test_instances[1].failure_message, + "duration_seconds": self.test_instances[1].duration_seconds, + "commitid": self.test_instances[1].commitid, + "outcome": self.test_instances[1].outcome, + "branch": self.test_instances[1].branch, + "repoid": self.test_instances[1].repoid, + "failure_rate": self.test_instances[1].test.failure_rate, + "commits_where_fail": self.test_instances[ + 1 + ].test.commits_where_fail, + }, + ], + "total_pages": 1, + } + + def test_list_filters(self): + url = reverse( + "api-v2-tests-results-list", + kwargs={ + "service": self.org.service, + "owner_username": self.org.username, + "repo_name": self.repo.name, + }, + ) + res = self.client.get(f"{url}?commit_id={self.test_instances[0].commitid}") + assert res.status_code == status.HTTP_200_OK + assert res.json() == { + "count": 1, + "next": None, + "previous": None, + "results": [ + { + "id": self.test_instances[0].id, + "name": self.test_instances[0].test.name, + "test_id": self.test_instances[0].test_id, + "failure_message": self.test_instances[0].failure_message, + "duration_seconds": self.test_instances[0].duration_seconds, + "commitid": self.test_instances[0].commitid, + "outcome": self.test_instances[0].outcome, + "branch": self.test_instances[0].branch, + "repoid": self.test_instances[0].repoid, + "failure_rate": self.test_instances[0].test.failure_rate, + "commits_where_fail": self.test_instances[ + 0 + ].test.commits_where_fail, + }, + ], + "total_pages": 1, + } + + @patch("api.shared.repo.repository_accessors.RepoAccessors.get_repo_permissions") + def test_retrieve(self, get_repo_permissions): + get_repo_permissions.return_value = (True, True) + res = self.client.get( + reverse( + "api-v2-tests-results-detail", + kwargs={ + "service": self.org.service, + "owner_username": self.org.username, + "repo_name": self.repo.name, + "pk": self.test_instances[0].pk, + }, + ) + ) + assert res.status_code == status.HTTP_200_OK + assert res.json() == { + "id": self.test_instances[0].id, + "name": self.test_instances[0].test.name, + "test_id": self.test_instances[0].test_id, + "failure_message": self.test_instances[0].failure_message, + "duration_seconds": self.test_instances[0].duration_seconds, + "commitid": self.test_instances[0].commitid, + "outcome": self.test_instances[0].outcome, + "branch": self.test_instances[0].branch, + "repoid": self.test_instances[0].repoid, + "failure_rate": self.test_instances[0].test.failure_rate, + "commits_where_fail": self.test_instances[0].test.commits_where_fail, + } diff --git a/api/public/v2/tests/test_totals_viewset.py b/api/public/v2/tests/test_totals_viewset.py index e0ac490e0a..e2d5f7de40 100644 --- a/api/public/v2/tests/test_totals_viewset.py +++ b/api/public/v2/tests/test_totals_viewset.py @@ -1,4 +1,3 @@ -import os from unittest.mock import call, patch from urllib.parse import urlencode @@ -94,7 +93,7 @@ def setUp(self): self.commit3 = CommitFactory( author=self.org, repository=self.repo, - branch=self.branch, + branch=self.branch.name, ) self.branch.head = self.commit3.commitid self.branch.save() diff --git a/api/public/v2/urls.py b/api/public/v2/urls.py index 1afdb741c2..6e98edf542 100644 --- a/api/public/v2/urls.py +++ b/api/public/v2/urls.py @@ -16,6 +16,7 @@ from .pull.views import PullViewSet from .repo.views import RepositoryConfigView, RepositoryViewSet from .report.views import FileReportViewSet, ReportViewSet, TotalsViewSet +from .test_results.views import TestResultsView urls.handler404 = not_found urls.handler500 = server_error @@ -41,6 +42,9 @@ repository_artifacts_router.register( r"components", ComponentViewSet, basename="api-v2-components" ) +repository_artifacts_router.register( + r"test-results", TestResultsView, basename="api-v2-tests-results" +) compare_router = RetrieveUpdateDestroyRouter() compare_router.register(r"compare", CompareViewSet, basename="api-v2-compare") @@ -67,7 +71,7 @@ service_prefix = "/" owner_prefix = "//" repo_prefix = "//repos//" -flag_prefix = repo_prefix + "flags//" +flag_prefix = repo_prefix + "flags//" commit_prefix = repo_prefix + "commits//" urlpatterns = [ diff --git a/api/shared/commit/serializers.py b/api/shared/commit/serializers.py index 37be1881df..0eec4c533b 100644 --- a/api/shared/commit/serializers.py +++ b/api/shared/commit/serializers.py @@ -2,6 +2,8 @@ from shared.reports.resources import Report, ReportFile from shared.utils.merge import line_type +from utils import round_decimals_down + class BaseTotalsSerializer(serializers.Serializer): files = serializers.IntegerField() @@ -15,7 +17,7 @@ class BaseTotalsSerializer(serializers.Serializer): def get_coverage(self, totals) -> float: if totals.coverage is not None: - return round(float(totals.coverage), 2) + return round_decimals_down(float(totals.coverage), 2) return 0 @@ -40,11 +42,11 @@ def get_coverage(self, totals) -> float: if totals.get("c") is None: return None else: - return round(float(totals["c"]), 2) + return round_decimals_down(float(totals["c"]), 2) def get_complexity_ratio(self, totals) -> float: return ( - round((totals["C"] / totals["N"]) * 100, 2) + round_decimals_down((totals["C"] / totals["N"]) * 100, 2) if totals["C"] and totals["N"] else 0 ) @@ -65,7 +67,7 @@ class ReportTotalsSerializer(BaseTotalsSerializer): def get_complexity_ratio(self, totals) -> float: return ( - round((totals.complexity / totals.complexity_total) * 100, 2) + round_decimals_down((totals.complexity / totals.complexity_total) * 100, 2) if totals.complexity and totals.complexity_total else 0 ) diff --git a/api/shared/compare/mixins.py b/api/shared/compare/mixins.py index 87cf2f974d..a8e288fc2c 100644 --- a/api/shared/compare/mixins.py +++ b/api/shared/compare/mixins.py @@ -1,9 +1,8 @@ import logging -from typing import Optional -from rest_framework import mixins, viewsets +from rest_framework import viewsets from rest_framework.decorators import action -from rest_framework.exceptions import NotFound, PermissionDenied +from rest_framework.exceptions import NotFound from rest_framework.response import Response from api.shared.mixins import CompareSlugMixin @@ -12,7 +11,6 @@ from services.comparison import ( CommitComparisonService, Comparison, - ComparisonReport, MissingComparisonCommit, MissingComparisonReport, PullRequestComparison, @@ -89,23 +87,10 @@ def retrieve(self, request, *args, **kwargs): comparison = self.get_object() # Some checks here for pseudo-comparisons. Basically, when pseudo-comparing, - # we sometimes might need to tweak the base report if the user allows us to - # in their yaml, or raise an error if not. + # we sometimes might need to tweak the base report if isinstance(comparison, PullRequestComparison): - if ( - comparison.pseudo_diff_adjusts_tracked_lines - and comparison.allow_coverage_offsets - ): + if comparison.pseudo_diff_adjusts_tracked_lines: comparison.update_base_report_with_pseudo_diff() - elif comparison.pseudo_diff_adjusts_tracked_lines: - return Response( - data={ - "detail": "Changes found in between %.7s...%.7s (pseudo...base) " - "which prevent comparing this pull request." - % (comparison.pull.compared_to, comparison.pull.base) - }, - status=400, - ) serializer = self.get_serializer(comparison) try: diff --git a/api/shared/compare/serializers.py b/api/shared/compare/serializers.py index f874bf2022..162fd32387 100644 --- a/api/shared/compare/serializers.py +++ b/api/shared/compare/serializers.py @@ -1,5 +1,4 @@ import dataclasses -import hashlib import logging from typing import List @@ -7,7 +6,6 @@ from api.internal.commit.serializers import CommitSerializer from api.shared.commit.serializers import ReportTotalsSerializer -from compare.models import CommitComparison from services.comparison import ( Comparison, ComparisonReport, @@ -52,7 +50,6 @@ class ComparisonSerializer(serializers.Serializer): diff = serializers.SerializerMethodField() files = serializers.SerializerMethodField() untracked = serializers.SerializerMethodField() - has_unmerged_base_commits = serializers.BooleanField() def get_untracked(self, comparison) -> List[str]: return [ diff --git a/api/shared/mixins.py b/api/shared/mixins.py index 1f8696b6e4..43af0d7881 100644 --- a/api/shared/mixins.py +++ b/api/shared/mixins.py @@ -1,7 +1,6 @@ from typing import Optional from django.conf import settings -from django.db import connection from django.http import Http404 from django.shortcuts import get_object_or_404 from django.utils.functional import cached_property diff --git a/api/shared/pull/mixins.py b/api/shared/pull/mixins.py index 6d013475e1..5f9cd29171 100644 --- a/api/shared/pull/mixins.py +++ b/api/shared/pull/mixins.py @@ -4,7 +4,6 @@ from rest_framework import filters, viewsets from api.shared.mixins import RepoPropertyMixin -from api.shared.pagination import PaginationMixin from api.shared.permissions import RepositoryArtifactPermissions from core.models import Commit diff --git a/api/shared/repo/repository_accessors.py b/api/shared/repo/repository_accessors.py index 4734bbd986..1143710263 100644 --- a/api/shared/repo/repository_accessors.py +++ b/api/shared/repo/repository_accessors.py @@ -1,15 +1,11 @@ -import asyncio import logging from asgiref.sync import async_to_sync from django.core.exceptions import ObjectDoesNotExist from django.utils import timezone -from rest_framework.exceptions import APIException, PermissionDenied -from shared.torngit.exceptions import TorngitClientError from codecov_auth.models import Owner from core.models import Repository -from services.decorators import torngit_safe from services.repo_providers import RepoProviderService log = logging.getLogger(__name__) diff --git a/api/shared/serializers.py b/api/shared/serializers.py index 4cb79fa525..4e3b0a1b37 100644 --- a/api/shared/serializers.py +++ b/api/shared/serializers.py @@ -3,7 +3,6 @@ from rest_framework.exceptions import NotFound from core.models import Branch, Commit, Pull -from utils.config import get_config class StringListField(serializers.ListField): diff --git a/billing/admin.py b/billing/admin.py deleted file mode 100644 index 8c38f3f3da..0000000000 --- a/billing/admin.py +++ /dev/null @@ -1,3 +0,0 @@ -from django.contrib import admin - -# Register your models here. diff --git a/billing/migrations/0003_delete_account.py b/billing/migrations/0003_delete_account.py new file mode 100644 index 0000000000..ce40e1bbbf --- /dev/null +++ b/billing/migrations/0003_delete_account.py @@ -0,0 +1,15 @@ +# Generated by Django 4.2.11 on 2024-07-24 00:23 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("billing", "0002_auto_20220118_1232"), + ] + + operations = [ + migrations.DeleteModel( + name="Account", + ), + ] diff --git a/billing/models.py b/billing/models.py deleted file mode 100644 index ca641d7d73..0000000000 --- a/billing/models.py +++ /dev/null @@ -1,19 +0,0 @@ -from django.db import models -from django_prometheus.models import ExportModelOperationsMixin - -# Create your models here. -from codecov.models import BaseCodecovModel -from plan.constants import PlanName - - -class PlanProviders(models.TextChoices): - GITHUB = "github" - - -class Account(ExportModelOperationsMixin("billing.account"), BaseCodecovModel): - stripe_customer_id = models.TextField(null=True) - stripe_subscription_id = models.TextField(null=True) - plan = models.TextField(default=PlanName.FREE_PLAN_NAME.value) - plan_provider = models.TextField(null=True, choices=PlanProviders.choices) - max_activated_user_count = models.SmallIntegerField(default=5) - should_auto_activate_users = models.BooleanField(default=True) diff --git a/billing/tests/test_views.py b/billing/tests/test_views.py index ffd6257341..46bdcf175f 100644 --- a/billing/tests/test_views.py +++ b/billing/tests/test_views.py @@ -1,19 +1,16 @@ import time from unittest.mock import patch -import pytest import stripe from django.conf import settings from freezegun import freeze_time -from pytest import raises from rest_framework import status from rest_framework.reverse import reverse from rest_framework.test import APIRequestFactory, APITestCase -from codecov_auth.models import Owner from codecov_auth.tests.factories import OwnerFactory from core.tests.factories import RepositoryFactory -from plan.constants import PlanName, TrialDaysAmount +from plan.constants import PlanName from ..constants import StripeHTTPHeaders @@ -83,7 +80,7 @@ def test_invalid_event_signature(self): assert response.status_code == status.HTTP_400_BAD_REQUEST def test_invoice_payment_succeeded_sets_owner_delinquent_false(self): - self.owner.deliquent = True + self.owner.delinquent = True self.owner.save() response = self._send_event( @@ -127,7 +124,7 @@ def test_customer_subscription_deleted_sets_plan_to_free(self): self.owner.plan_user_count = 20 self.owner.save() - response = self._send_event( + self._send_event( payload={ "type": "customer.subscription.deleted", "data": { @@ -143,8 +140,8 @@ def test_customer_subscription_deleted_sets_plan_to_free(self): assert self.owner.plan == PlanName.BASIC_PLAN_NAME.value assert self.owner.plan_user_count == 1 - assert self.owner.plan_activated_users == None - assert self.owner.stripe_subscription_id == None + assert self.owner.plan_activated_users is None + assert self.owner.stripe_subscription_id is None def test_customer_subscription_deleted_deactivates_all_repos(self): RepositoryFactory(author=self.owner, activated=True, active=True) @@ -155,7 +152,7 @@ def test_customer_subscription_deleted_deactivates_all_repos(self): self.owner.repository_set.filter(activated=True, active=True).count() == 3 ) - response = self._send_event( + self._send_event( payload={ "type": "customer.subscription.deleted", "data": { @@ -196,7 +193,7 @@ def test_customer_subscription_deleted_no_customer(self, log_info_mock): ) def test_customer_created_logs_and_doesnt_crash(self): - response = self._send_event( + self._send_event( payload={ "type": "customer.created", "data": {"object": {"id": "FOEKDCDEQ", "email": "test@email.com"}}, @@ -224,8 +221,8 @@ def test_customer_subscription_created_does_nothing_if_no_plan_id(self): ) self.owner.refresh_from_db() - assert self.owner.stripe_subscription_id == None - assert self.owner.stripe_customer_id == None + assert self.owner.stripe_subscription_id is None + assert self.owner.stripe_customer_id is None def test_customer_subscription_created_does_nothing_if_plan_not_paid_user_plan( self, @@ -234,7 +231,7 @@ def test_customer_subscription_created_does_nothing_if_plan_not_paid_user_plan( self.owner.stripe_customer_id = None self.owner.save() - response = self._send_event( + self._send_event( payload={ "type": "customer.subscription.created", "data": { @@ -250,8 +247,8 @@ def test_customer_subscription_created_does_nothing_if_plan_not_paid_user_plan( ) self.owner.refresh_from_db() - assert self.owner.stripe_subscription_id == None - assert self.owner.stripe_customer_id == None + assert self.owner.stripe_subscription_id is None + assert self.owner.stripe_customer_id is None def test_customer_subscription_created_sets_plan_info(self): self.owner.stripe_subscription_id = None @@ -288,16 +285,16 @@ def test_customer_subscription_created_sets_plan_info(self): @freeze_time("2023-06-19") @patch("plan.service.PlanService.expire_trial_when_upgrading") - @patch("services.billing.StripeService.update_payment_method") + @patch("services.billing.stripe.PaymentMethod.attach") + @patch("services.billing.stripe.Customer.modify") def test_customer_subscription_created_can_trigger_trial_expiration( - self, _, expire_trial_when_upgrading_mock + self, c_mock, pm_mock, expire_trial_when_upgrading_mock ): stripe_subscription_id = "FOEKDCDEQ" stripe_customer_id = "sdo050493" - plan_name = "users-pr-inappy" quantity = 20 - response = self._send_event( + self._send_event( payload={ "type": "customer.subscription.created", "data": { @@ -315,9 +312,10 @@ def test_customer_subscription_created_can_trigger_trial_expiration( expire_trial_when_upgrading_mock.assert_called_once() - @patch("services.billing.StripeService.update_payment_method") + @patch("services.billing.stripe.PaymentMethod.attach") + @patch("services.billing.stripe.Customer.modify") def test_customer_subscription_updated_does_not_change_subscription_if_not_paid_user_plan( - self, upm_mock + self, c_mock, pm_mock ): self.owner.plan = PlanName.BASIC_PLAN_NAME.value self.owner.plan_user_count = 0 @@ -346,18 +344,25 @@ def test_customer_subscription_updated_does_not_change_subscription_if_not_paid_ assert self.owner.plan == PlanName.BASIC_PLAN_NAME.value assert self.owner.plan_user_count == 0 assert self.owner.plan_auto_activate == False - upm_mock.assert_called_once_with(self.owner, "pm_1LhiRsGlVGuVgOrkQguJXdeV") + pm_mock.assert_called_once_with( + "pm_1LhiRsGlVGuVgOrkQguJXdeV", customer=self.owner.stripe_customer_id + ) + c_mock.assert_called_once_with( + self.owner.stripe_customer_id, + invoice_settings={"default_payment_method": "pm_1LhiRsGlVGuVgOrkQguJXdeV"}, + ) - @patch("services.billing.StripeService.update_payment_method") + @patch("services.billing.stripe.PaymentMethod.attach") + @patch("services.billing.stripe.Customer.modify") def test_customer_subscription_updated_does_not_change_subscription_if_there_is_a_schedule( - self, upm_mock + self, c_mock, pm_mock ): self.owner.plan = "users-pr-inappy" self.owner.plan_user_count = 10 self.owner.plan_auto_activate = False self.owner.save() - response = self._send_event( + self._send_event( payload={ "type": "customer.subscription.updated", "data": { @@ -379,11 +384,18 @@ def test_customer_subscription_updated_does_not_change_subscription_if_there_is_ assert self.owner.plan == "users-pr-inappy" assert self.owner.plan_user_count == 10 assert self.owner.plan_auto_activate == False - upm_mock.assert_called_once_with(self.owner, "pm_1LhiRsGlVGuVgOrkQguJXdeV") + pm_mock.assert_called_once_with( + "pm_1LhiRsGlVGuVgOrkQguJXdeV", customer=self.owner.stripe_customer_id + ) + c_mock.assert_called_once_with( + self.owner.stripe_customer_id, + invoice_settings={"default_payment_method": "pm_1LhiRsGlVGuVgOrkQguJXdeV"}, + ) - @patch("services.billing.StripeService.update_payment_method") + @patch("services.billing.stripe.PaymentMethod.attach") + @patch("services.billing.stripe.Customer.modify") def test_customer_subscription_updated_sets_free_and_deactivates_all_repos_if_incomplete_expired( - self, upm_mock + self, c_mock, pm_mock ): self.owner.plan = "users-pr-inappy" self.owner.plan_user_count = 10 @@ -395,7 +407,7 @@ def test_customer_subscription_updated_sets_free_and_deactivates_all_repos_if_in RepositoryFactory(author=self.owner, activated=True, active=True) assert self.owner.repository_set.count() == 3 - response = self._send_event( + self._send_event( payload={ "type": "customer.subscription.updated", "data": { @@ -419,14 +431,23 @@ def test_customer_subscription_updated_sets_free_and_deactivates_all_repos_if_in assert self.owner.plan == PlanName.BASIC_PLAN_NAME.value assert self.owner.plan_user_count == 1 assert self.owner.plan_auto_activate == False - assert self.owner.stripe_subscription_id == None + assert self.owner.stripe_subscription_id is None assert ( self.owner.repository_set.filter(active=True, activated=True).count() == 0 ) - upm_mock.assert_called_once_with(self.owner, "pm_1LhiRsGlVGuVgOrkQguJXdeV") + pm_mock.assert_called_once_with( + "pm_1LhiRsGlVGuVgOrkQguJXdeV", customer=self.owner.stripe_customer_id + ) + c_mock.assert_called_once_with( + self.owner.stripe_customer_id, + invoice_settings={"default_payment_method": "pm_1LhiRsGlVGuVgOrkQguJXdeV"}, + ) - @patch("services.billing.StripeService.update_payment_method") - def test_customer_subscription_updated_sets_fields_on_success(self, upm_mock): + @patch("services.billing.stripe.PaymentMethod.attach") + @patch("services.billing.stripe.Customer.modify") + def test_customer_subscription_updated_sets_fields_on_success( + self, c_mock, pm_mock + ): self.owner.plan = "users-free" self.owner.plan_user_count = 5 self.owner.plan_auto_activate = False @@ -456,7 +477,13 @@ def test_customer_subscription_updated_sets_fields_on_success(self, upm_mock): assert self.owner.plan == plan_name assert self.owner.plan_user_count == quantity assert self.owner.plan_auto_activate == True - upm_mock.assert_called_once_with(self.owner, "pm_1LhiRsGlVGuVgOrkQguJXdeV") + pm_mock.assert_called_once_with( + "pm_1LhiRsGlVGuVgOrkQguJXdeV", customer=self.owner.stripe_customer_id + ) + c_mock.assert_called_once_with( + self.owner.stripe_customer_id, + invoice_settings={"default_payment_method": "pm_1LhiRsGlVGuVgOrkQguJXdeV"}, + ) @patch("services.billing.stripe.Subscription.retrieve") def test_subscription_schedule_released_updates_owner_with_existing_subscription( @@ -572,7 +599,7 @@ def test_checkout_session_completed_sets_stripe_customer_id(self): expected_id = "fhjtwoo40" - response = self._send_event( + self._send_event( payload={ "type": "checkout.session.completed", "data": { @@ -590,7 +617,7 @@ def test_checkout_session_completed_sets_stripe_customer_id(self): @patch("billing.views.stripe.Subscription.modify") def test_customer_update_but_not_payment_method(self, subscription_modify_mock): payment_method = "pm_123" - response = self._send_event( + self._send_event( payload={ "type": "customer.updated", "data": { @@ -610,7 +637,7 @@ def test_customer_update_but_not_payment_method(self, subscription_modify_mock): def test_customer_update_payment_method(self, subscription_modify_mock): payment_method = "pm_123" old_payment_method = "pm_321" - response = self._send_event( + self._send_event( payload={ "type": "customer.updated", "data": { diff --git a/billing/views.py b/billing/views.py index 5b50385c0f..38a97ddf50 100644 --- a/billing/views.py +++ b/billing/views.py @@ -9,7 +9,6 @@ from codecov_auth.models import Owner from plan.service import PlanService -from services.billing import BillingService from .constants import StripeHTTPHeaders, StripeWebhookEvents @@ -49,7 +48,7 @@ def invoice_payment_succeeded(self, invoice: stripe.Invoice) -> None: def invoice_payment_failed(self, invoice: stripe.Invoice) -> None: log.info( - "Invoice Payment Failed - Setting Deliquency status True", + "Invoice Payment Failed - Setting Delinquency status True", extra=dict( stripe_customer_id=invoice.customer, stripe_subscription_id=invoice.subscription, @@ -131,8 +130,8 @@ def subscription_schedule_updated( def subscription_schedule_released(self, schedule: stripe.Subscription) -> None: subscription = stripe.Subscription.retrieve(schedule["released_subscription"]) - owner = Owner.objects.get(ownerid=subscription.metadata["obo_organization"]) - requesting_user_id = subscription.metadata["obo"] + owner = Owner.objects.get(ownerid=subscription.metadata.get("obo_organization")) + requesting_user_id = subscription.metadata.get("obo") plan_service = PlanService(current_org=owner) sub_item_plan_id = subscription.plan.id @@ -166,7 +165,7 @@ def customer_subscription_created(self, subscription: stripe.Subscription) -> No "Subscription created, but missing plan_id", extra=dict( stripe_customer_id=subscription.customer, - ownerid=subscription.metadata["obo_organization"], + ownerid=subscription.metadata.get("obo_organization"), subscription_plan=subscription.plan, ), ) @@ -177,7 +176,7 @@ def customer_subscription_created(self, subscription: stripe.Subscription) -> No "Subscription creation requested for invalid plan", extra=dict( stripe_customer_id=subscription.customer, - ownerid=subscription.metadata["obo_organization"], + ownerid=subscription.metadata.get("obo_organization"), plan_id=sub_item_plan_id, ), ) @@ -190,12 +189,12 @@ def customer_subscription_created(self, subscription: stripe.Subscription) -> No extra=dict( stripe_customer_id=subscription.customer, stripe_subscription_id=subscription.id, - ownerid=subscription.metadata["obo_organization"], + ownerid=subscription.metadata.get("obo_organization"), plan=plan_name, quantity=subscription.quantity, ), ) - owner = Owner.objects.get(ownerid=subscription.metadata["obo_organization"]) + owner = Owner.objects.get(ownerid=subscription.metadata.get("obo_organization")) owner.stripe_subscription_id = subscription.id owner.stripe_customer_id = subscription.customer owner.save() @@ -227,9 +226,14 @@ def customer_subscription_updated(self, subscription: stripe.Subscription) -> No # This hook will be called after a checkout session completes, updating the subscription created # with it default_payment_method = subscription.default_payment_method - if default_payment_method: - billing = BillingService(requesting_user=owner) - billing.update_payment_method(owner, default_payment_method) + if default_payment_method and owner.stripe_customer_id is not None: + stripe.PaymentMethod.attach( + default_payment_method, customer=owner.stripe_customer_id + ) + stripe.Customer.modify( + owner.stripe_customer_id, + invoice_settings={"default_payment_method": default_payment_method}, + ) subscription_schedule_id = subscription.schedule plan_service = PlanService(current_org=owner) diff --git a/codecov.yml b/codecov.yml index f8d95eca50..029ddec033 100644 --- a/codecov.yml +++ b/codecov.yml @@ -14,3 +14,6 @@ flag_management: codecov: require_ci_to_pass: false + +test_analytics: + flake_detection: true diff --git a/codecov/commands/tests/test_base.py b/codecov/commands/tests/test_base.py index 86923c52c2..92e75858e2 100644 --- a/codecov/commands/tests/test_base.py +++ b/codecov/commands/tests/test_base.py @@ -10,11 +10,11 @@ def test_base_command(): command = BaseCommand(None, "github") # test command is properly init - assert command.current_owner == None + assert command.current_owner is None assert command.service == "github" # test get_interactor interactor = command.get_interactor(BaseInteractor) - assert interactor.current_owner == None + assert interactor.current_owner is None assert interactor.current_user == AnonymousUser() assert interactor.service == "github" # test get_command diff --git a/codecov/commands/tests/test_executor.py b/codecov/commands/tests/test_executor.py index 5289e8225b..6c54bfcbae 100644 --- a/codecov/commands/tests/test_executor.py +++ b/codecov/commands/tests/test_executor.py @@ -16,11 +16,11 @@ def test_get_executor_from_request(): request.user = AnonymousUser() executor = get_executor_from_request(request) assert executor.service == "github" - assert executor.current_owner == None + assert executor.current_owner is None def test_get_executor_from_command(): command = OwnerCommands(None, "github") executor = get_executor_from_command(command) assert executor.service == "github" - assert executor.current_owner == None + assert executor.current_owner is None diff --git a/codecov/settings_base.py b/codecov/settings_base.py index eaffdf634b..e43d7e1879 100644 --- a/codecov/settings_base.py +++ b/codecov/settings_base.py @@ -1,7 +1,5 @@ import os -from urllib.parse import urlparse -import django_prometheus import sentry_sdk from corsheaders.defaults import default_headers from sentry_sdk.integrations.celery import CeleryIntegration @@ -9,13 +7,13 @@ from sentry_sdk.integrations.httpx import HttpxIntegration from sentry_sdk.integrations.redis import RedisIntegration from sentry_sdk.scrubber import DEFAULT_DENYLIST, EventScrubber +from shared.django_apps.db_settings import * +from shared.license import startup_license_logging from utils.config import SettingsModule, get_config, get_settings_module SECRET_KEY = get_config("django", "secret_key", default="*") -AUTH_USER_MODEL = "codecov_auth.User" - # Application definition INSTALLED_APPS = [ @@ -52,6 +50,8 @@ # New Shared Models "shared.django_apps.rollouts", "shared.django_apps.user_measurements", + "shared.django_apps.codecov_metrics", + "shared.django_apps.bundle_analysis", ] MIDDLEWARE = [ @@ -95,171 +95,10 @@ # Database # https://docs.djangoproject.com/en/2.1/ref/settings/#databases -db_url = get_config("services", "database_url") -if db_url: - db_conf = urlparse(db_url) - DATABASE_USER = db_conf.username - DATABASE_NAME = db_conf.path.replace("/", "") - DATABASE_PASSWORD = db_conf.password - DATABASE_HOST = db_conf.hostname - DATABASE_PORT = db_conf.port -else: - DATABASE_USER = get_config("services", "database", "username", default="postgres") - DATABASE_NAME = get_config("services", "database", "name", default="postgres") - DATABASE_PASSWORD = get_config( - "services", "database", "password", default="postgres" - ) - DATABASE_HOST = get_config("services", "database", "host", default="postgres") - DATABASE_PORT = get_config("services", "database", "port", default=5432) - -DATABASE_READ_REPLICA_ENABLED = get_config( - "setup", "database", "read_replica_enabled", default=False -) - -db_read_url = get_config("services", "database_read_url") -if db_read_url: - db_conf = urlparse(db_read_url) - DATABASE_READ_USER = db_conf.username - DATABASE_READ_NAME = db_conf.path.replace("/", "") - DATABASE_READ_PASSWORD = db_conf.password - DATABASE_READ_HOST = db_conf.hostname - DATABASE_READ_PORT = db_conf.port -else: - DATABASE_READ_USER = get_config( - "services", "database_read", "username", default="postgres" - ) - DATABASE_READ_NAME = get_config( - "services", "database_read", "name", default="postgres" - ) - DATABASE_READ_PASSWORD = get_config( - "services", "database_read", "password", default="postgres" - ) - DATABASE_READ_HOST = get_config( - "services", "database_read", "host", default="postgres" - ) - DATABASE_READ_PORT = get_config("services", "database_read", "port", default=5432) - GRAPHQL_QUERY_COST_THRESHOLD = get_config( "setup", "graphql", "query_cost_threshold", default=10000 ) -TIMESERIES_ENABLED = get_config("setup", "timeseries", "enabled", default=False) -TIMESERIES_REAL_TIME_AGGREGATES = get_config( - "setup", "timeseries", "real_time_aggregates", default=False -) - -timeseries_database_url = get_config("services", "timeseries_database_url") -if timeseries_database_url: - timeseries_database_conf = urlparse(timeseries_database_url) - TIMESERIES_DATABASE_USER = timeseries_database_conf.username - TIMESERIES_DATABASE_NAME = timeseries_database_conf.path.replace("/", "") - TIMESERIES_DATABASE_PASSWORD = timeseries_database_conf.password - TIMESERIES_DATABASE_HOST = timeseries_database_conf.hostname - TIMESERIES_DATABASE_PORT = timeseries_database_conf.port -else: - TIMESERIES_DATABASE_USER = get_config( - "services", "timeseries_database", "username", default="postgres" - ) - TIMESERIES_DATABASE_NAME = get_config( - "services", "timeseries_database", "name", default="postgres" - ) - TIMESERIES_DATABASE_PASSWORD = get_config( - "services", "timeseries_database", "password", default="postgres" - ) - TIMESERIES_DATABASE_HOST = get_config( - "services", "timeseries_database", "host", default="timescale" - ) - TIMESERIES_DATABASE_PORT = get_config( - "services", "timeseries_database", "port", default=5432 - ) - -TIMESERIES_DATABASE_READ_REPLICA_ENABLED = get_config( - "setup", "timeseries", "read_replica_enabled", default=False -) - -timeseries_database_read_url = get_config("services", "timeseries_database_read_url") -if timeseries_database_read_url: - timeseries_database_conf = urlparse(timeseries_database_read_url) - TIMESERIES_DATABASE_READ_USER = timeseries_database_conf.username - TIMESERIES_DATABASE_READ_NAME = timeseries_database_conf.path.replace("/", "") - TIMESERIES_DATABASE_READ_PASSWORD = timeseries_database_conf.password - TIMESERIES_DATABASE_READ_HOST = timeseries_database_conf.hostname - TIMESERIES_DATABASE_READ_PORT = timeseries_database_conf.port -else: - TIMESERIES_DATABASE_READ_USER = get_config( - "services", "timeseries_database_read", "username", default="postgres" - ) - TIMESERIES_DATABASE_READ_NAME = get_config( - "services", "timeseries_database_read", "name", default="postgres" - ) - TIMESERIES_DATABASE_READ_PASSWORD = get_config( - "services", "timeseries_database_read", "password", default="postgres" - ) - TIMESERIES_DATABASE_READ_HOST = get_config( - "services", "timeseries_database_read", "host", default="timescale" - ) - TIMESERIES_DATABASE_READ_PORT = get_config( - "services", "timeseries_database_read", "port", default=5432 - ) - -# this is the time in seconds django decides to keep the connection open after the request -# the default is 0 seconds, meaning django closes the connection after every request -# https://docs.djangoproject.com/en/3.1/ref/settings/#conn-max-age -CONN_MAX_AGE = int(get_config("services", "database", "conn_max_age", default=0)) - -DATABASES = { - "default": { - "ENGINE": "psqlextra.backend", - "NAME": DATABASE_NAME, - "USER": DATABASE_USER, - "PASSWORD": DATABASE_PASSWORD, - "HOST": DATABASE_HOST, - "PORT": DATABASE_PORT, - "CONN_MAX_AGE": CONN_MAX_AGE, - } -} - -if DATABASE_READ_REPLICA_ENABLED: - DATABASES["default_read"] = { - "ENGINE": "psqlextra.backend", - "NAME": DATABASE_READ_NAME, - "USER": DATABASE_READ_USER, - "PASSWORD": DATABASE_READ_PASSWORD, - "HOST": DATABASE_READ_HOST, - "PORT": DATABASE_READ_PORT, - "CONN_MAX_AGE": CONN_MAX_AGE, - } - -if TIMESERIES_ENABLED: - DATABASES["timeseries"] = { - "ENGINE": "django_prometheus.db.backends.postgresql", - "NAME": TIMESERIES_DATABASE_NAME, - "USER": TIMESERIES_DATABASE_USER, - "PASSWORD": TIMESERIES_DATABASE_PASSWORD, - "HOST": TIMESERIES_DATABASE_HOST, - "PORT": TIMESERIES_DATABASE_PORT, - "CONN_MAX_AGE": CONN_MAX_AGE, - } - - if TIMESERIES_DATABASE_READ_REPLICA_ENABLED: - DATABASES["timeseries_read"] = { - "ENGINE": "django_prometheus.db.backends.postgresql", - "NAME": TIMESERIES_DATABASE_READ_NAME, - "USER": TIMESERIES_DATABASE_READ_USER, - "PASSWORD": TIMESERIES_DATABASE_READ_PASSWORD, - "HOST": TIMESERIES_DATABASE_READ_HOST, - "PORT": TIMESERIES_DATABASE_READ_PORT, - "CONN_MAX_AGE": CONN_MAX_AGE, - } - -# See https://django-postgres-extra.readthedocs.io/en/master/settings.html -POSTGRES_EXTRA_DB_BACKEND_BASE: "django_prometheus.db.backends.postgresql" # type: ignore - -# Allows to use the pgpartition command -PSQLEXTRA_PARTITIONING_MANAGER = ( - "shared.django_apps.user_measurements.partitioning.manager" -) - DATABASE_ROUTERS = ["codecov.db.DatabaseRouter"] # Password validation @@ -320,8 +159,11 @@ "https://cdn.jsdelivr.net/npm/graphql-playground-react/build/static/js/middleware.js", "https://cdn.jsdelivr.net/npm/graphql-playground-react/build/favicon.png", "https://cdn.jsdelivr.net/npm/graphql-playground-react/build/static/css/index.css", + "blob:", ] +CSP_WORKER_SRC = ["'self'", "blob:"] + # Internationalization # https://docs.djangoproject.com/en/2.1/topics/i18n/ @@ -472,10 +314,6 @@ GITLAB_ENTERPRISE_URL = get_config("gitlab_enterprise", "url") GITLAB_ENTERPRISE_API_URL = get_config("gitlab_enterprise", "api_url") -SEGMENT_API_KEY = get_config("setup", "segment", "key", default=None) -SEGMENT_ENABLED = get_config("setup", "segment", "enabled", default=False) and not bool( - get_config("setup", "enterprise_license", default=False) -) CORS_ALLOW_HEADERS = ( list(default_headers) @@ -501,7 +339,7 @@ CORS_ALLOWED_ORIGIN_REGEXES = get_config( "setup", "api_cors_allowed_origin_regexes", default=[] ) -CORS_ALLOWED_ORIGINS = [] +CORS_ALLOWED_ORIGINS: list[str] = [] GRAPHQL_PLAYGROUND = True @@ -558,11 +396,9 @@ set_default_pii=True, environment=SENTRY_ENV, traces_sample_rate=SENTRY_SAMPLE_RATE, - _experiments={ - "profiles_sample_rate": float( - os.environ.get("SERVICES__SENTRY__PROFILE_SAMPLE_RATE", 0.01) - ), - }, + profiles_sample_rate=float( + os.environ.get("SERVICES__SENTRY__PROFILE_SAMPLE_RATE", 0.01) + ), ) if os.getenv("CLUSTER_ENV"): sentry_sdk.set_tag("cluster", os.getenv("CLUSTER_ENV")) @@ -583,3 +419,6 @@ "reports": "shared.django_apps.reports.migrations", "legacy_migrations": "shared.django_apps.legacy_migrations.migrations", } + +# to aid in debugging, print out this info on startup. If no license, prints nothing +startup_license_logging() diff --git a/codecov/settings_dev.py b/codecov/settings_dev.py index ce8e1c2a3d..e58164c69f 100644 --- a/codecov/settings_dev.py +++ b/codecov/settings_dev.py @@ -1,5 +1,3 @@ -import logging - from .settings_base import * # Remove CSP headers from local development build to allow GQL Playground diff --git a/codecov/settings_enterprise.py b/codecov/settings_enterprise.py index 964f6a6c48..7af6c9db00 100644 --- a/codecov/settings_enterprise.py +++ b/codecov/settings_enterprise.py @@ -1,7 +1,6 @@ import os -from urllib.parse import urlparse -from utils.config import get_config, get_settings_module +from utils.config import get_config from .settings_base import * diff --git a/codecov/settings_prod.py b/codecov/settings_prod.py index 6b82842220..978656c84e 100644 --- a/codecov/settings_prod.py +++ b/codecov/settings_prod.py @@ -53,7 +53,9 @@ # Redirect after authentication, update this setting with care CORS_ALLOWED_ORIGIN_REGEXES = [] -DATA_UPLOAD_MAX_MEMORY_SIZE = 15000000 +# 25MB in bytes +DATA_UPLOAD_MAX_MEMORY_SIZE = 26214400 + SILENCED_SYSTEM_CHECKS = ["urls.W002"] # Reinforcing the Cookie SameSite configuration to be sure it's Lax in prod diff --git a/codecov/settings_staging.py b/codecov/settings_staging.py index 95317aa670..96062e1efa 100644 --- a/codecov/settings_staging.py +++ b/codecov/settings_staging.py @@ -62,7 +62,8 @@ "http://localhost:3000", ] -DATA_UPLOAD_MAX_MEMORY_SIZE = 15000000 +# 25MB in bytes +DATA_UPLOAD_MAX_MEMORY_SIZE = 26214400 # Same site is set to none on Staging as we want to be able to call the API # From Netlify preview deploy diff --git a/codecov/tests/test_urls.py b/codecov/tests/test_urls.py index a1c3c9415c..2ec92c2194 100644 --- a/codecov/tests/test_urls.py +++ b/codecov/tests/test_urls.py @@ -1,24 +1,8 @@ -import json - -import pytest -from django.conf import settings from django.test import TestCase from django.test.client import Client class ViewTest(TestCase): - def test_redirect_app(self): - client = Client() - response = client.get( - "/redirect_app/gh/codecov/codecov.io/settings", follow=False - ) - self.assertRedirects( - response, - "http://localhost:3000/gh/codecov/codecov.io/settings", - 302, - fetch_redirect_response=False, - ) - def test_health(self): client = Client() response = client.get("") diff --git a/codecov/urls.py b/codecov/urls.py index b524a0952b..b718d4355d 100644 --- a/codecov/urls.py +++ b/codecov/urls.py @@ -34,7 +34,6 @@ views.OwnerAutoCompleteSearch.as_view(), name="admin-owner-autocomplete", ), - re_path(r"^redirect_app", views.redirect_app), # /monitoring/metrics will be a public route unless you take steps at a # higher level to null-route or redirect it. path("monitoring/", include("django_prometheus.urls")), diff --git a/codecov/views.py b/codecov/views.py index 4ee7f30fd8..08051037fc 100644 --- a/codecov/views.py +++ b/codecov/views.py @@ -1,7 +1,6 @@ from dal import autocomplete -from django.conf import settings from django.db import connection -from django.http import HttpResponse, HttpResponseRedirect +from django.http import HttpResponse from codecov_auth.models import Owner, Service from core.models import Constants, Repository @@ -17,21 +16,13 @@ def _get_version(): def health(request): - # will raise if connection cannot be estabilished + # will raise if connection cannot be established connection.ensure_connection() version = _get_version() return HttpResponse("%s is live!" % version.value) -def redirect_app(request): - """ - This view is intended to be used as part of the frontend migration to redirect traffic from legacy urls to urls - """ - app_domain = settings.CODECOV_DASHBOARD_URL - return HttpResponseRedirect(app_domain + request.path.replace("/redirect_app", "")) - - SERVICE_CHOICES = dict(Service.choices) diff --git a/codecov_auth/admin.py b/codecov_auth/admin.py index f2529b76a9..80575e4fda 100644 --- a/codecov_auth/admin.py +++ b/codecov_auth/admin.py @@ -1,25 +1,37 @@ -from typing import Optional +import logging +from datetime import timedelta +from typing import Optional, Sequence import django.forms as forms from django.conf import settings from django.contrib import admin, messages from django.contrib.admin.models import LogEntry +from django.db.models import OuterRef, Subquery from django.db.models.fields import BLANK_CHOICE_DASH from django.forms import CheckboxInput, Select -from django.http import HttpRequest, HttpResponseRedirect +from django.http import HttpRequest from django.shortcuts import redirect, render +from django.utils import timezone from django.utils.html import format_html +from shared.django_apps.codecov_auth.models import ( + Account, + AccountsUsers, + InvoiceBilling, + StripeBilling, +) from codecov.admin import AdminMixin from codecov.commands.exceptions import ValidationError from codecov_auth.helpers import History -from codecov_auth.models import OrganizationLevelToken, Owner, SentryUser, User +from codecov_auth.models import OrganizationLevelToken, Owner, SentryUser, Session, User from codecov_auth.services.org_level_token_service import OrgLevelTokenService from plan.constants import USER_PLAN_REPRESENTATIONS from plan.service import PlanService from services.task import TaskService from utils.services import get_short_service_name +log = logging.getLogger(__name__) + class ExtendTrialForm(forms.Form): end_date = forms.DateTimeField( @@ -96,13 +108,43 @@ def impersonate_owner(self, request, queryset): impersonate_owner.short_description = "Impersonate the selected owner" +class AccountsUsersInline(admin.TabularInline): + model = AccountsUsers + max_num = 10 + extra = 1 + verbose_name_plural = "Accounts Users (click save to commit changes)" + verbose_name = "Account User" + can_delete = False + can_edit = False + + +class OwnerUserInline(admin.TabularInline): + model = Owner + max_num = 5 + extra = 0 + verbose_name_plural = "Owners (read only)" + verbose_name = "Owner" + exclude = ("oauth_token",) + can_delete = False + + readonly_fields = [ + "name", + "username", + "email", + "service", + "student", + ] + + fields = [] + readonly_fields + + @admin.register(User) class UserAdmin(AdminMixin, admin.ModelAdmin): list_display = ( "name", "email", ) - readonly_fields = [] + inlines = [AccountsUsersInline, OwnerUserInline] search_fields = ( "name__iregex", "email__iregex", @@ -190,14 +232,336 @@ def has_add_permission(self, request: HttpRequest, obj: Optional[Owner]) -> bool return (not has_token) and request.user.is_staff +class InvoiceBillingInline(admin.StackedInline): + model = InvoiceBilling + extra = 0 + can_delete = False + verbose_name_plural = "Invoice Billing" + verbose_name = "Invoice Billing (click save to commit changes)" + + +@admin.register(InvoiceBilling) +class InvoiceBillingAdmin(AdminMixin, admin.ModelAdmin): + list_display = ("id", "account", "is_active") + search_fields = ( + "account__name", + "account__id__iexact", + "id__iexact", + "account_manager", + ) + search_help_text = ( + "Search by account name, account id (exact), id (exact), or account_manager" + ) + autocomplete_fields = ("account",) + + readonly_fields = [ + "id", + "created_at", + "updated_at", + ] + + fields = readonly_fields + [ + "account", + "account_manager", + "invoice_notes", + "is_active", + ] + + def get_form(self, request, obj=None, **kwargs): + form = super().get_form(request, obj, **kwargs) + field = form.base_fields["account"] + field.widget.can_add_related = False + field.widget.can_change_related = False + field.widget.can_delete_related = False + return form + + +class StripeBillingInline(admin.StackedInline): + can_delete = False + extra = 0 + model = StripeBilling + verbose_name_plural = "Stripe Billing" + verbose_name = "Stripe Billing (click save to commit changes)" + + +@admin.register(StripeBilling) +class StripeBillingAdmin(AdminMixin, admin.ModelAdmin): + list_display = ("id", "account", "is_active") + search_fields = ( + "account__name", + "account__id__iexact", + "id__iexact", + "customer_id__iexact", + "subscription_id__iexact", + ) + search_help_text = "Search by account name, account id (exact), id (exact), customer_id (exact), or subscription_id (exact)" + autocomplete_fields = ("account",) + + readonly_fields = [ + "id", + "created_at", + "updated_at", + ] + + fields = readonly_fields + [ + "account", + "customer_id", + "subscription_id", + "is_active", + ] + + def get_form(self, request, obj=None, **kwargs): + form = super().get_form(request, obj, **kwargs) + field = form.base_fields["account"] + field.widget.can_add_related = False + field.widget.can_change_related = False + field.widget.can_delete_related = False + return form + + +class OwnerOrgInline(admin.TabularInline): + model = Owner + max_num = 100 + extra = 0 + verbose_name_plural = "Organizations (read only)" + verbose_name = "Organization" + exclude = ("oauth_token",) + can_delete = False + + readonly_fields = [ + "name", + "username", + "plan", + "plan_activated_users", + "service", + ] + + fields = [] + readonly_fields + + +def find_and_remove_stale_users( + orgs: Sequence[Owner], date_threshold: timedelta | None = None +) -> tuple[set[int], set[int]]: + """ + This functions finds all the stale `plan_activated_users` in any of the given `orgs`. + + It then removes all those stale users from the given `orgs`, + returning the set of stale users (`ownerid`), and the set of `orgs` that were updated (`ownerid`). + + A user is considered stale if it had no API or login `Session` or any opened PR within `date_threshold`. + If no `date_threshold` is given, it defaults to *90 days*. + """ + + active_users = set() + for org in orgs: + active_users.update(set(org.plan_activated_users)) + + if not active_users: + return (set(), set()) + + # NOTE: the `annotate_last_pull_timestamp` manager/queryset method does the same `annotate` with `Subquery`. + sessions = Session.objects.filter(owner=OuterRef("pk")).order_by("-lastseen") + resolved_users = list( + Owner.objects.filter(ownerid__in=active_users) + .annotate(latest_session=Subquery(sessions.values("lastseen")[:1])) + .annotate_last_pull_timestamp() + .values_list("ownerid", "latest_session", "last_pull_timestamp", named=True) + ) + + threshold = timezone.now() - (date_threshold or timedelta(days=90)) + + def is_stale(user: dict) -> bool: + return (user.latest_session is None or user.latest_session < threshold) and ( + # NOTE: `last_pull_timestamp` is not timezone-aware, so we explicitly compare without timezones here + user.last_pull_timestamp is None + or user.last_pull_timestamp.replace(tzinfo=None) + < threshold.replace(tzinfo=None) + ) + + stale_users = {user.ownerid for user in resolved_users if is_stale(user)} + + # TODO: the existing stale user cleanup script clears the `oauth_token`, though the reason for that is not clear? + # Owner.objects.filter(ownerid__in=stale_users).update(oauth_token=None) + + affected_orgs = { + org for org in orgs if stale_users.intersection(set(org.plan_activated_users)) + } + + if not affected_orgs: + return (set(), set()) + + # TODO: it might make sense to run all this within a transaction and locking the `affected_orgs` for update, + # as we have a slight chance of races between querying the `orgs` at the very beginning and updating them here: + for org in affected_orgs: + org.plan_activated_users = list( + set(org.plan_activated_users).difference(stale_users) + ) + Owner.objects.bulk_update(affected_orgs, ["plan_activated_users"]) + + return (stale_users, {org.ownerid for org in affected_orgs}) + + +@admin.register(Account) +class AccountAdmin(AdminMixin, admin.ModelAdmin): + list_display = ("name", "is_active", "organizations_count", "all_user_count") + search_fields = ("name__iregex", "id") + search_help_text = "Search by name (can use regex), or id (exact)" + inlines = [OwnerOrgInline, StripeBillingInline, InvoiceBillingInline] + actions = ["seat_check", "link_users_to_account", "deactivate_stale_users"] + + readonly_fields = ["id", "created_at", "updated_at", "users"] + + fields = readonly_fields + [ + "name", + "is_active", + "plan", + "plan_seat_count", + "free_seat_count", + "plan_auto_activate", + "is_delinquent", + ] + + @admin.action(description="Deactivate all stale `plan_activated_users`") + def deactivate_stale_users(self, request, queryset): + orgs = [org for account in queryset for org in account.organizations.all()] + stale_users, updated_orgs = find_and_remove_stale_users(orgs) + + if not stale_users or not updated_orgs: + self.message_user( + request, + "No stale users found in selected accounts / organizations.", + messages.INFO, + ) + else: + self.message_user( + request, + f"Removed {len(stale_users)} stale users from {len(updated_orgs)} affected organizations.", + messages.SUCCESS, + ) + + @admin.action( + description="Count current plan_activated_users across all Organizations" + ) + def seat_check(self, request, queryset): + self.link_users_to_account(request, queryset, dry_run=True) + + @admin.action(description="Link Users to Account") + def link_users_to_account(self, request, queryset, dry_run=False): + for account in queryset: + account_plan_activated_user_ownerids = set() + for org in account.organizations.all(): + account_plan_activated_user_ownerids.update( + set(org.plan_activated_users) + ) + + account_plan_activated_user_owners = Owner.objects.filter( + ownerid__in=account_plan_activated_user_ownerids + ).prefetch_related("user") + + non_student_count = account_plan_activated_user_owners.exclude( + student=True + ).count() + total_seats_for_account = account.plan_seat_count + account.free_seat_count + if non_student_count > total_seats_for_account: + self.message_user( + request, + f"Request failed: Account plan does not have enough seats; " + f"current plan activated users (non-students): {non_student_count}, total seats for account: {total_seats_for_account}", + messages.ERROR, + ) + return + if dry_run: + self.message_user( + request, + f"Request succeeded: Account plan has enough seats! " + f"current plan activated users (non-students): {non_student_count}, total seats for account: {total_seats_for_account}", + messages.SUCCESS, + ) + return + + owners_without_user_objects = account_plan_activated_user_owners.filter( + user__isnull=True + ) + owners_with_new_user_objects = [] + for userless_owner in owners_without_user_objects: + new_user = User.objects.create( + name=userless_owner.name, email=userless_owner.email + ) + userless_owner.user = new_user + owners_with_new_user_objects.append(userless_owner) + total = Owner.objects.bulk_update(owners_with_new_user_objects, ["user"]) + self.message_user( + request, + f"Created a User for {total} Owners", + messages.INFO, + ) + if total > 0: + log.info( + f"Admin operation for {account} - Created a User for {total} Owners", + extra=dict( + owners_with_new_user_objects=[ + str(owner) for owner in owners_with_new_user_objects + ], + account_id=account.id, + ), + ) + + # redo this query to get all Owners and Users + account_plan_activated_user_owners = Owner.objects.filter( + ownerid__in=account_plan_activated_user_ownerids + ).prefetch_related("user") + + already_linked_account_users = AccountsUsers.objects.filter(account=account) + + not_yet_linked_owners = account_plan_activated_user_owners.exclude( + user_id__in=already_linked_account_users.values_list( + "user_id", flat=True + ) + ) + + account_users_that_should_be_unlinked = ( + already_linked_account_users.exclude( + user_id__in=account_plan_activated_user_owners.values_list( + "user_id", flat=True + ) + ) + ) + deleted_ids_for_log = list( + account_users_that_should_be_unlinked.values_list("id", flat=True) + ) + deleted_count, _ = account_users_that_should_be_unlinked.delete() + + new_accounts_users = [] + for owner in not_yet_linked_owners: + new_account_user = AccountsUsers( + user_id=owner.user_id, account_id=account.id + ) + new_accounts_users.append(new_account_user) + total = AccountsUsers.objects.bulk_create(new_accounts_users) + self.message_user( + request, + f"Created {len(total)} AccountsUsers, removed {deleted_count} AccountsUsers", + messages.SUCCESS, + ) + if len(total) > 0 or deleted_count > 0: + log.info( + f"Admin operation for {account} - Created {len(total)} AccountsUsers, removed {deleted_count} AccountsUsers", + extra=dict( + new_accounts_users=total, + removed_accounts_users_ids=deleted_ids_for_log, + account_id=account.id, + ), + ) + + @admin.register(Owner) class OwnerAdmin(AdminMixin, admin.ModelAdmin): exclude = ("oauth_token",) list_display = ("name", "username", "email", "service") readonly_fields = [] - search_fields = ("name__iregex", "username__iregex", "email__iregex") + search_fields = ("name__iregex", "username__iregex", "email__iregex", "ownerid") actions = [impersonate_owner, extend_trial] - autocomplete_fields = ("bot",) + autocomplete_fields = ("bot", "account") inlines = [OrgUploadTokenInline] readonly_fields = ( @@ -215,7 +579,6 @@ class OwnerAdmin(AdminMixin, admin.ModelAdmin): "cache", "free", "invoice_details", - "delinquent", "yaml", "updatestamp", "permission", @@ -235,12 +598,14 @@ class OwnerAdmin(AdminMixin, admin.ModelAdmin): "plan_user_count", "plan_activated_users", "uses_invoice", + "delinquent", "integration_id", "bot", "stripe_customer_id", "stripe_subscription_id", "organizations", "max_upload_limit", + "account", ) def get_form(self, request, obj=None, change=False, **kwargs): @@ -252,10 +617,14 @@ def get_form(self, request, obj=None, change=False, **kwargs): form.base_fields["uses_invoice"].widget = CheckboxInput() is_superuser = request.user.is_superuser - if not is_superuser: form.base_fields["staff"].disabled = True + field = form.base_fields["account"] + field.widget.can_add_related = False + field.widget.can_change_related = False + field.widget.can_delete_related = False + return form def has_add_permission(self, _, obj=None): @@ -319,3 +688,26 @@ def has_change_permission(self, request, obj=None): def has_delete_permission(self, request, obj=None): return False + + +@admin.register(AccountsUsers) +class AccountsUsersAdmin(AdminMixin, admin.ModelAdmin): + list_display = ("id", "user", "account") + search_fields = ( + "account__name", + "account__id__iexact", + "id__iexact", + "user__id__iexact", + "user__name", + "user__email", + ) + search_help_text = "Search by account name, account id (exact), id (exact), user id (exact), user's name or email" + autocomplete_fields = ("account", "user") + + readonly_fields = [ + "id", + "created_at", + "updated_at", + ] + + fields = readonly_fields + ["account", "user"] diff --git a/codecov_auth/apps.py b/codecov_auth/apps.py index 53af2f4afe..ee2fabc65d 100644 --- a/codecov_auth/apps.py +++ b/codecov_auth/apps.py @@ -5,4 +5,4 @@ class CodecovAuthConfig(AppConfig): name = "codecov_auth" def ready(self): - import codecov_auth.signals + import codecov_auth.signals # noqa: F401 diff --git a/codecov_auth/authentication/helpers.py b/codecov_auth/authentication/helpers.py new file mode 100644 index 0000000000..f5272395d0 --- /dev/null +++ b/codecov_auth/authentication/helpers.py @@ -0,0 +1,28 @@ +import re +from typing import NamedTuple + +from django.http import HttpRequest + + +class UploadInfo(NamedTuple): + service: str + encoded_slug: str + commitid: str | None + + +def get_upload_info_from_request_path(request: HttpRequest) -> UploadInfo | None: + path_info = request.get_full_path_info() + # The repo part comes from https://stackoverflow.com/a/22312124 + upload_views_prefix_regex = ( + r"\/upload\/(\w+)\/([\w\.@:_/\-~]+)\/commits(?:\/([a-f0-9]{40}))?" + ) + match = re.search(upload_views_prefix_regex, path_info) + + if match is None: + return None + + service = match.group(1) + encoded_slug = match.group(2) + commitid = match.group(3) + + return UploadInfo(service, encoded_slug, commitid) diff --git a/codecov_auth/authentication/repo_auth.py b/codecov_auth/authentication/repo_auth.py index 9d475e6d0f..a50e84bfa7 100644 --- a/codecov_auth/authentication/repo_auth.py +++ b/codecov_auth/authentication/repo_auth.py @@ -1,24 +1,18 @@ import json import logging -import re -from datetime import datetime -from typing import Any, List, Tuple +from typing import List from uuid import UUID -from asgiref.sync import async_to_sync from django.core.exceptions import ObjectDoesNotExist from django.db.models import QuerySet -from django.http.request import HttpRequest +from django.http import HttpRequest from django.utils import timezone from jwt import PyJWTError from rest_framework import authentication, exceptions -from rest_framework.exceptions import AuthenticationFailed, NotAuthenticated -from rest_framework.response import Response +from rest_framework.exceptions import NotAuthenticated from rest_framework.views import exception_handler -from sentry_sdk import metrics as sentry_metrics -from shared.metrics import metrics -from shared.torngit.exceptions import TorngitObjectNotFoundError, TorngitRateLimitError +from codecov_auth.authentication.helpers import get_upload_info_from_request_path from codecov_auth.authentication.types import RepositoryAsUser, RepositoryAuthInterface from codecov_auth.models import ( OrganizationLevelToken, @@ -27,7 +21,6 @@ TokenTypeChoices, ) from core.models import Commit, Repository -from services.repo_providers import RepoProviderService from upload.helpers import get_global_tokens, get_repo_with_github_actions_oidc_token from upload.views.helpers import get_repository_from_string from utils import is_uuid @@ -42,10 +35,17 @@ def repo_auth_custom_exception_handler(exc, context): give the user something better than "Invalid Token" or "Authentication credentials were not provided." """ response = exception_handler(exc, context) - if response is not None: + # we were having issues with this block, I made it super cautions. + # Re-evaluate later whether this is overly cautious. + if ( + response is not None + and hasattr(response, "status_code") + and response.status_code == 401 + and hasattr(response, "data") + ): try: - exc_code = response.data["detail"].code - except TypeError: + exc_code = response.data.get("detail").code + except (TypeError, AttributeError): return response if exc_code == NotAuthenticated.default_code: response.data["detail"] = ( @@ -181,40 +181,37 @@ class GlobalTokenAuthentication(authentication.TokenAuthentication): def authenticate(self, request): global_tokens = get_global_tokens() token = self.get_token(request) - repoid = self.get_repoid(request) - owner = self.get_owner(request) - using_global_token = True if token in global_tokens else False - service = global_tokens[token] if using_global_token else None + using_global_token = token in global_tokens + if not using_global_token: + return None # continue to next auth class - if using_global_token: - try: - repository = Repository.objects.get( - author__service=service, - repoid=repoid, - author__username=owner.username, - ) - except ObjectDoesNotExist: - raise exceptions.AuthenticationFailed( - "Could not find a repository, try using repo upload token" - ) - else: + service = global_tokens[token] + upload_info = get_upload_info_from_request_path(request) + if upload_info is None: return None # continue to next auth class + # It's important NOT to use the service returned in upload_info + # To avoid someone uploading with GlobalUploadToken to a different service + # Than what it configured + repository = get_repository_from_string( + Service(service), upload_info.encoded_slug + ) + if repository is None: + raise exceptions.AuthenticationFailed( + "Could not find a repository, try using repo upload token" + ) return ( RepositoryAsUser(repository), LegacyTokenRepositoryAuth(repository, {"token": token}), ) - def get_token(self, request): - # TODO - pass - - def get_repoid(self, request): - # TODO - pass - - def get_owner(self, request): - # TODO - pass + def get_token(self, request: HttpRequest) -> str | None: + auth_header = request.headers.get("Authorization") + if not auth_header: + return None + if " " in auth_header: + _, token = auth_header.split(" ", 1) + return token + return auth_header class OrgLevelTokenAuthentication(authentication.TokenAuthentication): @@ -241,7 +238,7 @@ def authenticate_credentials(self, token): try: repository = get_repo_with_github_actions_oidc_token(token) - except (ObjectDoesNotExist, PyJWTError) as e: + except (ObjectDoesNotExist, PyJWTError): return None # continue to next auth class log.info( @@ -260,22 +257,13 @@ class TokenlessAuthentication(authentication.TokenAuthentication): auth_failed_message = "Not valid tokenless upload" def _get_info_from_request_path( - self, request - ) -> tuple[Repository, str | None] | None: - path_info = request.get_full_path_info() - # The repo part comes from https://stackoverflow.com/a/22312124 - upload_views_prefix_regex = ( - r"\/upload\/(\w+)\/([\w\.@:_/\-~]+)\/commits(?:\/([a-f0-9]{40}))?" - ) - match = re.search(upload_views_prefix_regex, path_info) + self, request: HttpRequest + ) -> tuple[Repository, str | None]: + upload_info = get_upload_info_from_request_path(request) - if match is None: + if upload_info is None: raise exceptions.AuthenticationFailed(self.auth_failed_message) - - service = match.group(1) - encoded_slug = match.group(2) - commitid = match.group(3) - + service, encoded_slug, commitid = upload_info # Validate provider try: service_enum = Service(service) @@ -291,9 +279,11 @@ def _get_info_from_request_path( return repo, commitid - def get_branch(self, request, commitid=None): - if commitid: - commit = Commit.objects.filter(commitid=commitid).first() + def get_branch(self, request, repoid=None, commitid=None): + if repoid and commitid: + commit = Commit.objects.filter( + repository_id=repoid, commitid=commitid + ).first() if not commit: raise exceptions.AuthenticationFailed(self.auth_failed_message) return commit.branch @@ -311,7 +301,7 @@ def authenticate(self, request): if repository is None or repository.private: raise exceptions.AuthenticationFailed(self.auth_failed_message) - branch = self.get_branch(request, commitid) + branch = self.get_branch(request, repository.repoid, commitid) if (branch and ":" in branch) or request.method == "GET": return ( @@ -320,3 +310,41 @@ def authenticate(self, request): ) else: raise exceptions.AuthenticationFailed(self.auth_failed_message) + + +class BundleAnalysisTokenlessAuthentication(TokenlessAuthentication): + def _get_info_from_request_path( + self, request: HttpRequest + ) -> tuple[Repository, str | None]: + try: + body = json.loads(str(request.body, "utf8")) + + # Validate provider + service_enum = Service(body.get("git_service")) + + # Validate that next group exists and decode slug + repo = get_repository_from_string(service_enum, body.get("slug")) + if repo is None: + # Purposefully using the generic message so that we don't tell that + # we don't have a certain repo + raise exceptions.AuthenticationFailed(self.auth_failed_message) + + return repo, body.get("commit") + except json.JSONDecodeError: + # Validate request body format + raise exceptions.AuthenticationFailed(self.auth_failed_message) + except ValueError: + # Validate provider + raise exceptions.AuthenticationFailed(self.auth_failed_message) + + def get_branch(self, request, repoid=None, commitid=None): + body = json.loads(str(request.body, "utf8")) + + # If commit is not created yet (ie first upload for this commit), we just validate branch format. + # However if a commit exists already (ie not the first upload for this commit), we must additionally + # validate the saved commit branch matches what is requested in this upload call. + commit = Commit.objects.filter(repository_id=repoid, commitid=commitid).first() + if commit and commit.branch != body.get("branch"): + raise exceptions.AuthenticationFailed(self.auth_failed_message) + + return body.get("branch") diff --git a/codecov_auth/commands/owner/__init__.py b/codecov_auth/commands/owner/__init__.py index 8873985889..826743c705 100644 --- a/codecov_auth/commands/owner/__init__.py +++ b/codecov_auth/commands/owner/__init__.py @@ -1 +1,3 @@ from .owner import OwnerCommands + +__all__ = ["OwnerCommands"] diff --git a/codecov_auth/commands/owner/interactors/cancel_trial.py b/codecov_auth/commands/owner/interactors/cancel_trial.py index a2c950ae1c..714e10d185 100644 --- a/codecov_auth/commands/owner/interactors/cancel_trial.py +++ b/codecov_auth/commands/owner/interactors/cancel_trial.py @@ -1,14 +1,17 @@ from codecov.commands.base import BaseInteractor -from codecov.commands.exceptions import ValidationError +from codecov.commands.exceptions import Unauthorized, ValidationError from codecov.db import sync_to_async +from codecov_auth.helpers import current_user_part_of_org from codecov_auth.models import Owner from plan.service import PlanService class CancelTrialInteractor(BaseInteractor): - def validate(self, owner: Owner): + def validate(self, owner: Owner | None): if not owner: raise ValidationError("Cannot find owner record in the database") + if not current_user_part_of_org(self.current_owner, owner): + raise Unauthorized() def _cancel_trial(self, owner: Owner): plan_service = PlanService(current_org=owner) diff --git a/codecov_auth/commands/owner/interactors/get_uploads_number_per_user.py b/codecov_auth/commands/owner/interactors/get_uploads_number_per_user.py index 461951fd01..a18b337e97 100644 --- a/codecov_auth/commands/owner/interactors/get_uploads_number_per_user.py +++ b/codecov_auth/commands/owner/interactors/get_uploads_number_per_user.py @@ -1,6 +1,3 @@ -from datetime import datetime, timedelta - -from django.db.models import Q from shared.upload.utils import query_monthly_coverage_measurements from codecov.commands.base import BaseInteractor diff --git a/codecov_auth/commands/owner/interactors/onboard_user.py b/codecov_auth/commands/owner/interactors/onboard_user.py index 3b9951d1f2..fb6928df20 100644 --- a/codecov_auth/commands/owner/interactors/onboard_user.py +++ b/codecov_auth/commands/owner/interactors/onboard_user.py @@ -1,6 +1,3 @@ -import html - -import yaml from django import forms from codecov.commands.base import BaseInteractor diff --git a/codecov_auth/commands/owner/interactors/save_okta_config.py b/codecov_auth/commands/owner/interactors/save_okta_config.py new file mode 100644 index 0000000000..e6b51ac7c8 --- /dev/null +++ b/codecov_auth/commands/owner/interactors/save_okta_config.py @@ -0,0 +1,92 @@ +from dataclasses import dataclass + +from shared.django_apps.codecov_auth.models import AccountsUsers, User + +from codecov.commands.base import BaseInteractor +from codecov.commands.exceptions import Unauthenticated, Unauthorized, ValidationError +from codecov.db import sync_to_async +from codecov_auth.models import Account, OktaSettings, Owner + + +@dataclass +class SaveOktaConfigInput: + enabled: bool | None + enforced: bool | None + client_id: str | None = None + client_secret: str | None = None + url: str | None = None + org_username: str | None = None + + +class SaveOktaConfigInteractor(BaseInteractor): + def validate(self, owner: Owner) -> None: + if not self.current_user.is_authenticated: + raise Unauthenticated() + if not owner: + raise ValidationError("Cannot find owner record in the database") + if not owner.is_admin(self.current_owner): + raise Unauthorized() + + @sync_to_async + def execute(self, input: dict) -> None: + typed_input = SaveOktaConfigInput( + client_id=input.get("client_id"), + client_secret=input.get("client_secret"), + url=input.get("url"), + enabled=input.get("enabled"), + enforced=input.get("enforced"), + org_username=input.get("org_username"), + ) + + owner = Owner.objects.filter( + username=typed_input.org_username, service=self.service + ).first() + self.validate(owner=owner) + + account = owner.account + if not account: + account = Account.objects.create( + name=owner.username, + plan=owner.plan, + plan_seat_count=owner.plan_user_count, + free_seat_count=owner.free, + plan_auto_activate=owner.plan_auto_activate, + ) + owner.account = account + owner.save() + + # Update the activated users to be added to the account + plan_activated_user_owners: list[int] = owner.plan_activated_users + activated_connections: list[AccountsUsers] = [] + for activated_user_owner in plan_activated_user_owners: + user_owner: Owner = Owner.objects.select_related("user").get( + pk=activated_user_owner + ) + user = user_owner.user + if user is None: + user = User(name=user_owner.name, email=user_owner.email) + user_owner.user = user + user.save() + user_owner.save() + + activated_connections.append(AccountsUsers(account=account, user=user)) + + # Batch the user creation in batches of 50 users + if len(activated_connections) > 50: + AccountsUsers.objects.bulk_create(activated_connections) + activated_connections = [] + + if activated_connections: + AccountsUsers.objects.bulk_create(activated_connections) + + okta_config, created = OktaSettings.objects.get_or_create(account=account) + + for field in ["client_id", "client_secret", "url", "enabled", "enforced"]: + value = getattr(typed_input, field) + if value is not None: + # Strip the URL of any trailing spaces and slashes before saving it + if field == "url": + value = value.strip("/ ") + setattr(okta_config, field, value) + + okta_config.save() diff --git a/codecov_auth/commands/owner/interactors/save_terms_agreement.py b/codecov_auth/commands/owner/interactors/save_terms_agreement.py index 070cfea425..113c899c8c 100644 --- a/codecov_auth/commands/owner/interactors/save_terms_agreement.py +++ b/codecov_auth/commands/owner/interactors/save_terms_agreement.py @@ -54,10 +54,10 @@ def send_data_to_marketo(self): @sync_to_async def execute(self, input): typed_input = TermsAgreementInput( - business_email=input.get("businessEmail"), - terms_agreement=input.get("termsAgreement"), - marketing_consent=input.get("marketingConsent"), - customer_intent=input.get("customerIntent"), + business_email=input.get("business_email"), + terms_agreement=input.get("terms_agreement"), + marketing_consent=input.get("marketing_consent"), + customer_intent=input.get("customer_intent"), ) self.validate(typed_input) return self.update_terms_agreement(typed_input) diff --git a/codecov_auth/commands/owner/interactors/start_trial.py b/codecov_auth/commands/owner/interactors/start_trial.py index feade7e2df..3058f230cc 100644 --- a/codecov_auth/commands/owner/interactors/start_trial.py +++ b/codecov_auth/commands/owner/interactors/start_trial.py @@ -1,14 +1,17 @@ from codecov.commands.base import BaseInteractor -from codecov.commands.exceptions import ValidationError +from codecov.commands.exceptions import Unauthorized, ValidationError from codecov.db import sync_to_async +from codecov_auth.helpers import current_user_part_of_org from codecov_auth.models import Owner from plan.service import PlanService class StartTrialInteractor(BaseInteractor): - def validate(self, current_org: Owner): + def validate(self, current_org: Owner | None): if not current_org: raise ValidationError("Cannot find owner record in the database") + if not current_user_part_of_org(self.current_owner, current_org): + raise Unauthorized() def _start_trial(self, current_org: Owner) -> None: plan_service = PlanService(current_org=current_org) diff --git a/codecov_auth/commands/owner/interactors/store_codecov_metric.py b/codecov_auth/commands/owner/interactors/store_codecov_metric.py new file mode 100644 index 0000000000..aa2ac0b69e --- /dev/null +++ b/codecov_auth/commands/owner/interactors/store_codecov_metric.py @@ -0,0 +1,31 @@ +import json + +from shared.django_apps.codecov_metrics.service.codecov_metrics import ( + UserOnboardingMetricsService, +) + +from codecov.commands.base import BaseInteractor +from codecov.commands.exceptions import ValidationError +from codecov.db import sync_to_async +from codecov_auth.models import Owner + + +class StoreCodecovMetricInteractor(BaseInteractor): + @sync_to_async + def execute(self, org_username: str, event: str, json_string: str) -> None: + current_org = Owner.objects.filter( + username=org_username, service=self.service + ).first() + if not current_org: + raise ValidationError("Cannot find owner record in the database") + + try: + payload = json.loads(json_string) + except json.JSONDecodeError: + raise ValidationError("Invalid JSON string") + + UserOnboardingMetricsService.create_user_onboarding_metric( + org_id=current_org.pk, + event=event, + payload=payload, + ) diff --git a/codecov_auth/commands/owner/interactors/tests/test_cancel_trial.py b/codecov_auth/commands/owner/interactors/tests/test_cancel_trial.py index 650826cdad..2ed619968d 100644 --- a/codecov_auth/commands/owner/interactors/tests/test_cancel_trial.py +++ b/codecov_auth/commands/owner/interactors/tests/test_cancel_trial.py @@ -5,10 +5,10 @@ from django.test import TransactionTestCase from freezegun import freeze_time -from codecov.commands.exceptions import ValidationError +from codecov.commands.exceptions import Unauthorized, ValidationError from codecov_auth.models import Owner from codecov_auth.tests.factories import OwnerFactory -from plan.constants import PlanName, TrialDaysAmount, TrialStatus +from plan.constants import PlanName, TrialStatus from ..cancel_trial import CancelTrialInteractor @@ -29,6 +29,18 @@ def test_cancel_trial_raises_exception_when_owner_is_not_in_db(self): with pytest.raises(ValidationError): self.execute(current_user=current_user, org_username="some-other-username") + def test_cancel_trial_raises_exception_when_current_user_not_part_of_org(self): + current_user = OwnerFactory( + username="random-user-123", + service="github", + ) + OwnerFactory( + username="random-user-456", + service="github", + ) + with pytest.raises(Unauthorized): + self.execute(current_user=current_user, org_username="random-user-456") + @freeze_time("2022-01-01T00:00:00") def test_cancel_trial_raises_exception_when_owners_trial_status_is_not_started( self, @@ -46,7 +58,7 @@ def test_cancel_trial_raises_exception_when_owners_trial_status_is_not_started( @freeze_time("2022-01-01T00:00:00") def test_cancel_trial_raises_exception_when_owners_trial_status_is_expired(self): - now = datetime.utcnow() + now = datetime.now() trial_start_date = now + timedelta(days=-10) trial_end_date = now + timedelta(days=-4) current_user = OwnerFactory( @@ -60,7 +72,7 @@ def test_cancel_trial_raises_exception_when_owners_trial_status_is_expired(self) @freeze_time("2022-01-01T00:00:00") def test_cancel_trial_starts_trial_for_org_that_has_trial_ongoing(self): - now = datetime.utcnow() + now = datetime.now() trial_start_date = now trial_end_date = now + timedelta(days=3) current_user: Owner = OwnerFactory( @@ -74,10 +86,10 @@ def test_cancel_trial_starts_trial_for_org_that_has_trial_ongoing(self): self.execute(current_user=current_user, org_username=current_user.username) current_user.refresh_from_db() - now = datetime.utcnow() + now = datetime.now() assert current_user.trial_end_date == now assert current_user.trial_status == TrialStatus.EXPIRED.value assert current_user.plan == PlanName.BASIC_PLAN_NAME.value - assert current_user.plan_activated_users == None + assert current_user.plan_activated_users is None assert current_user.plan_user_count == 1 - assert current_user.stripe_subscription_id == None + assert current_user.stripe_subscription_id is None diff --git a/codecov_auth/commands/owner/interactors/tests/test_create_api_token.py b/codecov_auth/commands/owner/interactors/tests/test_create_api_token.py index 4d99ab543e..6b310bf07d 100644 --- a/codecov_auth/commands/owner/interactors/tests/test_create_api_token.py +++ b/codecov_auth/commands/owner/interactors/tests/test_create_api_token.py @@ -1,9 +1,7 @@ import pytest -from django.contrib.auth.models import AnonymousUser from django.test import TransactionTestCase from codecov.commands.exceptions import Unauthenticated, ValidationError -from codecov_auth.models import Session from codecov_auth.tests.factories import OwnerFactory from ..create_api_token import CreateApiTokenInteractor diff --git a/codecov_auth/commands/owner/interactors/tests/test_create_user_token.py b/codecov_auth/commands/owner/interactors/tests/test_create_user_token.py index 45be10d67e..bda35a018e 100644 --- a/codecov_auth/commands/owner/interactors/tests/test_create_user_token.py +++ b/codecov_auth/commands/owner/interactors/tests/test_create_user_token.py @@ -1,5 +1,4 @@ import pytest -from django.contrib.auth.models import AnonymousUser from django.test import TransactionTestCase from codecov.commands.exceptions import Unauthenticated, ValidationError diff --git a/codecov_auth/commands/owner/interactors/tests/test_delete_session.py b/codecov_auth/commands/owner/interactors/tests/test_delete_session.py index 1fdfb12da6..d2369d4771 100644 --- a/codecov_auth/commands/owner/interactors/tests/test_delete_session.py +++ b/codecov_auth/commands/owner/interactors/tests/test_delete_session.py @@ -1,7 +1,5 @@ import pytest -from django.contrib.auth.models import AnonymousUser from django.test import TransactionTestCase -from django.utils import timezone from codecov.commands.exceptions import Unauthenticated from codecov.db import sync_to_async diff --git a/codecov_auth/commands/owner/interactors/tests/test_get_is_current_user_an_admin.py b/codecov_auth/commands/owner/interactors/tests/test_get_is_current_user_an_admin.py index b0e23cc1d7..79e6a7e0a7 100644 --- a/codecov_auth/commands/owner/interactors/tests/test_get_is_current_user_an_admin.py +++ b/codecov_auth/commands/owner/interactors/tests/test_get_is_current_user_an_admin.py @@ -1,8 +1,6 @@ from unittest.mock import patch -import pytest from asgiref.sync import async_to_sync -from distutils.util import execute from django.test import TransactionTestCase, override_settings from codecov_auth.tests.factories import GetAdminProviderAdapter, OwnerFactory diff --git a/codecov_auth/commands/owner/interactors/tests/test_get_org_upload_token.py b/codecov_auth/commands/owner/interactors/tests/test_get_org_upload_token.py index fc28e381a9..8067d84c05 100644 --- a/codecov_auth/commands/owner/interactors/tests/test_get_org_upload_token.py +++ b/codecov_auth/commands/owner/interactors/tests/test_get_org_upload_token.py @@ -17,7 +17,7 @@ async def test_owner_with_no_org_upload_token(self): token = await GetOrgUploadToken( self.owner_with_no_upload_token, "github" ).execute(self.owner_with_no_upload_token) - assert token == None + assert token is None async def test_owner_with_org_upload_token(self): token = await GetOrgUploadToken(self.owner_with_upload_token, "github").execute( @@ -32,7 +32,7 @@ async def test_owner_with_org_upload_token_and_anonymous_user(self): self.owner_with_upload_token ) - assert token == None + assert token is None async def test_owner_with_org_upload_token_and_unauthorized_user(self): with pytest.raises(Unauthorized): @@ -40,4 +40,4 @@ async def test_owner_with_org_upload_token_and_unauthorized_user(self): self.owner_with_upload_token, "github" ).execute(self.owner_with_no_upload_token) - assert token == None + assert token is None diff --git a/codecov_auth/commands/owner/interactors/tests/test_get_uploads_number_per_user.py b/codecov_auth/commands/owner/interactors/tests/test_get_uploads_number_per_user.py index fec5fffbb4..d29f643eba 100644 --- a/codecov_auth/commands/owner/interactors/tests/test_get_uploads_number_per_user.py +++ b/codecov_auth/commands/owner/interactors/tests/test_get_uploads_number_per_user.py @@ -43,8 +43,8 @@ def setUp(self): # Trial Data self.trial_owner = OwnerFactory( trial_status=TrialStatus.EXPIRED.value, - trial_start_date=datetime.utcnow() + timedelta(days=-10), - trial_end_date=datetime.utcnow() + timedelta(days=-2), + trial_start_date=datetime.now() + timedelta(days=-10), + trial_end_date=datetime.now() + timedelta(days=-2), ) trial_repo = RepositoryFactory.create(author=self.trial_owner, private=True) trial_commit = CommitFactory.create(repository=trial_repo) diff --git a/codecov_auth/commands/owner/interactors/tests/test_onboard_user.py b/codecov_auth/commands/owner/interactors/tests/test_onboard_user.py index a203e92d8c..c852c722e1 100644 --- a/codecov_auth/commands/owner/interactors/tests/test_onboard_user.py +++ b/codecov_auth/commands/owner/interactors/tests/test_onboard_user.py @@ -1,5 +1,4 @@ import pytest -from django.contrib.auth.models import AnonymousUser from django.test import TransactionTestCase from codecov.commands.exceptions import Unauthenticated, Unauthorized, ValidationError diff --git a/codecov_auth/commands/owner/interactors/tests/test_regenerate_org_upload_token.py b/codecov_auth/commands/owner/interactors/tests/test_regenerate_org_upload_token.py index f1b9caab23..ebc304562c 100644 --- a/codecov_auth/commands/owner/interactors/tests/test_regenerate_org_upload_token.py +++ b/codecov_auth/commands/owner/interactors/tests/test_regenerate_org_upload_token.py @@ -1,6 +1,4 @@ import pytest -from asgiref.sync import async_to_sync -from django.contrib.auth.models import AnonymousUser from django.test import TransactionTestCase from codecov.commands.exceptions import Unauthenticated, Unauthorized, ValidationError diff --git a/codecov_auth/commands/owner/interactors/tests/test_revoke_user_token.py b/codecov_auth/commands/owner/interactors/tests/test_revoke_user_token.py index c2957d42dc..ab71a6c8c1 100644 --- a/codecov_auth/commands/owner/interactors/tests/test_revoke_user_token.py +++ b/codecov_auth/commands/owner/interactors/tests/test_revoke_user_token.py @@ -1,11 +1,10 @@ import pytest -from django.contrib.auth.models import AnonymousUser from django.test import TransactionTestCase from codecov.commands.exceptions import Unauthenticated from codecov.db import sync_to_async -from codecov_auth.models import Session, UserToken -from codecov_auth.tests.factories import OwnerFactory, SessionFactory, UserTokenFactory +from codecov_auth.models import UserToken +from codecov_auth.tests.factories import OwnerFactory, UserTokenFactory from ..revoke_user_token import RevokeUserTokenInteractor diff --git a/codecov_auth/commands/owner/interactors/tests/test_save_okta_config.py b/codecov_auth/commands/owner/interactors/tests/test_save_okta_config.py new file mode 100644 index 0000000000..8a333629be --- /dev/null +++ b/codecov_auth/commands/owner/interactors/tests/test_save_okta_config.py @@ -0,0 +1,249 @@ +import pytest +from asgiref.sync import async_to_sync +from django.contrib.auth.models import AnonymousUser +from django.test import TransactionTestCase + +from codecov.commands.exceptions import Unauthenticated, Unauthorized, ValidationError +from codecov_auth.models import OktaSettings +from codecov_auth.tests.factories import ( + AccountFactory, + OktaSettingsFactory, + OwnerFactory, +) + +from ..save_okta_config import SaveOktaConfigInteractor + + +class SaveOktaConfigInteractorTest(TransactionTestCase): + def setUp(self): + self.current_user = OwnerFactory(username="codecov-user") + self.service = "github" + user1 = OwnerFactory() + user2 = OwnerFactory() + self.owner = OwnerFactory( + username=self.current_user.username, + service=self.service, + account=AccountFactory(), + ) + + self.owner_with_admins = OwnerFactory( + username=self.current_user.username, + service=self.service, + admins=[self.current_user.ownerid], + plan_activated_users=[user1.ownerid, user2.ownerid], + account=None, + ) + + self.interactor = SaveOktaConfigInteractor( + current_owner=self.owner, + service=self.service, + current_user=self.current_user, + ) + + @async_to_sync + def execute( + self, + interactor: SaveOktaConfigInteractor | None = None, + input: dict | None = None, + ): + if not interactor and self.interactor: + interactor = self.interactor + + if not interactor: + return + return interactor.execute(input) + + def test_user_is_not_authenticated(self): + with pytest.raises(Unauthenticated): + self.execute( + interactor=SaveOktaConfigInteractor( + current_owner=None, + service=self.service, + current_user=AnonymousUser(), + ), + input={ + "client_id": "some-client-id", + "client_secret": "some-client-secret", + "url": "https://okta.example.com", + "enabled": True, + "enforced": True, + "org_username": self.owner.username, + }, + ) + + def test_validation_error_when_owner_not_found(self): + with pytest.raises(ValidationError): + self.execute( + input={ + "client_id": "some-client-id", + "client_secret": "some-client-secret", + "url": "https://okta.example.com", + "enabled": True, + "enforced": True, + "org_username": "non-existent-user", + }, + ) + + def test_unauthorized_error_when_user_is_not_admin(self): + with pytest.raises(Unauthorized): + self.execute( + input={ + "client_id": "some-client-id", + "client_secret": "some-client-secret", + "url": "https://okta.example.com", + "enabled": True, + "enforced": True, + "org_username": self.owner.username, + }, + ) + + def test_create_okta_settings_when_account_does_not_exist(self): + plan_activated_users = [] + for _ in range(100): + user_owner = OwnerFactory(user=None) + plan_activated_users.append(user_owner.ownerid) + + org_with_lots_of_users = OwnerFactory( + service=self.service, + admins=[self.current_user.ownerid], + plan_activated_users=plan_activated_users, + ) + + input_data = { + "client_id": "some-client-id", + "client_secret": "some-client-secret", + "url": "https://okta.example.com", + "enabled": True, + "enforced": True, + "org_username": org_with_lots_of_users.username, + } + + interactor = SaveOktaConfigInteractor( + current_owner=self.current_user, service=self.service + ) + self.execute(interactor=interactor, input=input_data) + + org_with_lots_of_users.refresh_from_db() + account = org_with_lots_of_users.account + + assert account.name == org_with_lots_of_users.username + assert account.plan == org_with_lots_of_users.plan + assert account.plan_seat_count == org_with_lots_of_users.plan_user_count + assert account.free_seat_count == org_with_lots_of_users.free + + assert account.users.count() == 100 + assert account.users.count() == len(org_with_lots_of_users.plan_activated_users) + + okta_config = OktaSettings.objects.get(account=org_with_lots_of_users.account) + + assert okta_config.client_id == input_data["client_id"] + assert okta_config.client_secret == input_data["client_secret"] + assert okta_config.url == input_data["url"] + assert okta_config.enabled == input_data["enabled"] + assert okta_config.enforced == input_data["enforced"] + + def test_update_okta_settings_when_account_exists(self): + input_data = { + "client_id": "some-client-id", + "client_secret": "some-client-secret", + "url": "https://okta.example.com", + "enabled": True, + "enforced": True, + "org_username": self.owner_with_admins.username, + } + + account = AccountFactory() + self.owner_with_admins.account = account + self.owner_with_admins.save() + + interactor = SaveOktaConfigInteractor( + current_owner=self.current_user, service=self.service + ) + self.execute(interactor=interactor, input=input_data) + + okta_config = OktaSettings.objects.get(account=self.owner_with_admins.account) + + assert okta_config.client_id == input_data["client_id"] + assert okta_config.client_secret == input_data["client_secret"] + assert okta_config.url == input_data["url"] + assert okta_config.enabled == input_data["enabled"] + assert okta_config.enforced == input_data["enforced"] + + def test_update_okta_settings_url_remove_trailing_slashes(self): + input_data = { + "client_id": "some-client-id", + "client_secret": "some-client-secret", + "url": "https://okta.example.com/", + "enabled": True, + "enforced": True, + "org_username": self.owner_with_admins.username, + } + + account = AccountFactory() + self.owner_with_admins.account = account + self.owner_with_admins.save() + + interactor = SaveOktaConfigInteractor( + current_owner=self.current_user, service=self.service + ) + self.execute(interactor=interactor, input=input_data) + + okta_config = OktaSettings.objects.get(account=self.owner_with_admins.account) + + assert okta_config.url == "https://okta.example.com" + + def test_update_okta_settings_when_okta_settings_exists(self): + input_data = { + "client_id": "some-client-id", + "client_secret": "some-client-secret", + "url": "https://okta.example.com", + "enabled": True, + "enforced": True, + "org_username": self.owner_with_admins.username, + } + + account = AccountFactory() + OktaSettingsFactory(account=account) + self.owner_with_admins.account = account + self.owner_with_admins.save() + + interactor = SaveOktaConfigInteractor( + current_owner=self.current_user, service=self.service + ) + self.execute(interactor=interactor, input=input_data) + + okta_config = OktaSettings.objects.get(account=self.owner_with_admins.account) + + assert okta_config.client_id == input_data["client_id"] + assert okta_config.client_secret == input_data["client_secret"] + assert okta_config.url == input_data["url"] + assert okta_config.enabled == input_data["enabled"] + assert okta_config.enforced == input_data["enforced"] + + def test_update_okta_settings_when_some_fields_are_none(self): + input_data = { + "client_id": "some-client-id", + "client_secret": None, + "url": None, + "enabled": True, + "enforced": True, + "org_username": self.owner_with_admins.username, + } + + account = AccountFactory() + OktaSettingsFactory(account=account) + self.owner_with_admins.account = account + self.owner_with_admins.save() + + interactor = SaveOktaConfigInteractor( + current_owner=self.current_user, service=self.service + ) + self.execute(interactor=interactor, input=input_data) + + okta_config = OktaSettings.objects.get(account=self.owner_with_admins.account) + + assert okta_config.client_id == input_data["client_id"] + assert okta_config.client_secret is not None + assert okta_config.url is not None + assert okta_config.enabled == input_data["enabled"] + assert okta_config.enforced == input_data["enforced"] diff --git a/codecov_auth/commands/owner/interactors/tests/test_save_terms_agreement.py b/codecov_auth/commands/owner/interactors/tests/test_save_terms_agreement.py index c73c3fd412..257e649764 100644 --- a/codecov_auth/commands/owner/interactors/tests/test_save_terms_agreement.py +++ b/codecov_auth/commands/owner/interactors/tests/test_save_terms_agreement.py @@ -22,8 +22,8 @@ def execute( self, current_user, input={ - "businessEmail": None, - "termsAgreement": False, + "business_email": None, + "terms_agreement": False, }, ): return SaveTermsAgreementInteractor(None, "github", current_user).execute( @@ -34,7 +34,7 @@ def execute( def test_update_user_when_agreement_is_false(self): self.execute( current_user=self.current_user, - input={"termsAgreement": False, "customerIntent": "Business"}, + input={"terms_agreement": False, "customer_intent": "Business"}, ) before_refresh_business_email = self.current_user.email @@ -42,13 +42,13 @@ def test_update_user_when_agreement_is_false(self): assert self.current_user.terms_agreement_at == self.updated_at self.current_user.refresh_from_db() - self.current_user.email == before_refresh_business_email + assert self.current_user.email == before_refresh_business_email @freeze_time("2022-01-01T00:00:00") def test_update_user_when_agreement_is_true(self): self.execute( current_user=self.current_user, - input={"termsAgreement": True, "customerIntent": "Business"}, + input={"terms_agreement": True, "customer_intent": "Business"}, ) before_refresh_business_email = self.current_user.email @@ -56,16 +56,16 @@ def test_update_user_when_agreement_is_true(self): assert self.current_user.terms_agreement_at == self.updated_at self.current_user.refresh_from_db() - self.current_user.email == before_refresh_business_email + assert self.current_user.email == before_refresh_business_email @freeze_time("2022-01-01T00:00:00") def test_update_owner_and_user_when_email_is_not_empty(self): self.execute( current_user=self.current_user, input={ - "businessEmail": "something@email.com", - "termsAgreement": True, - "customerIntent": "Business", + "business_email": "something@email.com", + "terms_agreement": True, + "customer_intent": "Business", }, ) @@ -73,29 +73,29 @@ def test_update_owner_and_user_when_email_is_not_empty(self): assert self.current_user.terms_agreement_at == self.updated_at self.current_user.refresh_from_db() - self.current_user.email == "something@email.com" + assert self.current_user.email == "something@email.com" def test_validation_error_when_terms_is_none(self): with pytest.raises(ValidationError): self.execute( current_user=self.current_user, - input={"termsAgreement": None, "customerIntent": "Business"}, + input={"terms_agreement": None, "customer_intent": "Business"}, ) def test_validation_error_when_customer_intent_invalid(self): with pytest.raises(ValidationError): self.execute( current_user=self.current_user, - input={"termsAgreement": None, "customerIntent": "invalid"}, + input={"terms_agreement": None, "customer_intent": "invalid"}, ) def test_user_is_not_authenticated(self): - with pytest.raises(Unauthenticated) as e: + with pytest.raises(Unauthenticated): self.execute( current_user=AnonymousUser(), input={ - "businessEmail": "something@email.com", - "termsAgreement": True, - "customerIntent": "Business", + "business_email": "something@email.com", + "terms_agreement": True, + "customer_intent": "Business", }, ) diff --git a/codecov_auth/commands/owner/interactors/tests/test_set_yaml_on_owner.py b/codecov_auth/commands/owner/interactors/tests/test_set_yaml_on_owner.py index de858b71a0..f24014a66f 100644 --- a/codecov_auth/commands/owner/interactors/tests/test_set_yaml_on_owner.py +++ b/codecov_auth/commands/owner/interactors/tests/test_set_yaml_on_owner.py @@ -1,5 +1,4 @@ import pytest -from django.contrib.auth.models import AnonymousUser from django.test import TransactionTestCase from codecov.commands.exceptions import ( @@ -86,14 +85,12 @@ async def test_user_is_part_of_org_and_yaml_is_empty(self): async def test_user_is_part_of_org_and_yaml_is_not_dict(self): with pytest.raises(ValidationError) as e: - owner_updated = await self.execute( - self.current_owner, self.org.username, bad_yaml_not_dict - ) + await self.execute(self.current_owner, self.org.username, bad_yaml_not_dict) assert str(e.value) == "Error at []: Yaml needs to be a dict" async def test_user_is_part_of_org_and_yaml_is_not_codecov_valid(self): - with pytest.raises(ValidationError) as e: - owner_updated = await self.execute( + with pytest.raises(ValidationError): + await self.execute( self.current_owner, self.org.username, bad_yaml_wrong_keys ) diff --git a/codecov_auth/commands/owner/interactors/tests/test_start_trial.py b/codecov_auth/commands/owner/interactors/tests/test_start_trial.py index 002cb8601d..e19d0e29d4 100644 --- a/codecov_auth/commands/owner/interactors/tests/test_start_trial.py +++ b/codecov_auth/commands/owner/interactors/tests/test_start_trial.py @@ -1,12 +1,11 @@ from datetime import datetime, timedelta -from unittest.mock import patch import pytest from asgiref.sync import async_to_sync from django.test import TransactionTestCase from freezegun import freeze_time -from codecov.commands.exceptions import ValidationError +from codecov.commands.exceptions import Unauthorized, ValidationError from codecov_auth.models import Owner from codecov_auth.tests.factories import OwnerFactory from plan.constants import TRIAL_PLAN_SEATS, PlanName, TrialDaysAmount, TrialStatus @@ -30,9 +29,21 @@ def test_start_trial_raises_exception_when_owner_is_not_in_db(self): with pytest.raises(ValidationError): self.execute(current_user=current_user, org_username="some-other-username") + def test_cancel_trial_raises_exception_when_current_user_not_part_of_org(self): + current_user = OwnerFactory( + username="random-user-123", + service="github", + ) + OwnerFactory( + username="random-user-456", + service="github", + ) + with pytest.raises(Unauthorized): + self.execute(current_user=current_user, org_username="random-user-456") + @freeze_time("2022-01-01T00:00:00") def test_start_trial_raises_exception_when_owners_trial_status_is_ongoing(self): - now = datetime.utcnow() + now = datetime.now() trial_start_date = now trial_end_date = now + timedelta(days=3) current_user = OwnerFactory( @@ -47,7 +58,7 @@ def test_start_trial_raises_exception_when_owners_trial_status_is_ongoing(self): @freeze_time("2022-01-01T00:00:00") def test_start_trial_raises_exception_when_owners_trial_status_is_expired(self): - now = datetime.utcnow() + now = datetime.now() trial_start_date = now + timedelta(days=-10) trial_end_date = now + timedelta(days=-4) current_user = OwnerFactory( @@ -64,7 +75,7 @@ def test_start_trial_raises_exception_when_owners_trial_status_is_expired(self): def test_start_trial_raises_exception_when_owners_trial_status_cannot_trial( self, ): - now = datetime.utcnow() + now = datetime.now() trial_start_date = now trial_end_date = now current_user = OwnerFactory( @@ -91,7 +102,7 @@ def test_start_trial_starts_trial_for_org_that_has_not_started_trial_before_and_ self.execute(current_user=current_user, org_username=current_user.username) current_user.refresh_from_db() - now = datetime.utcnow() + now = datetime.now() assert current_user.trial_start_date == now assert current_user.trial_end_date == now + timedelta( days=TrialDaysAmount.CODECOV_SENTRY.value diff --git a/codecov_auth/commands/owner/interactors/tests/test_trigger_sync.py b/codecov_auth/commands/owner/interactors/tests/test_trigger_sync.py index 162a200755..0c169f11e8 100644 --- a/codecov_auth/commands/owner/interactors/tests/test_trigger_sync.py +++ b/codecov_auth/commands/owner/interactors/tests/test_trigger_sync.py @@ -1,10 +1,9 @@ from unittest.mock import patch import pytest -from django.contrib.auth.models import AnonymousUser from django.test import TransactionTestCase -from codecov.commands.exceptions import Unauthenticated, ValidationError +from codecov.commands.exceptions import Unauthenticated from codecov_auth.tests.factories import OwnerFactory from ..trigger_sync import TriggerSyncInteractor diff --git a/codecov_auth/commands/owner/interactors/tests/test_update_default_organization.py b/codecov_auth/commands/owner/interactors/tests/test_update_default_organization.py index 9f0dbda6fd..2e9636bbde 100644 --- a/codecov_auth/commands/owner/interactors/tests/test_update_default_organization.py +++ b/codecov_auth/commands/owner/interactors/tests/test_update_default_organization.py @@ -2,7 +2,6 @@ import pytest from asgiref.sync import async_to_sync -from django.contrib.auth.models import AnonymousUser from django.test import TransactionTestCase from codecov.commands.exceptions import Unauthenticated, ValidationError @@ -48,7 +47,7 @@ def test_update_org_when_default_org_username_is_none(self): owner_profile: OwnerProfile = OwnerProfile.objects.filter( owner_id=self.owner.ownerid ).first() - assert owner_profile.default_org == None + assert owner_profile.default_org is None def test_update_owners_default_org(self): username = self.execute( diff --git a/codecov_auth/commands/owner/interactors/tests/test_update_profile.py b/codecov_auth/commands/owner/interactors/tests/test_update_profile.py index 582efb8ef5..0d7c7e6e3a 100644 --- a/codecov_auth/commands/owner/interactors/tests/test_update_profile.py +++ b/codecov_auth/commands/owner/interactors/tests/test_update_profile.py @@ -1,5 +1,4 @@ import pytest -from django.contrib.auth.models import AnonymousUser from django.test import TransactionTestCase from codecov.commands.exceptions import Unauthenticated, ValidationError diff --git a/codecov_auth/commands/owner/interactors/tests/test_update_self_hosted_settings.py b/codecov_auth/commands/owner/interactors/tests/test_update_self_hosted_settings.py index 436c8ebf15..e731f8d77a 100644 --- a/codecov_auth/commands/owner/interactors/tests/test_update_self_hosted_settings.py +++ b/codecov_auth/commands/owner/interactors/tests/test_update_self_hosted_settings.py @@ -1,5 +1,3 @@ -from unittest.mock import patch - import pytest from asgiref.sync import async_to_sync from django.contrib.auth.models import AnonymousUser @@ -18,7 +16,7 @@ def execute( self, current_user, input={ - "shouldAutoActivate": None, + "should_auto_activate": None, }, ): return UpdateSelfHostedSettingsInteractor(None, "github", current_user).execute( @@ -28,14 +26,14 @@ def execute( @override_settings(IS_ENTERPRISE=True) def test_update_self_hosted_settings_when_auto_activate_is_true(self): owner = OwnerFactory(plan_auto_activate=False) - self.execute(current_user=owner, input={"shouldAutoActivate": True}) + self.execute(current_user=owner, input={"should_auto_activate": True}) owner.refresh_from_db() assert owner.plan_auto_activate == True @override_settings(IS_ENTERPRISE=True) def test_update_self_hosted_settings_when_auto_activate_is_false(self): owner = OwnerFactory(plan_auto_activate=True) - self.execute(current_user=owner, input={"shouldAutoActivate": False}) + self.execute(current_user=owner, input={"should_auto_activate": False}) owner.refresh_from_db() assert owner.plan_auto_activate == False @@ -46,16 +44,16 @@ def test_validation_error_when_not_self_hosted_instance(self): self.execute( current_user=owner, input={ - "shouldAutoActivate": False, + "should_auto_activate": False, }, ) @override_settings(IS_ENTERPRISE=True) def test_user_is_not_authenticated(self): - with pytest.raises(Unauthenticated) as e: + with pytest.raises(Unauthenticated): self.execute( current_user=AnonymousUser(), input={ - "shouldAutoActivate": False, + "should_auto_activate": False, }, ) diff --git a/codecov_auth/commands/owner/interactors/trigger_sync.py b/codecov_auth/commands/owner/interactors/trigger_sync.py index f87797a689..66f2381a85 100644 --- a/codecov_auth/commands/owner/interactors/trigger_sync.py +++ b/codecov_auth/commands/owner/interactors/trigger_sync.py @@ -5,12 +5,12 @@ class TriggerSyncInteractor(BaseInteractor): - def validate(self): + def validate(self) -> None: if not self.current_user.is_authenticated: raise Unauthenticated() @sync_to_async - def execute(self): + def execute(self) -> None: self.validate() RefreshService().trigger_refresh( self.current_owner.ownerid, diff --git a/codecov_auth/commands/owner/interactors/update_self_hosted_settings.py b/codecov_auth/commands/owner/interactors/update_self_hosted_settings.py index 080cfe6d72..174acee4c6 100644 --- a/codecov_auth/commands/owner/interactors/update_self_hosted_settings.py +++ b/codecov_auth/commands/owner/interactors/update_self_hosted_settings.py @@ -6,7 +6,6 @@ from codecov.commands.base import BaseInteractor from codecov.commands.exceptions import Unauthenticated, ValidationError from codecov.db import sync_to_async -from services.refresh import RefreshService @dataclass @@ -28,7 +27,7 @@ def validate(self) -> None: def execute(self, input: UpdateSelfHostedSettingsInput) -> None: self.validate() typed_input = UpdateSelfHostedSettingsInput( - auto_activate_members=input.get("shouldAutoActivate"), + auto_activate_members=input.get("should_auto_activate"), ) should_auto_activate = typed_input.auto_activate_members diff --git a/codecov_auth/commands/owner/owner.py b/codecov_auth/commands/owner/owner.py index af866ddcb1..c98a4b7ded 100644 --- a/codecov_auth/commands/owner/owner.py +++ b/codecov_auth/commands/owner/owner.py @@ -12,9 +12,11 @@ from .interactors.onboard_user import OnboardUserInteractor from .interactors.regenerate_org_upload_token import RegenerateOrgUploadTokenInteractor from .interactors.revoke_user_token import RevokeUserTokenInteractor +from .interactors.save_okta_config import SaveOktaConfigInteractor from .interactors.save_terms_agreement import SaveTermsAgreementInteractor from .interactors.set_yaml_on_owner import SetYamlOnOwnerInteractor from .interactors.start_trial import StartTrialInteractor +from .interactors.store_codecov_metric import StoreCodecovMetricInteractor from .interactors.trigger_sync import TriggerSyncInteractor from .interactors.update_default_organization import UpdateDefaultOrganizationInteractor from .interactors.update_profile import UpdateProfileInteractor @@ -86,3 +88,13 @@ def cancel_trial(self, org_username: str) -> None: def update_self_hosted_settings(self, input) -> None: return self.get_interactor(UpdateSelfHostedSettingsInteractor).execute(input) + + def store_codecov_metric( + self, org_username: str, event: str, json_string: str + ) -> None: + return self.get_interactor(StoreCodecovMetricInteractor).execute( + org_username, event, json_string + ) + + def save_okta_config(self, input) -> None: + return self.get_interactor(SaveOktaConfigInteractor).execute(input) diff --git a/codecov_auth/commands/owner/tests/test_owner.py b/codecov_auth/commands/owner/tests/test_owner.py index 3a7a1ac32e..04aca26a2f 100644 --- a/codecov_auth/commands/owner/tests/test_owner.py +++ b/codecov_auth/commands/owner/tests/test_owner.py @@ -121,3 +121,16 @@ def test_regenerate_org_upload_token_delegate_to_interactor(self, interactor_moc owner = {} self.command.regenerate_org_upload_token(owner) interactor_mock.assert_called_once_with(owner) + + @patch("codecov_auth.commands.owner.owner.SaveOktaConfigInteractor.execute") + def test_save_okta_config_delegate_to_interactor(self, interactor_mock): + input_dict = { + "client_id": "123", + "client_secret": "123", + "url": "http://example.com", + "enabled": True, + "enforced": False, + "org_username": "codecov-user", + } + self.command.save_okta_config(input_dict) + interactor_mock.assert_called_once_with(input_dict) diff --git a/codecov_auth/helpers.py b/codecov_auth/helpers.py index 7ba6519f53..3998d7529f 100644 --- a/codecov_auth/helpers.py +++ b/codecov_auth/helpers.py @@ -2,7 +2,6 @@ import requests from django.contrib.admin.models import CHANGE, LogEntry -from django.contrib.auth import get_user_model from django.contrib.contenttypes.models import ContentType from codecov_auth.constants import GITLAB_BASE_URL diff --git a/codecov_auth/management/commands/set_trial_status_values.py b/codecov_auth/management/commands/set_trial_status_values.py index fa30a62a90..5a1033f5b8 100644 --- a/codecov_auth/management/commands/set_trial_status_values.py +++ b/codecov_auth/management/commands/set_trial_status_values.py @@ -35,7 +35,7 @@ def handle(self, *args, **options) -> None: if trial_status_type == "all" or trial_status_type == "ongoing": Owner.objects.filter( plan__in=SENTRY_PAID_USER_PLAN_REPRESENTATIONS, - trial_end_date__gt=datetime.utcnow(), + trial_end_date__gt=datetime.now(), ).update(trial_status=TrialStatus.ONGOING.value) # EXPIRED @@ -46,7 +46,7 @@ def handle(self, *args, **options) -> None: plan__in=SENTRY_PAID_USER_PLAN_REPRESENTATIONS, stripe_customer_id__isnull=False, stripe_subscription_id__isnull=False, - trial_end_date__lte=datetime.utcnow(), + trial_end_date__lte=datetime.now(), ) # Currently paying sentry customer without trial_end_date | Q( diff --git a/codecov_auth/management/commands/tests/test_set_trial_status_values.py b/codecov_auth/management/commands/tests/test_set_trial_status_values.py index 4fff05001d..57fb43a313 100644 --- a/codecov_auth/management/commands/tests/test_set_trial_status_values.py +++ b/codecov_auth/management/commands/tests/test_set_trial_status_values.py @@ -14,7 +14,7 @@ class OwnerCommandTestCase(TestCase): def setUp(self): self.command_instance = BaseCommand() - now = datetime.utcnow() + now = datetime.now() later = now + timedelta(days=3) yesterday = now + timedelta(days=-1) much_before = now + timedelta(days=-20) diff --git a/codecov_auth/middleware.py b/codecov_auth/middleware.py index 0e6fd683b9..6664c600bd 100644 --- a/codecov_auth/middleware.py +++ b/codecov_auth/middleware.py @@ -8,8 +8,7 @@ ACCESS_CONTROL_ALLOW_ORIGIN, ) from corsheaders.middleware import CorsMiddleware as BaseCorsMiddleware -from django.conf import settings -from django.http import HttpRequest, JsonResponse +from django.http import HttpRequest, HttpResponse from django.urls import resolve from django.utils.deprecation import MiddlewareMixin from rest_framework import exceptions @@ -47,7 +46,7 @@ class CurrentOwnerMiddleware(MiddlewareMixin): additional database queries). """ - def process_request(self, request): + def process_request(self, request: HttpRequest) -> None: if not request.user or request.user.is_anonymous: request.current_owner = None return @@ -73,18 +72,24 @@ class ImpersonationMiddleware(MiddlewareMixin): Allows staff users to impersonate other users for debugging. """ - def process_request(self, request): + def process_request(self, request: HttpRequest) -> None: + """Log and ensure that the impersonating user is authenticated. + The `current user` is the staff user that is impersonating the + user owner at `impersonating_ownerid`. + """ current_user = request.user if current_user and not current_user.is_anonymous: impersonating_ownerid = request.COOKIES.get("staff_user") if impersonating_ownerid is None: + request.impersonation = False return log.info( "Impersonation attempted", extra=dict( current_user_id=current_user.pk, + current_user_email=current_user.email, impersonating_ownerid=impersonating_ownerid, ), ) @@ -94,6 +99,7 @@ def process_request(self, request): extra=dict( reason="must be a staff user", current_user_id=current_user.pk, + current_user_email=current_user.email, impersonating_ownerid=impersonating_ownerid, ), ) @@ -110,6 +116,7 @@ def process_request(self, request): extra=dict( reason="no such owner", current_user_id=current_user.pk, + current_user_email=current_user.email, impersonating_ownerid=impersonating_ownerid, ), ) @@ -119,13 +126,19 @@ def process_request(self, request): "Impersonation successful", extra=dict( current_user_id=current_user.pk, + current_user_email=current_user.email, impersonating_ownerid=impersonating_ownerid, ), ) + request.impersonation = True + else: + request.impersonation = False class CorsMiddleware(BaseCorsMiddleware): - def process_response(self, request, response): + def process_response( + self, request: HttpRequest, response: HttpResponse + ) -> HttpResponse: response = super().process_response(request, response) if not self.is_enabled(request): return response diff --git a/codecov_auth/migrations/0001_initial.py b/codecov_auth/migrations/0001_initial.py index 512e7c2580..5318fbac90 100644 --- a/codecov_auth/migrations/0001_initial.py +++ b/codecov_auth/migrations/0001_initial.py @@ -6,7 +6,7 @@ import django.contrib.postgres.fields import django.contrib.postgres.fields.citext import django.db.models.deletion -from django.conf import settings +from django.conf import settings # noqa: F401 from django.contrib.postgres.operations import CITextExtension from django.db import migrations, models diff --git a/codecov_auth/migrations/0003_auto_20210924_1003.py b/codecov_auth/migrations/0003_auto_20210924_1003.py index d51a5732c7..6328372ba6 100644 --- a/codecov_auth/migrations/0003_auto_20210924_1003.py +++ b/codecov_auth/migrations/0003_auto_20210924_1003.py @@ -4,7 +4,7 @@ import django.contrib.postgres.fields import django.db.models.deletion -from django.conf import settings +from django.conf import settings # noqa: F401 from django.db import migrations, models diff --git a/codecov_auth/migrations/0006_auto_20211123_1535.py b/codecov_auth/migrations/0006_auto_20211123_1535.py index f3e9582cc5..d53dbc6769 100644 --- a/codecov_auth/migrations/0006_auto_20211123_1535.py +++ b/codecov_auth/migrations/0006_auto_20211123_1535.py @@ -1,7 +1,7 @@ # Generated by Django 3.1.13 on 2021-11-23 15:35 import django.db.models.deletion -from django.conf import settings +from django.conf import settings # noqa: F401 from django.db import migrations, models import codecov_auth.models diff --git a/codecov_auth/migrations/0009_auto_20220511_1313.py b/codecov_auth/migrations/0009_auto_20220511_1313.py index dbcdbac6ce..5f63bca6a1 100644 --- a/codecov_auth/migrations/0009_auto_20220511_1313.py +++ b/codecov_auth/migrations/0009_auto_20220511_1313.py @@ -2,7 +2,6 @@ import django.contrib.postgres.fields import django.db.models.deletion -from django.conf import settings from django.db import migrations, models diff --git a/codecov_auth/migrations/0015_organizationleveltoken.py b/codecov_auth/migrations/0015_organizationleveltoken.py index a0f1558e84..5fe0a0c7bc 100644 --- a/codecov_auth/migrations/0015_organizationleveltoken.py +++ b/codecov_auth/migrations/0015_organizationleveltoken.py @@ -3,7 +3,7 @@ import uuid import django.db.models.deletion -from django.conf import settings +from django.conf import settings # noqa: F401 from django.db import migrations, models diff --git a/codecov_auth/migrations/0018_usertoken.py b/codecov_auth/migrations/0018_usertoken.py index 679554a435..7178dc8414 100644 --- a/codecov_auth/migrations/0018_usertoken.py +++ b/codecov_auth/migrations/0018_usertoken.py @@ -3,7 +3,7 @@ import uuid import django.db.models.deletion -from django.conf import settings +from django.conf import settings # noqa: F401 from django.db import migrations, models diff --git a/codecov_auth/migrations/0020_ownerprofile_default_org.py b/codecov_auth/migrations/0020_ownerprofile_default_org.py index d517a42f07..de9fdb881e 100644 --- a/codecov_auth/migrations/0020_ownerprofile_default_org.py +++ b/codecov_auth/migrations/0020_ownerprofile_default_org.py @@ -1,7 +1,7 @@ # Generated by Django 3.2.12 on 2023-01-19 19:06 import django.db.models.deletion -from django.conf import settings +from django.conf import settings # noqa: F401 from django.db import migrations, models diff --git a/codecov_auth/migrations/0031_user_owner_user.py b/codecov_auth/migrations/0031_user_owner_user.py index 649e154329..e80b4c91cc 100644 --- a/codecov_auth/migrations/0031_user_owner_user.py +++ b/codecov_auth/migrations/0031_user_owner_user.py @@ -1,6 +1,5 @@ # Generated by Django 4.1.7 on 2023-05-22 17:53 -import uuid import django.contrib.postgres.fields.citext import django.db.models.deletion diff --git a/codecov_auth/models.py b/codecov_auth/models.py index 0a965b9d0e..cd96cb672d 100644 --- a/codecov_auth/models.py +++ b/codecov_auth/models.py @@ -1,2 +1,2 @@ from shared.django_apps.codecov_auth.models import * -from shared.django_apps.codecov_auth.models import _generate_key +from shared.django_apps.codecov_auth.models import _generate_key # noqa: F401 diff --git a/codecov_auth/services/org_level_token_service.py b/codecov_auth/services/org_level_token_service.py index 1589ca2468..b1811d35f6 100644 --- a/codecov_auth/services/org_level_token_service.py +++ b/codecov_auth/services/org_level_token_service.py @@ -1,8 +1,7 @@ import logging import uuid -from secrets import token_bytes -from django.db.models.signals import post_save, pre_save +from django.db.models.signals import post_save from django.dispatch import receiver from django.forms import ValidationError diff --git a/codecov_auth/signals.py b/codecov_auth/signals.py index 37f333fb49..6a92548a70 100644 --- a/codecov_auth/signals.py +++ b/codecov_auth/signals.py @@ -1,6 +1,4 @@ import json -import logging -from datetime import datetime from django.conf import settings from django.db.models.signals import post_save @@ -31,7 +29,7 @@ def _get_pubsub_publisher(): @receiver( post_save, sender=OrganizationLevelToken, dispatch_uid="shelter_sync_org_token" ) -def update_repository(sender, instance: OrganizationLevelToken, **kwargs): +def update_org_token(sender, instance: OrganizationLevelToken, **kwargs): pubsub_project_id = settings.SHELTER_PUBSUB_PROJECT_ID topic_id = settings.SHELTER_PUBSUB_SYNC_REPO_TOPIC_ID if pubsub_project_id and topic_id: diff --git a/codecov_auth/tests/factories.py b/codecov_auth/tests/factories.py index c5c5838c4f..3ea584d09f 100644 --- a/codecov_auth/tests/factories.py +++ b/codecov_auth/tests/factories.py @@ -5,14 +5,16 @@ from factory.django import DjangoModelFactory from codecov_auth.models import ( + Account, DjangoSession, + OktaSettings, OktaUser, OrganizationLevelToken, Owner, OwnerProfile, - RepositoryToken, + RepositoryToken, # noqa: F401 SentryUser, - Service, + Service, # noqa: F401 Session, TokenTypeChoices, User, @@ -33,6 +35,25 @@ class Meta: customer_intent = "Business" +class AccountFactory(DjangoModelFactory): + class Meta: + model = Account + + name = factory.Faker("name") + + +class OktaSettingsFactory(DjangoModelFactory): + class Meta: + model = OktaSettings + + account = factory.SubFactory(AccountFactory) + client_id = factory.Faker("uuid4") + client_secret = factory.Faker("uuid4") + url = factory.Faker("url") + enabled = True + enforced = False + + class OwnerFactory(DjangoModelFactory): class Meta: model = Owner diff --git a/codecov_auth/tests/test_admin.py b/codecov_auth/tests/test_admin.py index 812100a4a6..d61e013ea0 100644 --- a/codecov_auth/tests/test_admin.py +++ b/codecov_auth/tests/test_admin.py @@ -1,25 +1,48 @@ +from datetime import timedelta from unittest.mock import MagicMock, patch +import pytest from django.contrib.admin.helpers import ACTION_CHECKBOX_NAME from django.contrib.admin.sites import AdminSite from django.test import RequestFactory, TestCase from django.urls import reverse +from django.utils import timezone +from shared.django_apps.codecov_auth.models import ( + Account, + AccountsUsers, + InvoiceBilling, + StripeBilling, +) +from shared.django_apps.codecov_auth.tests.factories import ( + AccountFactory, + InvoiceBillingFactory, + StripeBillingFactory, +) from codecov.commands.exceptions import ValidationError -from codecov_auth.admin import OrgUploadTokenInline, OwnerAdmin, UserAdmin +from codecov_auth.admin import ( + AccountAdmin, + InvoiceBillingAdmin, + OrgUploadTokenInline, + OwnerAdmin, + StripeBillingAdmin, + UserAdmin, + find_and_remove_stale_users, +) from codecov_auth.models import OrganizationLevelToken, Owner, SentryUser, User from codecov_auth.tests.factories import ( OrganizationLevelTokenFactory, OwnerFactory, SentryUserFactory, + SessionFactory, UserFactory, ) +from core.models import Pull +from core.tests.factories import PullFactory, RepositoryFactory from plan.constants import ( ENTERPRISE_CLOUD_USER_PLAN_REPRESENTATIONS, PlanName, - TrialStatus, ) -from utils.test_utils import APIClient class OwnerAdminTest(TestCase): @@ -227,7 +250,7 @@ def test_org_token_refresh_request_calls_service_to_refresh_token( "organization_tokens-0-REFRESH": "on", "_continue": ["Save and continue editing"], } - response = self.client.post(request_url, data=fake_data) + self.client.post(request_url, data=fake_data) mock_refresh.assert_called_with(str(org_token.id)) @patch( @@ -263,7 +286,7 @@ def test_org_token_request_doesnt_call_service_to_refresh_token(self, mock_refre "organization_tokens-0-token_type": ["upload"], "_continue": ["Save and continue editing"], } - response = self.client.post(request_url, data=fake_data) + self.client.post(request_url, data=fake_data) mock_refresh.assert_not_called() def test_start_trial_ui_display(self): @@ -336,6 +359,21 @@ def test_start_trial_paid_plan(self, mock_start_trial_service): assert res.status_code == 302 assert mock_start_trial_service.called + def test_account_widget(self): + owner = OwnerFactory(user=UserFactory(), plan="users-enterprisey") + rf = RequestFactory() + get_request = rf.get(f"/admin/codecov_auth/owner/{owner.ownerid}/change/") + get_request.user = self.staff_user + sample_input = { + "change": True, + "fields": ["account", "plan", "uses_invoice", "staff"], + } + form = self.owner_admin.get_form(request=get_request, obj=owner, **sample_input) + # admin user cannot create, edit, or delete Account objects from the OwnerAdmin + self.assertFalse(form.base_fields["account"].widget.can_add_related) + self.assertFalse(form.base_fields["account"].widget.can_change_related) + self.assertFalse(form.base_fields["account"].widget.can_delete_related) + class UserAdminTest(TestCase): def setUp(self): @@ -374,7 +412,7 @@ def test_user_admin_list_page(self): sentry_user = SentryUserFactory() res = self.client.get(reverse("admin:codecov_auth_sentryuser_changelist")) assert res.status_code == 200 - content = res.content.decode("utf-8") + res.content.decode("utf-8") assert sentry_user.name in res.content.decode("utf-8") assert sentry_user.email in res.content.decode("utf-8") @@ -388,3 +426,460 @@ def test_user_admin_detail_page(self): assert sentry_user.email in res.content.decode("utf-8") assert sentry_user.access_token not in res.content.decode("utf-8") assert sentry_user.refresh_token not in res.content.decode("utf-8") + + +def create_stale_users( + account: Account | None = None, +) -> tuple[list[Owner], list[Owner]]: + org_1 = OwnerFactory(account=account) + org_2 = OwnerFactory(account=account) + + now = timezone.now() + + user_1 = OwnerFactory(user=UserFactory()) # stale, neither session nor PR + + user_2 = OwnerFactory(user=UserFactory()) # semi-stale, semi-old session + SessionFactory(owner=user_2, lastseen=now - timedelta(days=45)) + + user_3 = OwnerFactory(user=UserFactory()) # stale, old session + SessionFactory(owner=user_3, lastseen=now - timedelta(days=120)) + + user_4 = OwnerFactory(user=UserFactory()) # semi-stale, semi-old PR + pull = PullFactory( + repository=RepositoryFactory(), + author=user_4, + ) + pull.updatestamp = now - timedelta(days=45) + super(Pull, pull).save() # `Pull` overrides the `updatestamp` on each save + + user_5 = OwnerFactory(user=UserFactory()) # stale, old PR + pull = PullFactory( + repository=RepositoryFactory(), + author=user_5, + ) + pull.updatestamp = now - timedelta(days=120) + super(Pull, pull).save() # `Pull` overrides the `updatestamp` on each save + + org_1.plan_activated_users = [ + user_1.ownerid, + user_2.ownerid, + ] + org_1.save() + + org_2.plan_activated_users = [ + user_3.ownerid, + user_4.ownerid, + user_5.ownerid, + ] + org_2.save() + + return ([org_1, org_2], [user_1, user_2, user_3, user_4, user_5]) + + +@pytest.mark.django_db() +def test_stale_user_cleanup(): + orgs, users = create_stale_users() + + # remove stale users with default > 90 days + removed_users, affected_orgs = find_and_remove_stale_users(orgs) + assert removed_users == set([users[0].ownerid, users[2].ownerid, users[4].ownerid]) + assert affected_orgs == set([orgs[0].ownerid, orgs[1].ownerid]) + + orgs = list( + Owner.objects.filter(ownerid__in=[org.ownerid for org in orgs]) + ) # re-fetch orgs + # all good, nothing to do + removed_users, affected_orgs = find_and_remove_stale_users(orgs) + assert removed_users == set() + assert affected_orgs == set() + + # remove even more stale users + removed_users, affected_orgs = find_and_remove_stale_users(orgs, timedelta(days=30)) + assert removed_users == set([users[1].ownerid, users[3].ownerid]) + assert affected_orgs == set([orgs[0].ownerid, orgs[1].ownerid]) + + orgs = list( + Owner.objects.filter(ownerid__in=[org.ownerid for org in orgs]) + ) # re-fetch orgs + # all the users have been deactivated by now + for org in orgs: + assert len(org.plan_activated_users) == 0 + + +class AccountAdminTest(TestCase): + def setUp(self): + staff_user = UserFactory(is_staff=True) + self.client.force_login(user=staff_user) + admin_site = AdminSite() + admin_site.register(Account) + admin_site.register(StripeBilling) + admin_site.register(InvoiceBilling) + admin_site.register(AccountsUsers) + self.account_admin = AccountAdmin(Account, admin_site) + + self.account = AccountFactory(plan_seat_count=4, free_seat_count=2) + self.org_1 = OwnerFactory(account=self.account) + self.org_2 = OwnerFactory(account=self.account) + self.owner_with_user_1 = OwnerFactory(user=UserFactory()) + self.owner_with_user_2 = OwnerFactory(user=UserFactory()) + self.owner_with_user_3 = OwnerFactory(user=UserFactory()) + self.owner_without_user_1 = OwnerFactory(user=None) + self.owner_without_user_2 = OwnerFactory(user=None) + self.student = OwnerFactory(user=UserFactory(), student=True) + self.org_1.plan_activated_users = [ + self.owner_with_user_2.ownerid, + self.owner_with_user_3.ownerid, + self.owner_without_user_1.ownerid, + self.student.ownerid, + self.owner_without_user_2.ownerid, + ] + self.org_2.plan_activated_users = [ + self.owner_with_user_2.ownerid, + self.owner_with_user_3.ownerid, + self.owner_without_user_1.ownerid, + self.student.ownerid, + self.owner_with_user_1.ownerid, + ] + self.org_1.save() + self.org_2.save() + + def test_list_page(self): + res = self.client.get(reverse("admin:codecov_auth_account_changelist")) + self.assertEqual(res.status_code, 200) + decoded_res = res.content.decode("utf-8") + self.assertIn("column-name", decoded_res) + self.assertIn("column-is_active", decoded_res) + self.assertIn( + '', decoded_res + ) + self.assertIn( + '', + decoded_res, + ) + + def test_detail_page(self): + res = self.client.get( + reverse("admin:codecov_auth_account_change", args=[self.account.pk]) + ) + self.assertEqual(res.status_code, 200) + decoded_res = res.content.decode("utf-8") + self.assertIn( + '', decoded_res + ) + self.assertIn("Organizations (read only)", decoded_res) + self.assertIn("Stripe Billing (click save to commit changes)", decoded_res) + self.assertIn("Invoice Billing (click save to commit changes)", decoded_res) + + def test_link_users_to_account(self): + self.assertEqual(AccountsUsers.objects.all().count(), 0) + self.assertEqual(self.account.accountsusers_set.all().count(), 0) + + res = self.client.post( + reverse("admin:codecov_auth_account_changelist"), + { + "action": "link_users_to_account", + ACTION_CHECKBOX_NAME: [self.account.pk], + }, + ) + self.assertEqual(res.status_code, 302) + self.assertEqual(res.url, "/admin/codecov_auth/account/") + messages = list(res.wsgi_request._messages) + self.assertEqual(messages[0].message, "Created a User for 2 Owners") + self.assertEqual( + messages[1].message, "Created 6 AccountsUsers, removed 0 AccountsUsers" + ) + + self.assertEqual(AccountsUsers.objects.all().count(), 6) + self.assertEqual( + AccountsUsers.objects.filter(account_id=self.account.id).count(), 6 + ) + + for org in [self.org_1, self.org_2]: + for active_owner_id in org.plan_activated_users: + owner_obj = Owner.objects.get(pk=active_owner_id) + self.assertTrue( + AccountsUsers.objects.filter( + account=self.account, user_id=owner_obj.user_id + ).exists() + ) + + # another user joins + another_owner_with_user = OwnerFactory(user=UserFactory()) + self.org_1.plan_activated_users.append(another_owner_with_user.ownerid) + self.org_1.save() + # rerun action to re-sync + res = self.client.post( + reverse("admin:codecov_auth_account_changelist"), + { + "action": "link_users_to_account", + ACTION_CHECKBOX_NAME: [self.account.pk], + }, + ) + self.assertEqual(res.status_code, 302) + self.assertEqual(res.url, "/admin/codecov_auth/account/") + messages = list(res.wsgi_request._messages) + self.assertEqual(messages[2].message, "Created a User for 0 Owners") + self.assertEqual( + messages[3].message, "Created 1 AccountsUsers, removed 0 AccountsUsers" + ) + + self.assertEqual(AccountsUsers.objects.all().count(), 7) + self.assertEqual( + AccountsUsers.objects.filter(account_id=self.account.id).count(), 7 + ) + self.assertIn( + another_owner_with_user.user_id, + self.account.accountsusers_set.all().values_list("user_id", flat=True), + ) + + def test_link_users_to_account_not_enough_seats(self): + self.assertEqual(AccountsUsers.objects.all().count(), 0) + self.account.plan_seat_count = 1 + self.account.save() + res = self.client.post( + reverse("admin:codecov_auth_account_changelist"), + { + "action": "link_users_to_account", + ACTION_CHECKBOX_NAME: [self.account.pk], + }, + ) + self.assertEqual(res.status_code, 302) + self.assertEqual(res.url, "/admin/codecov_auth/account/") + messages = list(res.wsgi_request._messages) + self.assertEqual( + messages[0].message, + "Request failed: Account plan does not have enough seats; current plan activated users (non-students): 5, total seats for account: 3", + ) + self.assertEqual(AccountsUsers.objects.all().count(), 0) + + def test_seat_check(self): + # edge case: User has multiple Owners, one of which is a Student, but should still count as 1 seat on this Account + user = self.owner_with_user_1.user + OwnerFactory( + service="gitlab", user=user, student=False + ) # another owner on this user + OwnerFactory( + service="bitbucket", user=user, student=True + ) # student owner on this user + + self.assertEqual(AccountsUsers.objects.all().count(), 0) + self.account.plan_seat_count = 1 + self.account.save() + res = self.client.post( + reverse("admin:codecov_auth_account_changelist"), + { + "action": "seat_check", + ACTION_CHECKBOX_NAME: [self.account.pk], + }, + ) + self.assertEqual(res.status_code, 302) + self.assertEqual(res.url, "/admin/codecov_auth/account/") + messages = list(res.wsgi_request._messages) + self.assertEqual( + messages[0].message, + "Request failed: Account plan does not have enough seats; current plan activated users (non-students): 5, total seats for account: 3", + ) + + self.account.plan_seat_count = 10 + self.account.save() + res = self.client.post( + reverse("admin:codecov_auth_account_changelist"), + { + "action": "seat_check", + ACTION_CHECKBOX_NAME: [self.account.pk], + }, + ) + self.assertEqual(res.status_code, 302) + self.assertEqual(res.url, "/admin/codecov_auth/account/") + messages = list(res.wsgi_request._messages) + self.assertEqual( + messages[1].message, + "Request succeeded: Account plan has enough seats! current plan activated users (non-students): 5, total seats for account: 12", + ) + self.assertEqual(AccountsUsers.objects.all().count(), 0) + + def test_link_users_to_account_remove_unneeded_account_users(self): + res = self.client.post( + reverse("admin:codecov_auth_account_changelist"), + { + "action": "link_users_to_account", + ACTION_CHECKBOX_NAME: [self.account.pk], + }, + ) + self.assertEqual(res.status_code, 302) + self.assertEqual(res.url, "/admin/codecov_auth/account/") + messages = list(res.wsgi_request._messages) + self.assertEqual(messages[0].message, "Created a User for 2 Owners") + self.assertEqual( + messages[1].message, "Created 6 AccountsUsers, removed 0 AccountsUsers" + ) + + self.assertEqual(AccountsUsers.objects.all().count(), 6) + self.assertEqual( + AccountsUsers.objects.filter(account_id=self.account.id).count(), 6 + ) + + for org in [self.org_1, self.org_2]: + for active_owner_id in org.plan_activated_users: + owner_obj = Owner.objects.get(pk=active_owner_id) + self.assertTrue( + AccountsUsers.objects.filter( + account=self.account, user_id=owner_obj.user_id + ).exists() + ) + + # disconnect one of the orgs + self.org_2.account = None + self.org_2.save() + + # re-sync to remove Account users from org 2 that are not connected to other account orgs (just owner_with_user_1) + res = self.client.post( + reverse("admin:codecov_auth_account_changelist"), + { + "action": "link_users_to_account", + ACTION_CHECKBOX_NAME: [self.account.pk], + }, + ) + self.assertEqual(res.status_code, 302) + self.assertEqual(res.url, "/admin/codecov_auth/account/") + messages = list(res.wsgi_request._messages) + self.assertEqual(messages[2].message, "Created a User for 0 Owners") + self.assertEqual( + messages[3].message, "Created 0 AccountsUsers, removed 1 AccountsUsers" + ) + + self.assertEqual(AccountsUsers.objects.all().count(), 5) + self.assertEqual( + AccountsUsers.objects.filter(account_id=self.account.id).count(), 5 + ) + still_connected = [ + self.owner_with_user_2, + self.owner_with_user_3, + self.owner_without_user_1, + self.owner_without_user_2, + self.student, + ] + for owner in still_connected: + owner.refresh_from_db() + self.assertTrue( + AccountsUsers.objects.filter( + account=self.account, user_id=owner.user_id + ).exists() + ) + + self.owner_with_user_1.refresh_from_db() # removed user + # no longer connected to account + self.assertFalse( + AccountsUsers.objects.filter( + account=self.account, user_id=self.owner_with_user_1.user_id + ).exists() + ) + # still connected to org + self.assertIn( + self.owner_with_user_1.ownerid, + Owner.objects.get(pk=self.org_2.pk).plan_activated_users, + ) + # user object still exists, with no account connections + self.assertIsNotNone(self.owner_with_user_1.user_id) + self.assertFalse( + AccountsUsers.objects.filter(user=self.owner_with_user_1.user).exists() + ) + + def test_deactivate_stale_users(self): + account = AccountFactory() + orgs, users = create_stale_users(account) + + res = self.client.post( + reverse("admin:codecov_auth_account_changelist"), + { + "action": "deactivate_stale_users", + ACTION_CHECKBOX_NAME: [account.pk], + }, + ) + messages = list(res.wsgi_request._messages) + self.assertEqual( + messages[-1].message, "Removed 3 stale users from 2 affected organizations." + ) + + res = self.client.post( + reverse("admin:codecov_auth_account_changelist"), + { + "action": "deactivate_stale_users", + ACTION_CHECKBOX_NAME: [account.pk], + }, + ) + messages = list(res.wsgi_request._messages) + self.assertEqual( + messages[-1].message, + "No stale users found in selected accounts / organizations.", + ) + + +class StripeBillingAdminTest(TestCase): + def setUp(self): + self.staff_user = UserFactory(is_staff=True) + self.client.force_login(user=self.staff_user) + admin_site = AdminSite() + admin_site.register(StripeBilling) + self.stripe_admin = StripeBillingAdmin(StripeBilling, admin_site) + self.account = AccountFactory() + self.obj = StripeBillingFactory(account=self.account) + + def test_account_widget(self): + rf = RequestFactory() + get_request = rf.get(f"/admin/codecov_auth/stripebilling/{self.obj.id}/change/") + sample_input = { + "change": True, + "fields": [ + "id", + "created_at", + "updated_at", + "account", + "customer_id", + "subscription_id", + "is_active", + ], + } + form = self.stripe_admin.get_form( + request=get_request, obj=self.obj, **sample_input + ) + # admin user cannot create, edit, or delete Account objects from the StripeBillingAdmin + self.assertFalse(form.base_fields["account"].widget.can_add_related) + self.assertFalse(form.base_fields["account"].widget.can_change_related) + self.assertFalse(form.base_fields["account"].widget.can_delete_related) + + +class InvoiceBillingAdminTest(TestCase): + def setUp(self): + self.staff_user = UserFactory(is_staff=True) + self.client.force_login(user=self.staff_user) + admin_site = AdminSite() + admin_site.register(InvoiceBilling) + self.invoice_admin = InvoiceBillingAdmin(InvoiceBilling, admin_site) + self.account = AccountFactory() + self.obj = InvoiceBillingFactory(account=self.account) + + def test_account_widget(self): + rf = RequestFactory() + get_request = rf.get( + f"/admin/codecov_auth/invoicebilling/{self.obj.id}/change/" + ) + sample_input = { + "change": True, + "fields": [ + "id", + "created_at", + "updated_at", + "account", + "account_manager", + "invoice_notes", + "is_active", + ], + } + form = self.invoice_admin.get_form( + request=get_request, obj=self.obj, **sample_input + ) + # admin user cannot create, edit, or delete Account objects from the InvoiceBillingAdmin + self.assertFalse(form.base_fields["account"].widget.can_add_related) + self.assertFalse(form.base_fields["account"].widget.can_change_related) + self.assertFalse(form.base_fields["account"].widget.can_delete_related) diff --git a/codecov_auth/tests/test_migrations.py b/codecov_auth/tests/test_migrations.py index d5ddba1b14..adeebbaf2a 100644 --- a/codecov_auth/tests/test_migrations.py +++ b/codecov_auth/tests/test_migrations.py @@ -1,5 +1,3 @@ -from datetime import datetime - import pytest from utils.test_utils import TestMigrations diff --git a/codecov_auth/tests/unit/test_authentication.py b/codecov_auth/tests/unit/test_authentication.py index 0db718bf10..6342d7a42b 100644 --- a/codecov_auth/tests/unit/test_authentication.py +++ b/codecov_auth/tests/unit/test_authentication.py @@ -1,5 +1,6 @@ from datetime import datetime, timedelta from http.cookies import SimpleCookie +from unittest.mock import call, patch import pytest from django.conf import settings @@ -7,6 +8,7 @@ from django.urls import ResolverMatch from rest_framework.exceptions import AuthenticationFailed, PermissionDenied from rest_framework.test import APIRequestFactory +from shared.django_apps.core.tests.factories import RepositoryFactory from codecov_auth.authentication import ( InternalTokenAuthentication, @@ -103,7 +105,7 @@ def test_bearer_token_auth_invalid_super_token(self): authenticator = SuperTokenAuthentication() result = authenticator.authenticate(request) - assert result == None + assert result is None def test_bearer_token_default_token_envar(self): super_token = "0ae68e58-79f8-4341-9531-55aada05a251" @@ -111,7 +113,7 @@ def test_bearer_token_default_token_envar(self): request = request_factory.get("", HTTP_AUTHORIZATION=f"Bearer {super_token}") authenticator = SuperTokenAuthentication() result = authenticator.authenticate(request) - assert result == None + assert result is None def test_bearer_token_default_token_envar_and_same_string_as_header(self): super_token = settings.SUPER_API_TOKEN @@ -168,7 +170,7 @@ def test_bearer_token_default_token_envar_and_same_string_as_header(self): class ImpersonationTests(TransactionTestCase): def setUp(self): self.owner_to_impersonate = OwnerFactory( - username="impersonateme", service="github" + username="impersonateme", service="github", user=UserFactory(is_staff=False) ) self.staff_user = UserFactory(is_staff=True) self.non_staff_user = UserFactory(is_staff=False) @@ -184,6 +186,47 @@ def test_impersonation(self): ) assert res.json()["data"]["me"] == {"user": {"username": "impersonateme"}} + @patch("core.commands.repository.repository.RepositoryCommands.fetch_repository") + def test_impersonation_with_okta(self, mock_call_to_fetch_repository): + repo = RepositoryFactory(author=self.owner_to_impersonate, private=True) + query_repositories = """{ owner(username: "%s") { repository(name: "%s") { ... on Repository { name } } } }""" + query = query_repositories % (repo.author.username, repo.name) + + # not impersonating + del self.client.cookies["staff_user"] + self.client.force_login(user=self.owner_to_impersonate.user) + self.client.post( + "/graphql/gh", + {"query": query}, + content_type="application/json", + ) + + # impersonating, same query + self.client.cookies = SimpleCookie({"staff_user": self.owner_to_impersonate.pk}) + self.client.force_login(user=self.staff_user) + self.client.post( + "/graphql/gh", + {"query": query}, + content_type="application/json", + ) + + mock_call_to_fetch_repository.assert_has_calls( + [ + call( + self.owner_to_impersonate, + repo.name, + [], + exclude_okta_enforced_repos=True, + ), + call( + self.owner_to_impersonate, + repo.name, + [], + exclude_okta_enforced_repos=False, + ), + ] + ) + def test_impersonation_non_staff(self): self.client.force_login(user=self.non_staff_user) with pytest.raises(PermissionDenied): diff --git a/codecov_auth/tests/unit/test_helpers.py b/codecov_auth/tests/unit/test_helpers.py index 34773ee3c3..6f98602f05 100644 --- a/codecov_auth/tests/unit/test_helpers.py +++ b/codecov_auth/tests/unit/test_helpers.py @@ -1,11 +1,9 @@ -from pprint import pprint from unittest.mock import patch import pytest -from django.contrib.admin.models import CHANGE, LogEntry +from django.contrib.admin.models import LogEntry from codecov_auth.helpers import History, current_user_part_of_org -from codecov_auth.models import Owner, User from ..factories import OwnerFactory diff --git a/codecov_auth/tests/unit/test_managers.py b/codecov_auth/tests/unit/test_managers.py index c009eba4ee..44452e0dbe 100644 --- a/codecov_auth/tests/unit/test_managers.py +++ b/codecov_auth/tests/unit/test_managers.py @@ -1,8 +1,7 @@ from django.test import TestCase from codecov_auth.models import Owner -from codecov_auth.tests.factories import OwnerFactory, SessionFactory -from core.tests.factories import PullFactory, RepositoryFactory +from codecov_auth.tests.factories import OwnerFactory class OwnerManagerTests(TestCase): diff --git a/codecov_auth/tests/unit/test_middleware.py b/codecov_auth/tests/unit/test_middleware.py index 83a7f600fe..edc5b8179a 100644 --- a/codecov_auth/tests/unit/test_middleware.py +++ b/codecov_auth/tests/unit/test_middleware.py @@ -1,7 +1,5 @@ from django.test import TestCase, override_settings -from django.urls import reverse -from codecov_auth.tests.factories import OwnerFactory from utils.test_utils import Client diff --git a/codecov_auth/tests/unit/test_repo_authentication.py b/codecov_auth/tests/unit/test_repo_authentication.py index 83d0def0b9..e24772ade2 100644 --- a/codecov_auth/tests/unit/test_repo_authentication.py +++ b/codecov_auth/tests/unit/test_repo_authentication.py @@ -1,6 +1,6 @@ import uuid from datetime import datetime, timedelta -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import patch import pytest from django.core.exceptions import ObjectDoesNotExist @@ -10,7 +10,6 @@ from jwt import PyJWTError from rest_framework import exceptions from rest_framework.test import APIRequestFactory -from shared.torngit.exceptions import TorngitObjectNotFoundError, TorngitRateLimitError from codecov_auth.authentication.repo_auth import ( GitHubOIDCTokenAuthentication, @@ -168,58 +167,69 @@ def get_mocked_global_tokens(self): "bitbucketserveruploadtoken": "bitbucket_server", } - def test_authentication_for_non_enterprise(self): + @patch("codecov_auth.authentication.repo_auth.get_global_tokens") + def test_authentication_no_global_token_available(self, mocked_get_global_tokens): + mocked_get_global_tokens.return_value = {} authentication = GlobalTokenAuthentication() - request = APIRequestFactory().post("/endpoint") + request = APIRequestFactory().post("/upload/service/owner::::repo/commits") res = authentication.authenticate(request) assert res is None @patch("codecov_auth.authentication.repo_auth.get_global_tokens") - @patch("codecov_auth.authentication.repo_auth.GlobalTokenAuthentication.get_token") - def test_authentication_for_enterprise_wrong_token( - self, mocked_token, mocked_get_global_tokens - ): + def test_authentication_for_enterprise_wrong_token(self, mocked_get_global_tokens): mocked_get_global_tokens.return_value = self.get_mocked_global_tokens() - mocked_token.return_value = "random_token" authentication = GlobalTokenAuthentication() - request = APIRequestFactory().post("/endpoint") + request = APIRequestFactory().post( + "/upload/service/owner::::repo/commits", + headers={"Authorization": "token GUT"}, + ) res = authentication.authenticate(request) assert res is None @patch("codecov_auth.authentication.repo_auth.get_global_tokens") - @patch("codecov_auth.authentication.repo_auth.GlobalTokenAuthentication.get_token") - @patch("codecov_auth.authentication.repo_auth.GlobalTokenAuthentication.get_owner") def test_authentication_for_enterprise_correct_token_repo_not_exists( - self, mocked_owner, mocked_token, mocked_get_global_tokens, db + self, mocked_get_global_tokens, db ): mocked_get_global_tokens.return_value = self.get_mocked_global_tokens() - mocked_token.return_value = "githubuploadtoken" - mocked_owner.return_value = OwnerFactory.create() authentication = GlobalTokenAuthentication() - request = APIRequestFactory().post("/endpoint") + request = APIRequestFactory().post( + "/upload/service/owner::::repo/commits", + headers={"Authorization": "token githubuploadtoken"}, + ) with pytest.raises(exceptions.AuthenticationFailed) as exc: authentication.authenticate(request) assert exc.value.args == ( "Could not find a repository, try using repo upload token", ) + @pytest.mark.parametrize( + "owner_service, owner_name, token", + [ + pytest.param("github", "username", "githubuploadtoken", id="github"), + pytest.param( + "gitlab", "username", "gitlabuploadtoken", id="gitlab_single_user" + ), + pytest.param( + "gitlab", + "usergroup:username", + "gitlabuploadtoken", + id="gitlab_subgroup_user", + ), + ], + ) @patch("codecov_auth.authentication.repo_auth.get_global_tokens") - @patch("codecov_auth.authentication.repo_auth.GlobalTokenAuthentication.get_token") - @patch("codecov_auth.authentication.repo_auth.GlobalTokenAuthentication.get_owner") - @patch("codecov_auth.authentication.repo_auth.GlobalTokenAuthentication.get_repoid") def test_authentication_for_enterprise_correct_token_repo_exists( - self, mocked_repoid, mocked_owner, mocked_token, mocked_get_global_tokens, db + self, mocked_get_global_tokens, owner_service, owner_name, token, db ): mocked_get_global_tokens.return_value = self.get_mocked_global_tokens() - mocked_token.return_value = "githubuploadtoken" - owner = OwnerFactory.create(service="github") - repoid = 123 - mocked_repoid.return_value = repoid - mocked_owner.return_value = owner - - repository = RepositoryFactory.create(author=owner, repoid=repoid) + owner = OwnerFactory.create(service=owner_service, username=owner_name) + owner_name.replace(":", ":::") # encode name to test GL subgroups + repository = RepositoryFactory.create(author=owner) authentication = GlobalTokenAuthentication() - request = APIRequestFactory().post("/endpoint") + request = APIRequestFactory().post( + f"/upload/{owner_service}/{owner_name}::::{repository.name}/commits", + headers={"Authorization": f"token {token}"}, + ) res = authentication.authenticate(request) assert res is not None user, auth = res @@ -290,7 +300,7 @@ def test_owner_has_no_token_return_none(self, db, mocker): token = uuid.uuid4() authentication = OrgLevelTokenAuthentication() res = authentication.authenticate_credentials(token) - assert res == None + assert res is None @override_settings(IS_ENTERPRISE=False) def test_owner_has_token_but_wrong_one_sent_return_none(self, db, mocker): @@ -305,7 +315,7 @@ def test_owner_has_token_but_wrong_one_sent_return_none(self, db, mocker): ) authentication = OrgLevelTokenAuthentication() res = authentication.authenticate(request) - assert res == None + assert res is None assert OrganizationLevelToken.objects.filter(owner=owner).count() == 1 @override_settings(IS_ENTERPRISE=False) @@ -512,6 +522,7 @@ def test_tokenless_success( if existing_commit: commit = CommitFactory() commit.branch = commit_branch + commit.repository = repo commit.save() request = APIRequestFactory().post( diff --git a/codecov_auth/tests/unit/views/test_base.py b/codecov_auth/tests/unit/views/test_base.py index e82164b2f0..ec4e326ccb 100644 --- a/codecov_auth/tests/unit/views/test_base.py +++ b/codecov_auth/tests/unit/views/test_base.py @@ -4,7 +4,7 @@ import pytest from django.conf import settings from django.contrib.sessions.backends.cache import SessionStore -from django.core.exceptions import PermissionDenied, SuspiciousOperation +from django.core.exceptions import PermissionDenied from django.http import HttpResponse from django.test import RequestFactory, TestCase, override_settings from freezegun import freeze_time @@ -190,7 +190,7 @@ def test_get_or_create_calls_analytics_user_signed_in_when_owner_not_created( @override_settings(IS_ENTERPRISE=False) @patch("services.analytics.AnalyticsService.user_signed_in") def test_set_marketing_tags_on_cookies(self, user_signed_in_mock): - owner = OwnerFactory(service="github") + OwnerFactory(service="github") self.request = RequestFactory().get( "", { @@ -402,7 +402,7 @@ def test_check_user_account_limitations_enterprise_user_new_not_pr_billing( == 0 ) # If the number of users is larger than the limit, raise error - with pytest.raises(PermissionDenied) as exp: + with pytest.raises(PermissionDenied): OwnerFactory(service="github", ownerid=12, oauth_token="very-fake-token") OwnerFactory(service="github", ownerid=13, oauth_token=None) OwnerFactory(service="github", ownerid=14, oauth_token="very-fake-token") @@ -425,7 +425,7 @@ def test_check_user_account_limitations_enterprise_pr_billing( ) mock_get_current_license.return_value = license # User doesn't exist, and existing users will raise error - with pytest.raises(PermissionDenied) as exp: + with pytest.raises(PermissionDenied): OwnerFactory(ownerid=1, service="github", plan_activated_users=[1, 2, 3]) OwnerFactory( ownerid=2, diff --git a/codecov_auth/tests/unit/views/test_bitbucket.py b/codecov_auth/tests/unit/views/test_bitbucket.py index 31288ea6f5..a854c86cf0 100644 --- a/codecov_auth/tests/unit/views/test_bitbucket.py +++ b/codecov_auth/tests/unit/views/test_bitbucket.py @@ -1,4 +1,4 @@ -from unittest.mock import patch +from unittest.mock import call, patch from django.core import signing from django.http.cookie import SimpleCookie @@ -27,6 +27,7 @@ def test_get_bitbucket_redirect(client, settings, mocker): url = reverse("bitbucket-login") res = client.get(url, SERVER_NAME="localhost:8000") assert res.status_code == 302 + assert "_oauth_request_token" in res.cookies cookie = res.cookies["_oauth_request_token"] assert cookie.value @@ -128,6 +129,10 @@ async def fake_list_teams(): ).sign(encryptor.encode(oauth_request_token).decode()) } ) + mock_create_user_onboarding_metric = mocker.patch( + "shared.django_apps.codecov_metrics.service.codecov_metrics.UserOnboardingMetricsService.create_user_onboarding_metric" + ) + res = client.get( url, {"oauth_verifier": 8519288973, "oauth_token": "test1daxl4jnhegoh4"}, @@ -143,6 +148,13 @@ async def fake_list_teams(): "test6tl3evq7c8vuyn", "testdm61tppb5x0tam7nae3qajhcepzz", "8519288973" ) owner = Owner.objects.get(username="ThiagoCodecov", service="bitbucket") + expected_call = call( + org_id=owner.ownerid, + event="INSTALLED_APP", + payload={"login": "bitbucket"}, + ) + assert mock_create_user_onboarding_metric.call_args_list == [expected_call] + assert ( encryptor.decode(owner.oauth_token) == "test6tl3evq7c8vuyn:testdm61tppb5x0tam7nae3qajhcepzz" diff --git a/codecov_auth/tests/unit/views/test_bitbucket_server.py b/codecov_auth/tests/unit/views/test_bitbucket_server.py index 394ccd61b2..e7eb7cc982 100644 --- a/codecov_auth/tests/unit/views/test_bitbucket_server.py +++ b/codecov_auth/tests/unit/views/test_bitbucket_server.py @@ -1,5 +1,3 @@ -from unittest.mock import patch - import pytest from django.core import signing from django.http.cookie import SimpleCookie @@ -9,7 +7,6 @@ from codecov_auth.models import Owner from codecov_auth.views.bitbucket_server import ( BitbucketServer, - BitbucketServerLoginView, ) from utils.encryption import encryptor @@ -44,7 +41,7 @@ def faulty_response(*args, **kwargs): # This is the error class that BitbucketServer.api generates raise TorngitClientGeneralError(500, "data data", "BBS unavailable") - client_request_mock = mocker.patch.object( + mocker.patch.object( BitbucketServer, "api", side_effect=faulty_response, @@ -54,7 +51,7 @@ def faulty_response(*args, **kwargs): settings.BITBUCKET_CLIENT_ID = "testqmo19ebdkseoby" settings.BITBUCKET_CLIENT_SECRET = "testfi8hzehvz453qj8mhv21ca4rf83f" with pytest.raises(TorngitClientGeneralError): - res = client.get(reverse("bbs-login"), SERVER_NAME="localhost:8000") + client.get(reverse("bbs-login"), SERVER_NAME="localhost:8000") def test_get_bbs_already_token(client, settings, mocker, db, mock_redis): diff --git a/codecov_auth/tests/unit/views/test_github.py b/codecov_auth/tests/unit/views/test_github.py index d01c6d7311..57935b83c6 100644 --- a/codecov_auth/tests/unit/views/test_github.py +++ b/codecov_auth/tests/unit/views/test_github.py @@ -1,5 +1,6 @@ import re from datetime import datetime +from unittest.mock import call import pytest from django.http.cookie import SimpleCookie @@ -54,7 +55,7 @@ def fake_config(*path, default=None): return default return curr - mock_get_config = mocker.patch( + mocker.patch( "shared.torngit.github.get_config", side_effect=fake_config, ) @@ -312,10 +313,6 @@ async def is_student(*args, **kwargs): assert owner.email == "thiago@codecov.io" assert owner.private_access is True assert res.url == "http://localhost:3000/gh" - assert "session_expiry" in res.cookies - session_expiry_cookie = res.cookies["session_expiry"] - assert session_expiry_cookie.value == "2023-02-01T08:00:00Z" - assert session_expiry_cookie.get("domain") == ".simple.site" @freeze_time("2023-01-01T00:00:00") @@ -355,6 +352,9 @@ async def is_student(*args, **kwargs): as_tuple=mocker.MagicMock(return_value=("a", "b")) ), ) + mock_create_user_onboarding_metric = mocker.patch( + "shared.django_apps.codecov_metrics.service.codecov_metrics.UserOnboardingMetricsService.create_user_onboarding_metric" + ) session = client.session session["github_oauth_state"] = "abc" @@ -362,6 +362,13 @@ async def is_student(*args, **kwargs): mock_redis.setex("oauth-state-abc", 300, "http://localhost:3000/gh") url = reverse("github-login") res = client.get(url, {"code": "aaaaaaa", "state": "abc"}) + expected_call = call( + org_id=client.session["current_owner_id"], + event="INSTALLED_APP", + payload={"login": "github"}, + ) + assert mock_create_user_onboarding_metric.call_args_list == [expected_call] + assert res.status_code == 302 owner = Owner.objects.get(pk=client.session["current_owner_id"]) @@ -371,10 +378,6 @@ async def is_student(*args, **kwargs): assert owner.private_access is True assert res.url == "http://localhost:3000/gh" assert owner.student is True - assert "session_expiry" in res.cookies - session_expiry_cookie = res.cookies["session_expiry"] - assert session_expiry_cookie.value == "2023-01-01T08:00:00Z" - assert session_expiry_cookie.get("domain") == ".simple.site" @freeze_time("2023-01-01T00:00:00") @@ -437,10 +440,6 @@ async def is_student(*args, **kwargs): assert owner.service_id == "44376991" assert owner.private_access is True assert res.url == "http://localhost:3000/gh" - assert "session_expiry" in res.cookies - session_expiry_cookie = res.cookies["session_expiry"] - assert session_expiry_cookie.value == "2023-01-01T08:00:00Z" - assert session_expiry_cookie.get("domain") == ".simple.site" @pytest.mark.asyncio diff --git a/codecov_auth/tests/unit/views/test_github_enterprise.py b/codecov_auth/tests/unit/views/test_github_enterprise.py index c83c9773a4..545152ecdd 100644 --- a/codecov_auth/tests/unit/views/test_github_enterprise.py +++ b/codecov_auth/tests/unit/views/test_github_enterprise.py @@ -37,7 +37,7 @@ def test_get_ghe_redirect(client, mocker, mock_redis, settings): @pytest.mark.django_db def test_get_ghe_redirect_with_ghpr_cookie(client, mocker, mock_redis, settings): - mock_get_config = mocker.patch( + mocker.patch( "shared.torngit.github_enterprise.get_config", side_effect=lambda *args: "https://my.githubenterprise.com", ) @@ -61,7 +61,7 @@ def test_get_ghe_redirect_with_ghpr_cookie(client, mocker, mock_redis, settings) @pytest.mark.django_db def test_get_github_redirect_with_private_url(client, mocker, mock_redis, settings): - mock_get_config = mocker.patch( + mocker.patch( "shared.torngit.github_enterprise.get_config", side_effect=lambda *args: "https://my.githubenterprise.com", ) @@ -83,7 +83,7 @@ def test_get_github_redirect_with_private_url(client, mocker, mock_redis, settin def test_get_ghe_already_with_code(client, mocker, db, mock_redis, settings): - mock_get_config = mocker.patch( + mocker.patch( "shared.torngit.github_enterprise.get_config", side_effect=lambda *args: "https://my.githubenterprise.com", ) @@ -214,7 +214,7 @@ async def is_student(*args, **kwargs): def test_get_ghe_already_with_code_github_error( client, mocker, db, mock_redis, settings ): - mock_get_config = mocker.patch( + mocker.patch( "shared.torngit.github_enterprise.get_config", side_effect=lambda *args: "https://my.githubenterprise.com", ) @@ -242,7 +242,7 @@ async def helper_func(*args, **kwargs): def test_state_not_known(client, mocker, db, mock_redis, settings): - mock_get_config = mocker.patch( + mocker.patch( "shared.torngit.github_enterprise.get_config", side_effect=lambda *args: "https://my.githubenterprise.com", ) @@ -255,7 +255,7 @@ def test_state_not_known(client, mocker, db, mock_redis, settings): def test_get_ghe_already_with_code_with_email(client, mocker, db, mock_redis, settings): - mock_get_config = mocker.patch( + mocker.patch( "shared.torngit.github_enterprise.get_config", side_effect=lambda *args: "https://my.githubenterprise.com", ) @@ -315,7 +315,7 @@ async def is_student(*args, **kwargs): def test_get_ghe_already_owner_already_exist(client, mocker, db, mock_redis, settings): - mock_get_config = mocker.patch( + mocker.patch( "shared.torngit.github_enterprise.get_config", side_effect=lambda *args: "https://my.githubenterprise.com", ) diff --git a/codecov_auth/tests/unit/views/test_gitlab.py b/codecov_auth/tests/unit/views/test_gitlab.py index ee8144b9a0..c6a0e4e50b 100644 --- a/codecov_auth/tests/unit/views/test_gitlab.py +++ b/codecov_auth/tests/unit/views/test_gitlab.py @@ -1,3 +1,4 @@ +from unittest.mock import call from uuid import UUID import pytest @@ -84,7 +85,9 @@ async def helper_list_teams_func(*args, **kwargs): session = client.session session["gitlab_oauth_state"] = "abc" session.save() - + mock_create_user_onboarding_metric = mocker.patch( + "shared.django_apps.codecov_metrics.service.codecov_metrics.UserOnboardingMetricsService.create_user_onboarding_metric" + ) url = reverse("gitlab-login") mock_redis.setex("oauth-state-abc", 300, "http://localhost:3000/gl") res = client.get(url, {"code": "aaaaaaa", "state": "abc"}) @@ -94,6 +97,14 @@ async def helper_list_teams_func(*args, **kwargs): assert owner.username == "ThiagoCodecov" assert owner.service_id == "3124507" assert res.url == "http://localhost:3000/gl" + + expected_call = call( + org_id=owner.ownerid, + event="INSTALLED_APP", + payload={"login": "gitlab"}, + ) + assert mock_create_user_onboarding_metric.call_args_list == [expected_call] + assert encryptor.decode(owner.oauth_token) == f"{access_token}: :{refresh_token}" diff --git a/codecov_auth/tests/unit/views/test_logout.py b/codecov_auth/tests/unit/views/test_logout.py index f9a484e2b4..95b679f9a3 100644 --- a/codecov_auth/tests/unit/views/test_logout.py +++ b/codecov_auth/tests/unit/views/test_logout.py @@ -1,6 +1,3 @@ -from unittest.mock import patch - -from django.core.exceptions import ObjectDoesNotExist from django.test import TransactionTestCase from codecov_auth.tests.factories import OwnerFactory diff --git a/codecov_auth/tests/unit/views/test_okta.py b/codecov_auth/tests/unit/views/test_okta.py index 1a9b7bc0ca..a8e549a7ab 100644 --- a/codecov_auth/tests/unit/views/test_okta.py +++ b/codecov_auth/tests/unit/views/test_okta.py @@ -1,7 +1,3 @@ -from http.cookies import SimpleCookie -from unittest.mock import MagicMock, patch - -import jwt import pytest from django.conf import settings from django.contrib import auth @@ -10,14 +6,14 @@ from codecov_auth.models import OktaUser from codecov_auth.tests.factories import OktaUserFactory, OwnerFactory, UserFactory -from codecov_auth.views.okta import OktaLoginView, validate_id_token -from codecov_auth.views.okta import auth as okta_basic_auth +from codecov_auth.views.okta import OKTA_BASIC_AUTH +from codecov_auth.views.okta_mixin import OktaIdTokenPayload @pytest.fixture def mocked_okta_token_request(mocker): return mocker.patch( - "codecov_auth.views.okta.requests.post", + "codecov_auth.views.okta_mixin.requests.post", return_value=mocker.MagicMock( status_code=200, json=mocker.MagicMock( @@ -26,7 +22,7 @@ def mocked_okta_token_request(mocker): "refresh_token": "test-refresh-token", "id_token": "test-id-token", "state": "test-state", - } + }, ), ), ) @@ -36,11 +32,13 @@ def mocked_okta_token_request(mocker): def mocked_validate_id_token(mocker): return mocker.patch( "codecov_auth.views.okta.validate_id_token", - return_value={ - "sub": "test-id", - "email": "test@example.com", - "name": "Some User", - }, + return_value=OktaIdTokenPayload( + sub="test-id", + email="test@example.com", + name="Some User", + iss="https://example.com", + aud="test-client-id", + ), ) @@ -55,91 +53,6 @@ def mocked_validate_id_token(mocker): mwIDAQAB -----END PUBLIC KEY----- """ -private_key = """-----BEGIN PRIVATE KEY----- -MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQC7VJTUt9Us8cKj -MzEfYyjiWA4R4/M2bS1GB4t7NXp98C3SC6dVMvDuictGeurT8jNbvJZHtCSuYEvu -NMoSfm76oqFvAp8Gy0iz5sxjZmSnXyCdPEovGhLa0VzMaQ8s+CLOyS56YyCFGeJZ -qgtzJ6GR3eqoYSW9b9UMvkBpZODSctWSNGj3P7jRFDO5VoTwCQAWbFnOjDfH5Ulg -p2PKSQnSJP3AJLQNFNe7br1XbrhV//eO+t51mIpGSDCUv3E0DDFcWDTH9cXDTTlR -ZVEiR2BwpZOOkE/Z0/BVnhZYL71oZV34bKfWjQIt6V/isSMahdsAASACp4ZTGtwi -VuNd9tybAgMBAAECggEBAKTmjaS6tkK8BlPXClTQ2vpz/N6uxDeS35mXpqasqskV -laAidgg/sWqpjXDbXr93otIMLlWsM+X0CqMDgSXKejLS2jx4GDjI1ZTXg++0AMJ8 -sJ74pWzVDOfmCEQ/7wXs3+cbnXhKriO8Z036q92Qc1+N87SI38nkGa0ABH9CN83H -mQqt4fB7UdHzuIRe/me2PGhIq5ZBzj6h3BpoPGzEP+x3l9YmK8t/1cN0pqI+dQwY -dgfGjackLu/2qH80MCF7IyQaseZUOJyKrCLtSD/Iixv/hzDEUPfOCjFDgTpzf3cw -ta8+oE4wHCo1iI1/4TlPkwmXx4qSXtmw4aQPz7IDQvECgYEA8KNThCO2gsC2I9PQ -DM/8Cw0O983WCDY+oi+7JPiNAJwv5DYBqEZB1QYdj06YD16XlC/HAZMsMku1na2T -N0driwenQQWzoev3g2S7gRDoS/FCJSI3jJ+kjgtaA7Qmzlgk1TxODN+G1H91HW7t -0l7VnL27IWyYo2qRRK3jzxqUiPUCgYEAx0oQs2reBQGMVZnApD1jeq7n4MvNLcPv -t8b/eU9iUv6Y4Mj0Suo/AU8lYZXm8ubbqAlwz2VSVunD2tOplHyMUrtCtObAfVDU -AhCndKaA9gApgfb3xw1IKbuQ1u4IF1FJl3VtumfQn//LiH1B3rXhcdyo3/vIttEk -48RakUKClU8CgYEAzV7W3COOlDDcQd935DdtKBFRAPRPAlspQUnzMi5eSHMD/ISL -DY5IiQHbIH83D4bvXq0X7qQoSBSNP7Dvv3HYuqMhf0DaegrlBuJllFVVq9qPVRnK -xt1Il2HgxOBvbhOT+9in1BzA+YJ99UzC85O0Qz06A+CmtHEy4aZ2kj5hHjECgYEA -mNS4+A8Fkss8Js1RieK2LniBxMgmYml3pfVLKGnzmng7H2+cwPLhPIzIuwytXywh -2bzbsYEfYx3EoEVgMEpPhoarQnYPukrJO4gwE2o5Te6T5mJSZGlQJQj9q4ZB2Dfz -et6INsK0oG8XVGXSpQvQh3RUYekCZQkBBFcpqWpbIEsCgYAnM3DQf3FJoSnXaMhr -VBIovic5l0xFkEHskAjFTevO86Fsz1C2aSeRKSqGFoOQ0tmJzBEs1R6KqnHInicD -TQrKhArgLXX4v3CddjfTRJkFWDbE/CkvKZNOrcf1nhaGCPspRJj2KUkj1Fhl9Cnc -dn/RsYEONbwQSjIfMPkvxF+8HQ== ------END PRIVATE KEY----- -""" - - -@override_settings( - OKTA_OAUTH_CLIENT_ID="test-client-id", -) -def test_validate_id_token(mocker): - data = { - "sub": "test-okta-id", - "name": "Some User", - "email": "test@example.com", - "iss": "https://example.okta.com", - "aud": "test-client-id", - } - id_token = jwt.encode( - data, private_key, algorithm="RS256", headers={"kid": "test-kid"} - ) - - # did this offline so as not to need an additional dependency - here's the code - # if we ever need to regenerate this: - # - # from Crypto.PublicKey import RSA - # import base64 - # pub = RSA.importKey(public_key) - # modulus = base64.b64encode(pub.n.to_bytes(256, "big")).decode("ascii") - # exponent = base64.b64encode(pub.e.to_bytes(3, "big")).decode("ascii") - - exponent = "AQAB" - modulus = "u1SU1LfVLPHCozMxH2Mo4lgOEePzNm0tRgeLezV6ffAt0gunVTLw7onLRnrq0/IzW7yWR7QkrmBL7jTKEn5u+qKhbwKfBstIs+bMY2Zkp18gnTxKLxoS2tFczGkPLPgizskuemMghRniWaoLcyehkd3qqGElvW/VDL5AaWTg0nLVkjRo9z+40RQzuVaE8AkAFmxZzow3x+VJYKdjykkJ0iT9wCS0DRTXu269V264Vf/3jvredZiKRkgwlL9xNAwxXFg0x/XFw005UWVRIkdgcKWTjpBP2dPwVZ4WWC+9aGVd+Gyn1o0CLelf4rEjGoXbAAEgAqeGUxrcIlbjXfbcmw==" - - get_keys = mocker.patch( - "codecov_auth.views.okta.requests.get", - return_value=mocker.MagicMock( - status_code=200, - json=mocker.MagicMock( - return_value={ - "keys": [ - { - "kty": "RSA", - "alg": "RS256", - "kid": "test-kid", - "use": "sig", - "e": exponent, - "n": modulus, - } - ] - } - ), - ), - ) - - id_payload = validate_id_token(iss="https://example.okta.com", id_token=id_token) - assert id_payload["sub"] == "test-okta-id" - assert id_payload["name"] == "Some User" - assert id_payload["email"] == "test@example.com" - - get_keys.assert_called_once_with("https://example.okta.com/oauth2/v1/keys") @override_settings( @@ -188,7 +101,7 @@ def test_okta_redirect_to_authorize_invalid_iss(client): @override_settings( OKTA_OAUTH_CLIENT_ID="test-client-id", - OKTA_OAUTH_CLIENT_SECRE="test-client-secret", + OKTA_OAUTH_CLIENT_SECRET="test-client-secret", OKTA_OAUTH_REDIRECT_URL="https://localhost:8000/login/okta", ) def test_okta_perform_login( @@ -209,7 +122,7 @@ def test_okta_perform_login( mocked_okta_token_request.assert_called_once_with( "https://example.okta.com/oauth2/v1/token", - auth=okta_basic_auth, + auth=OKTA_BASIC_AUTH, data={ "grant_type": "authorization_code", "code": "test-code", @@ -219,8 +132,7 @@ def test_okta_perform_login( ) mocked_validate_id_token.assert_called_once_with( - "https://example.okta.com", - "test-id-token", + "https://example.okta.com", "test-id-token", "test-client-id" ) assert res.status_code == 302 @@ -243,7 +155,7 @@ def test_okta_perform_login( @override_settings( OKTA_OAUTH_CLIENT_ID="test-client-id", - OKTA_OAUTH_CLIENT_SECRE="test-client-secret", + OKTA_OAUTH_CLIENT_SECRET="test-client-secret", OKTA_OAUTH_REDIRECT_URL="https://localhost:8000/login/okta", ) def test_okta_perform_login_authenticated( @@ -282,7 +194,7 @@ def test_okta_perform_login_authenticated( @override_settings( OKTA_OAUTH_CLIENT_ID="test-client-id", - OKTA_OAUTH_CLIENT_SECRE="test-client-secret", + OKTA_OAUTH_CLIENT_SECRET="test-client-secret", OKTA_OAUTH_REDIRECT_URL="https://localhost:8000/login/okta", ) def test_okta_perform_login_existing_okta_user( @@ -313,7 +225,7 @@ def test_okta_perform_login_existing_okta_user( @override_settings( OKTA_OAUTH_CLIENT_ID="test-client-id", - OKTA_OAUTH_CLIENT_SECRE="test-client-secret", + OKTA_OAUTH_CLIENT_SECRET="test-client-secret", OKTA_OAUTH_REDIRECT_URL="https://localhost:8000/login/okta", ) def test_okta_perform_login_authenticated_existing_okta_user( @@ -347,11 +259,12 @@ def test_okta_perform_login_authenticated_existing_okta_user( @override_settings( OKTA_OAUTH_CLIENT_ID="test-client-id", - OKTA_OAUTH_CLIENT_SECRE="test-client-secret", + OKTA_OAUTH_CLIENT_SECRET="test-client-secret", OKTA_OAUTH_REDIRECT_URL="https://localhost:8000/login/okta", ) +@pytest.mark.django_db def test_okta_perform_login_existing_okta_user_existing_owner( - client, mocked_okta_token_request, mocked_validate_id_token, db + client, mocked_okta_token_request, mocked_validate_id_token ): okta_user = OktaUserFactory(okta_id="test-id") OwnerFactory(service="github", user=okta_user.user) @@ -380,12 +293,12 @@ def test_okta_perform_login_existing_okta_user_existing_owner( @override_settings( OKTA_OAUTH_CLIENT_ID="test-client-id", - OKTA_OAUTH_CLIENT_SECRE="test-client-secret", + OKTA_OAUTH_CLIENT_SECRET="test-client-secret", OKTA_OAUTH_REDIRECT_URL="https://localhost:8000/login/okta", ) def test_okta_perform_login_error(client, mocker, db): mocker.patch( - "codecov_auth.views.okta.requests.post", + "codecov_auth.views.okta_mixin.requests.post", return_value=mocker.MagicMock( status_code=401, ), @@ -409,7 +322,7 @@ def test_okta_perform_login_error(client, mocker, db): @override_settings( OKTA_OAUTH_CLIENT_ID="test-client-id", - OKTA_OAUTH_CLIENT_SECRE="test-client-secret", + OKTA_OAUTH_CLIENT_SECRET="test-client-secret", OKTA_OAUTH_REDIRECT_URL="https://localhost:8000/login/okta", ) def test_okta_perform_login_state_mismatch(client, mocker, db): @@ -427,25 +340,3 @@ def test_okta_perform_login_state_mismatch(client, mocker, db): # does not login user current_user = auth.get_user(client) assert current_user.is_anonymous - - -@override_settings( - OKTA_OAUTH_CLIENT_ID="test-client-id", - OKTA_OAUTH_CLIENT_SECRE="test-client-secret", - OKTA_OAUTH_REDIRECT_URL="https://localhost:8000/login/okta", -) -def test_okta_fetch_user_data_invalid_state(client, db): - with patch("codecov_auth.views.okta.requests.post") as mock_post: - mock_response = MagicMock() - mock_response.status_code = 200 - mock_post.return_value = mock_response - - with patch.object(OktaLoginView, "verify_state", return_value=False): - view = OktaLoginView() - res = view._fetch_user_data( - "https://example.okta.com", - "test-code", - "invalid-state", - ) - - assert res is None diff --git a/codecov_auth/tests/unit/views/test_okta_cloud.py b/codecov_auth/tests/unit/views/test_okta_cloud.py new file mode 100644 index 0000000000..94e8131ca9 --- /dev/null +++ b/codecov_auth/tests/unit/views/test_okta_cloud.py @@ -0,0 +1,458 @@ +from logging import LogRecord +from typing import Any +from unittest.mock import ANY +from urllib.parse import unquote, urlparse + +import pytest +from django.test import override_settings +from pytest import LogCaptureFixture +from pytest_mock import MockerFixture +from shared.django_apps.codecov_auth.models import Account, OktaSettings, Owner +from shared.django_apps.codecov_auth.tests.factories import ( + AccountFactory, + OktaSettingsFactory, +) + +from codecov_auth.tests.factories import OwnerFactory +from codecov_auth.views.okta_cloud import ( + OKTA_CURRENT_SESSION, + OKTA_SIGNED_IN_ACCOUNTS_SESSION_KEY, +) +from codecov_auth.views.okta_mixin import OktaIdTokenPayload +from utils.test_utils import Client as TestClient + + +@pytest.fixture +def signed_in_client() -> TestClient: + new_client = TestClient() + new_client.force_login_owner(OwnerFactory()) + return new_client + + +@pytest.fixture +def okta_org_name() -> str: + return "foo-bar-organization" + + +@pytest.fixture +def okta_org(okta_org_name: str) -> Owner: + org: Owner = OwnerFactory.create(username=okta_org_name, service="github") + org.save() + return org + + +@pytest.fixture +def okta_account(okta_org: Owner): + account = AccountFactory() + okta_org.account = account + okta_org.save() + + okta_settings: OktaSettings = OktaSettingsFactory(account=account) + okta_settings.url = "https://foo-bar.okta.com/" + okta_settings.save() + return account + + +@pytest.fixture +def mocked_okta_token_request(mocker): + return mocker.patch( + "codecov_auth.views.okta_mixin.requests.post", + return_value=mocker.MagicMock( + status_code=200, + json=mocker.MagicMock( + return_value={ + "access_token": "test-access-token", + "refresh_token": "test-refresh-token", + "id_token": "test-id-token", + "state": "test-state", + }, + ), + ), + ) + + +@pytest.fixture +def mocked_validate_id_token(mocker): + return mocker.patch( + "codecov_auth.views.okta_cloud.validate_id_token", + return_value=OktaIdTokenPayload( + sub="test-id", + email="test@example.com", + name="Some User", + iss="https://example.com", + aud="test-client-id", + ), + ) + + +def log_message_exists(message: str, logs: list[LogRecord]) -> bool: + """Helper method to check that a particular log record was emitted""" + for log in logs: + if log.message == message: + return True + return False + + +@pytest.mark.django_db +def test_okta_login_unauthenticated_user( + client: TestClient, + caplog: LogCaptureFixture, +): + res = client.get("/login/okta/github/some-unknown-service") + assert log_message_exists( + "User needs to be signed in before authenticating organization with Okta.", + caplog.records, + ) + assert res.status_code == 403 + + +@pytest.mark.django_db +def test_okta_login_invalid_organization( + signed_in_client: TestClient, + caplog: LogCaptureFixture, +): + res = signed_in_client.get("/login/okta/github/some-unknown-service") + assert log_message_exists("The organization doesn't exist.", caplog.records) + assert res.status_code == 404 + + +@pytest.mark.django_db +def test_okta_login_no_account(signed_in_client: TestClient, caplog: LogCaptureFixture): + org: Owner = OwnerFactory.create(username="org-no-account", service="github") + org.save() + res = signed_in_client.get("/login/okta/github/org-no-account") + assert log_message_exists( + "Okta settings not found. Cannot sign into Okta", caplog.records + ) + assert res.status_code == 404 + + +@pytest.mark.django_db +def test_okta_login_no_okta_settings( + signed_in_client: TestClient, caplog: LogCaptureFixture +): + org: Owner = OwnerFactory.create(username="account-no-okta", service="github") + org.account = AccountFactory() + org.save() + res = signed_in_client.get("/login/okta/github/account-no-okta") + assert log_message_exists( + "Okta settings not found. Cannot sign into Okta", caplog.records + ) + assert res.status_code == 404 + + +@pytest.mark.django_db +def test_okta_login_already_signed_into_okta( + signed_in_client: TestClient, + okta_org_name: str, + okta_account: Account, +): + session = signed_in_client.session + session[OKTA_SIGNED_IN_ACCOUNTS_SESSION_KEY] = [okta_account.id] + session.save() + res = signed_in_client.get(f"/login/okta/gh/{okta_org_name}") + assert res.status_code == 302 + assert res.url == f"http://localhost:3000/github/{okta_org_name}" + + +@override_settings( + CODECOV_API_URL="http://localhost:8000", +) +@pytest.mark.django_db +def test_okta_login_redirect_to_okta_issuer( + signed_in_client: TestClient, okta_org_name: str, okta_account: Account +): + res = signed_in_client.get(f"/login/okta/gh/{okta_org_name}") + assert res.status_code == 302 + parsed_url = urlparse(res.url) + assert parsed_url.hostname == "foo-bar.okta.com" + assert parsed_url.path == "/oauth2/v1/authorize" + + parsed_query = parsed_url.query.split("&") + raw_redirect_url = next(x for x in parsed_query if x.startswith("redirect_uri=")) + assert raw_redirect_url + assert ( + unquote(raw_redirect_url.split("=")[1]) + == "http://localhost:8000/login/okta/callback" + ) + + +@pytest.mark.django_db +def test_okta_callback_login_success( + signed_in_client: TestClient, + okta_account: Account, + okta_org: Owner, + mocked_validate_id_token: Any, + mocked_okta_token_request: Any, +): + state = "test-state" + session = signed_in_client.session + assert session.get(OKTA_SIGNED_IN_ACCOUNTS_SESSION_KEY) is None + + session["okta_cloud_oauth_state"] = state + session[OKTA_CURRENT_SESSION] = { + "org_ownerid": okta_org.ownerid, + "okta_settings_id": okta_account.okta_settings.first().id, + } + + session.save() + + res = signed_in_client.get( + "/login/okta/callback", + data={ + "code": "random-code", + "state": state, + }, + ) + + assert res.status_code == 302 + assert res.url == f"http://localhost:3000/github/{okta_org.username}" + + updated_session = signed_in_client.session + assert updated_session.get(OKTA_SIGNED_IN_ACCOUNTS_SESSION_KEY) == [okta_account.id] + + mocked_validate_id_token.assert_called_with("https://foo-bar.okta.com", ANY, ANY) + + +@pytest.mark.django_db +def test_okta_callback_login_success_multiple_accounts( + signed_in_client: TestClient, + okta_account: Account, + okta_org: Owner, + mocked_validate_id_token: Any, + mocked_okta_token_request: Any, +): + state = "test-state" + session = signed_in_client.session + # Put in a random account that's not current okta_account + session[OKTA_SIGNED_IN_ACCOUNTS_SESSION_KEY] = [okta_account.id + 1] + + session["okta_cloud_oauth_state"] = state + session[OKTA_CURRENT_SESSION] = { + "org_ownerid": okta_org.ownerid, + "okta_settings_id": okta_account.okta_settings.first().id, + } + + session.save() + + res = signed_in_client.get( + "/login/okta/callback", + data={ + "code": "random-code", + "state": state, + }, + ) + + assert res.status_code == 302 + assert res.url == f"http://localhost:3000/github/{okta_org.username}" + + updated_session = signed_in_client.session + assert updated_session.get(OKTA_SIGNED_IN_ACCOUNTS_SESSION_KEY) == [ + okta_account.id + 1, + okta_account.id, + ] + + mocked_validate_id_token.assert_called_with("https://foo-bar.okta.com", ANY, ANY) + + +@pytest.mark.django_db +def test_okta_callback_missing_session( + signed_in_client: TestClient, + caplog: LogCaptureFixture, + okta_org: Owner, + okta_account: Account, +): + session = signed_in_client.session + state = "test-state" + assert session.get(OKTA_SIGNED_IN_ACCOUNTS_SESSION_KEY) is None + session["okta_cloud_oauth_state"] = state + session.save() + + res = signed_in_client.get( + "/login/okta/callback", + data={ + "code": "random-code", + "state": state, + }, + ) + assert res.status_code == 403 + + assert log_message_exists( + "Trying to sign into Okta with no existing sign-in session.", caplog.records + ) + + updated_session = signed_in_client.session + assert updated_session.get(OKTA_SIGNED_IN_ACCOUNTS_SESSION_KEY) is None + + +@pytest.mark.django_db +def test_okta_callback_missing_user( + client: TestClient, + caplog: LogCaptureFixture, + okta_org: Owner, + okta_account: Account, +): + session = client.session + state = "test-state" + assert session.get(OKTA_SIGNED_IN_ACCOUNTS_SESSION_KEY) is None + session["okta_cloud_oauth_state"] = state + session[OKTA_CURRENT_SESSION] = { + "org_ownerid": okta_org.ownerid, + "okta_settings_id": okta_account.okta_settings.first().id, + } + session.save() + + res = client.get( + "/login/okta/callback", + data={ + "code": "random-code", + "state": state, + }, + ) + assert res.status_code == 403 + + assert log_message_exists("User not logged in for Okta callback.", caplog.records) + + updated_session = client.session + assert updated_session.get(OKTA_SIGNED_IN_ACCOUNTS_SESSION_KEY) is None + + +@pytest.mark.django_db +def test_okta_callback_missing_okta_settings( + signed_in_client: TestClient, + caplog: LogCaptureFixture, + okta_org: Owner, + okta_account: Account, +): + session = signed_in_client.session + state = "test-state" + assert session.get(OKTA_SIGNED_IN_ACCOUNTS_SESSION_KEY) is None + session["okta_cloud_oauth_state"] = state + session[OKTA_CURRENT_SESSION] = { + "org_ownerid": okta_org.ownerid, + "okta_settings_id": 12345, + } + session.save() + + res = signed_in_client.get( + "/login/okta/callback", + data={ + "code": "random-code", + "state": state, + }, + ) + assert res.status_code == 404 + + assert log_message_exists( + "Okta settings not found. Cannot sign into Okta", caplog.records + ) + + updated_session = signed_in_client.session + assert updated_session.get(OKTA_SIGNED_IN_ACCOUNTS_SESSION_KEY) is None + + +@pytest.mark.django_db +def test_okta_callback_no_code( + signed_in_client: TestClient, + caplog: LogCaptureFixture, + okta_org: Owner, + okta_account: Account, +): + session = signed_in_client.session + state = "test-state" + assert session.get(OKTA_SIGNED_IN_ACCOUNTS_SESSION_KEY) is None + session["okta_cloud_oauth_state"] = state + session[OKTA_CURRENT_SESSION] = { + "org_ownerid": okta_org.ownerid, + "okta_settings_id": okta_account.okta_settings.first().id, + } + session.save() + + res = signed_in_client.get( + "/login/okta/callback", + data={ + "state": state, + }, + ) + assert res.status_code == 400 + + assert log_message_exists( + "No code is passed. Invalid callback. Cannot sign into Okta", caplog.records + ) + + updated_session = signed_in_client.session + assert updated_session.get(OKTA_SIGNED_IN_ACCOUNTS_SESSION_KEY) is None + + +@pytest.mark.django_db +def test_okta_callback_perform_login_invalid_state( + signed_in_client: TestClient, + caplog: LogCaptureFixture, + okta_org: Owner, + okta_account: Account, +): + session = signed_in_client.session + assert session.get(OKTA_SIGNED_IN_ACCOUNTS_SESSION_KEY) is None + session["okta_cloud_oauth_state"] = "random-state" + + session[OKTA_CURRENT_SESSION] = { + "org_ownerid": okta_org.ownerid, + "okta_settings_id": okta_account.okta_settings.first().id, + } + session.save() + + res = signed_in_client.get( + "/login/okta/callback", + data={ + "code": "random-code", + "state": "different-state", + }, + ) + assert res.status_code == 302 + assert res.url == f"http://localhost:3000/github/{okta_org.username}" + + assert log_message_exists("Invalid state during Okta login", caplog.records) + + updated_session = signed_in_client.session + assert updated_session.get(OKTA_SIGNED_IN_ACCOUNTS_SESSION_KEY) is None + + +@pytest.mark.django_db +def test_okta_callback_perform_login_no_user_data( + mocker: MockerFixture, + signed_in_client: TestClient, + caplog: LogCaptureFixture, + okta_org: Owner, + okta_account: Account, + mocked_okta_token_request: Any, +): + state = "test-state" + session = signed_in_client.session + assert session.get(OKTA_SIGNED_IN_ACCOUNTS_SESSION_KEY) is None + session["okta_cloud_oauth_state"] = state + + session[OKTA_CURRENT_SESSION] = { + "org_ownerid": okta_org.ownerid, + "okta_settings_id": okta_account.okta_settings.first().id, + } + session.save() + + mocked_okta_token_request.return_value = mocker.MagicMock( + status_code=400, + ) + + res = signed_in_client.get( + "/login/okta/callback", + data={ + "code": "random-code", + "state": state, + }, + ) + assert res.status_code == 400 + + assert log_message_exists( + "Can't log in. Invalid Okta Token Response", caplog.records + ) + + updated_session = signed_in_client.session + assert updated_session.get(OKTA_SIGNED_IN_ACCOUNTS_SESSION_KEY) is None diff --git a/codecov_auth/tests/unit/views/test_okta_mixin.py b/codecov_auth/tests/unit/views/test_okta_mixin.py new file mode 100644 index 0000000000..32f612901a --- /dev/null +++ b/codecov_auth/tests/unit/views/test_okta_mixin.py @@ -0,0 +1,110 @@ +from unittest.mock import MagicMock, patch + +import jwt + +from codecov_auth.views.okta import OKTA_BASIC_AUTH, OktaLoginView +from codecov_auth.views.okta_mixin import validate_id_token + +private_key = """-----BEGIN PRIVATE KEY----- +MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQC7VJTUt9Us8cKj +MzEfYyjiWA4R4/M2bS1GB4t7NXp98C3SC6dVMvDuictGeurT8jNbvJZHtCSuYEvu +NMoSfm76oqFvAp8Gy0iz5sxjZmSnXyCdPEovGhLa0VzMaQ8s+CLOyS56YyCFGeJZ +qgtzJ6GR3eqoYSW9b9UMvkBpZODSctWSNGj3P7jRFDO5VoTwCQAWbFnOjDfH5Ulg +p2PKSQnSJP3AJLQNFNe7br1XbrhV//eO+t51mIpGSDCUv3E0DDFcWDTH9cXDTTlR +ZVEiR2BwpZOOkE/Z0/BVnhZYL71oZV34bKfWjQIt6V/isSMahdsAASACp4ZTGtwi +VuNd9tybAgMBAAECggEBAKTmjaS6tkK8BlPXClTQ2vpz/N6uxDeS35mXpqasqskV +laAidgg/sWqpjXDbXr93otIMLlWsM+X0CqMDgSXKejLS2jx4GDjI1ZTXg++0AMJ8 +sJ74pWzVDOfmCEQ/7wXs3+cbnXhKriO8Z036q92Qc1+N87SI38nkGa0ABH9CN83H +mQqt4fB7UdHzuIRe/me2PGhIq5ZBzj6h3BpoPGzEP+x3l9YmK8t/1cN0pqI+dQwY +dgfGjackLu/2qH80MCF7IyQaseZUOJyKrCLtSD/Iixv/hzDEUPfOCjFDgTpzf3cw +ta8+oE4wHCo1iI1/4TlPkwmXx4qSXtmw4aQPz7IDQvECgYEA8KNThCO2gsC2I9PQ +DM/8Cw0O983WCDY+oi+7JPiNAJwv5DYBqEZB1QYdj06YD16XlC/HAZMsMku1na2T +N0driwenQQWzoev3g2S7gRDoS/FCJSI3jJ+kjgtaA7Qmzlgk1TxODN+G1H91HW7t +0l7VnL27IWyYo2qRRK3jzxqUiPUCgYEAx0oQs2reBQGMVZnApD1jeq7n4MvNLcPv +t8b/eU9iUv6Y4Mj0Suo/AU8lYZXm8ubbqAlwz2VSVunD2tOplHyMUrtCtObAfVDU +AhCndKaA9gApgfb3xw1IKbuQ1u4IF1FJl3VtumfQn//LiH1B3rXhcdyo3/vIttEk +48RakUKClU8CgYEAzV7W3COOlDDcQd935DdtKBFRAPRPAlspQUnzMi5eSHMD/ISL +DY5IiQHbIH83D4bvXq0X7qQoSBSNP7Dvv3HYuqMhf0DaegrlBuJllFVVq9qPVRnK +xt1Il2HgxOBvbhOT+9in1BzA+YJ99UzC85O0Qz06A+CmtHEy4aZ2kj5hHjECgYEA +mNS4+A8Fkss8Js1RieK2LniBxMgmYml3pfVLKGnzmng7H2+cwPLhPIzIuwytXywh +2bzbsYEfYx3EoEVgMEpPhoarQnYPukrJO4gwE2o5Te6T5mJSZGlQJQj9q4ZB2Dfz +et6INsK0oG8XVGXSpQvQh3RUYekCZQkBBFcpqWpbIEsCgYAnM3DQf3FJoSnXaMhr +VBIovic5l0xFkEHskAjFTevO86Fsz1C2aSeRKSqGFoOQ0tmJzBEs1R6KqnHInicD +TQrKhArgLXX4v3CddjfTRJkFWDbE/CkvKZNOrcf1nhaGCPspRJj2KUkj1Fhl9Cnc +dn/RsYEONbwQSjIfMPkvxF+8HQ== +-----END PRIVATE KEY----- +""" + + +def test_validate_id_token(mocker): + data = { + "sub": "test-okta-id", + "name": "Some User", + "email": "test@example.com", + "iss": "https://example.okta.com", + "aud": "test-client-id", + } + id_token = jwt.encode( + data, private_key, algorithm="RS256", headers={"kid": "test-kid"} + ) + + # did this offline so as not to need an additional dependency - here's the code + # if we ever need to regenerate this: + # + # from Crypto.PublicKey import RSA + # import base64 + # pub = RSA.importKey(public_key) + # modulus = base64.b64encode(pub.n.to_bytes(256, "big")).decode("ascii") + # exponent = base64.b64encode(pub.e.to_bytes(3, "big")).decode("ascii") + + exponent = "AQAB" + modulus = "u1SU1LfVLPHCozMxH2Mo4lgOEePzNm0tRgeLezV6ffAt0gunVTLw7onLRnrq0/IzW7yWR7QkrmBL7jTKEn5u+qKhbwKfBstIs+bMY2Zkp18gnTxKLxoS2tFczGkPLPgizskuemMghRniWaoLcyehkd3qqGElvW/VDL5AaWTg0nLVkjRo9z+40RQzuVaE8AkAFmxZzow3x+VJYKdjykkJ0iT9wCS0DRTXu269V264Vf/3jvredZiKRkgwlL9xNAwxXFg0x/XFw005UWVRIkdgcKWTjpBP2dPwVZ4WWC+9aGVd+Gyn1o0CLelf4rEjGoXbAAEgAqeGUxrcIlbjXfbcmw==" + + get_keys = mocker.patch( + "codecov_auth.views.okta_mixin.requests.get", + return_value=mocker.MagicMock( + status_code=200, + json=mocker.MagicMock( + return_value={ + "keys": [ + { + "kty": "RSA", + "alg": "RS256", + "kid": "test-kid", + "use": "sig", + "e": exponent, + "n": modulus, + } + ] + } + ), + ), + ) + + id_payload = validate_id_token( + iss="https://example.okta.com", id_token=id_token, client_id="test-client-id" + ) + assert id_payload.sub == "test-okta-id" + assert id_payload.name == "Some User" + assert id_payload.email == "test@example.com" + + get_keys.assert_called_once_with("https://example.okta.com/oauth2/v1/keys") + + +def test_okta_fetch_user_data_invalid_state(client, db): + with patch("codecov_auth.views.okta_mixin.requests.post") as mock_post: + mock_response = MagicMock() + mock_response.status_code = 200 + mock_post.return_value = mock_response + + with patch.object(OktaLoginView, "verify_state", return_value=False): + view = OktaLoginView() + res = view._fetch_user_data( + "https://example.okta.com", + "test-code", + "invalid-state", + "https://localhost:8000/login/okta", + OKTA_BASIC_AUTH, + ) + + assert res is None diff --git a/codecov_auth/urls.py b/codecov_auth/urls.py index 91d7730692..47081a2b96 100644 --- a/codecov_auth/urls.py +++ b/codecov_auth/urls.py @@ -9,6 +9,7 @@ from .views.gitlab_enterprise import GitlabEnterpriseLoginView from .views.logout import logout_view from .views.okta import OktaLoginView +from .views.okta_cloud import OktaCloudCallbackView, OktaCloudLoginView from .views.sentry import SentryLoginView urlpatterns = [ @@ -40,6 +41,16 @@ path("login/bbs", BitbucketServerLoginView.as_view(), name="bbs-login"), path("login/stash", BitbucketServerLoginView.as_view(), name="stash-login"), path("login/sentry", SentryLoginView.as_view(), name="sentry-login"), + path( + "login/okta//", + OktaCloudLoginView.as_view(), + name="okta-cloud-login", + ), + path( + "login/okta/callback", + OktaCloudCallbackView.as_view(), + name="okta-cloud-callback", + ), ] if settings.OKTA_ISS is not None: urlpatterns += [path("login/okta", OktaLoginView.as_view(), name="okta-login")] diff --git a/codecov_auth/views/__init__.py b/codecov_auth/views/__init__.py index 84b2cbcc03..7992bf6347 100644 --- a/codecov_auth/views/__init__.py +++ b/codecov_auth/views/__init__.py @@ -1 +1,3 @@ from codecov_auth.views.github import GithubLoginView + +__all__ = ["GithubLoginView"] diff --git a/codecov_auth/views/base.py b/codecov_auth/views/base.py index 254189eed0..22c255d5b6 100644 --- a/codecov_auth/views/base.py +++ b/codecov_auth/views/base.py @@ -7,7 +7,7 @@ from django.conf import settings from django.contrib.auth import login, logout from django.contrib.sessions.models import Session as DjangoSession -from django.core.exceptions import PermissionDenied, SuspiciousOperation +from django.core.exceptions import PermissionDenied from django.http.request import HttpRequest from django.http.response import HttpResponse from django.utils import timezone @@ -262,9 +262,7 @@ def get_and_modify_owner(self, user_dict, request) -> Owner: ] self._check_enterprise_organizations_membership(user_dict, formatted_orgs) - upserted_orgs = [] - for org in formatted_orgs: - upserted_orgs.append(self.get_or_create_org(org)) + upserted_orgs = [self.get_or_create_org(org) for org in formatted_orgs] self._check_user_count_limitations(user_dict["user"]) owner, is_new_user = self._get_or_create_owner(user_dict, request) diff --git a/codecov_auth/views/bitbucket.py b/codecov_auth/views/bitbucket.py index e90f7da9b2..a204e237cd 100644 --- a/codecov_auth/views/bitbucket.py +++ b/codecov_auth/views/bitbucket.py @@ -1,4 +1,3 @@ -import asyncio import base64 import logging from urllib.parse import urlencode @@ -8,6 +7,9 @@ from django.shortcuts import redirect from django.urls import reverse from django.views import View +from shared.django_apps.codecov_metrics.service.codecov_metrics import ( + UserOnboardingMetricsService, +) from shared.torngit import Bitbucket from shared.torngit.exceptions import TorngitServerFailureError @@ -103,6 +105,9 @@ def actual_login_step(self, request): response.delete_cookie("_oauth_request_token", domain=settings.COOKIES_DOMAIN) self.login_owner(user, request, response) log.info("User successfully logged in", extra=dict(ownerid=user.ownerid)) + UserOnboardingMetricsService.create_user_onboarding_metric( + org_id=user.ownerid, event="INSTALLED_APP", payload={"login": "bitbucket"} + ) return response def get(self, request): diff --git a/codecov_auth/views/github.py b/codecov_auth/views/github.py index f957a38f59..cdb39d1a62 100644 --- a/codecov_auth/views/github.py +++ b/codecov_auth/views/github.py @@ -1,6 +1,4 @@ -import asyncio import logging -from datetime import datetime, timedelta from typing import Optional from urllib.parse import urlencode, urljoin @@ -8,6 +6,9 @@ from django.conf import settings from django.shortcuts import redirect from django.views import View +from shared.django_apps.codecov_metrics.service.codecov_metrics import ( + UserOnboardingMetricsService, +) from shared.torngit import Github from shared.torngit.exceptions import TorngitError @@ -125,7 +126,9 @@ def actual_login_step(self, request): response = redirect(redirection_url) self.login_owner(owner, request, response) self.remove_state(state) - self.store_access_token_expiry_to_cookie(response) + UserOnboardingMetricsService.create_user_onboarding_metric( + org_id=owner.ownerid, event="INSTALLED_APP", payload={"login": "github"} + ) return response def get(self, request): @@ -160,13 +163,3 @@ def get(self, request): response = redirect(url_to_redirect_to) self.store_to_cookie_utm_tags(response) return response - - # Set a session expiry of 8 hours for github logins. GH access tokens expire after 8 hours by default - # https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/token-expiration-and-revocation#user-token-revoked-due-to-github-app-configuration - def store_access_token_expiry_to_cookie(self, response): - domain_to_use = settings.COOKIES_DOMAIN - eight_hours_later = datetime.utcnow() + timedelta(hours=8) - eight_hours_later_iso = eight_hours_later.isoformat() + "Z" - response.set_cookie( - "session_expiry", eight_hours_later_iso, domain=domain_to_use, secure=True - ) diff --git a/codecov_auth/views/gitlab.py b/codecov_auth/views/gitlab.py index c48271a897..399e7c5ac6 100644 --- a/codecov_auth/views/gitlab.py +++ b/codecov_auth/views/gitlab.py @@ -1,18 +1,18 @@ -import asyncio import logging from urllib.parse import urlencode, urljoin -from uuid import uuid4 +from uuid import uuid4 # noqa: F401 from asgiref.sync import async_to_sync from django.conf import settings from django.shortcuts import redirect -from django.urls import reverse from django.views import View +from shared.django_apps.codecov_metrics.service.codecov_metrics import ( + UserOnboardingMetricsService, +) from shared.torngit import Gitlab from shared.torngit.exceptions import TorngitError from codecov_auth.views.base import LoginMixin, StateMixin -from utils.config import get_config log = logging.getLogger(__name__) @@ -85,6 +85,9 @@ def actual_login_step(self, request): response = redirect(redirection_url) self.login_owner(user, request, response) self.remove_state(state, delay=5) + UserOnboardingMetricsService.create_user_onboarding_metric( + org_id=user.ownerid, event="INSTALLED_APP", payload={"login": "gitlab"} + ) return response def get(self, request): diff --git a/codecov_auth/views/okta.py b/codecov_auth/views/okta.py index e01ef3d1ff..76931c75b6 100644 --- a/codecov_auth/views/okta.py +++ b/codecov_auth/views/okta.py @@ -1,11 +1,5 @@ -import json import logging -import re -from typing import Dict, Optional -from urllib.parse import urlencode -import jwt -import requests from django.conf import settings from django.contrib.auth import login, logout from django.http import HttpRequest, HttpResponse @@ -14,81 +8,26 @@ from requests.auth import HTTPBasicAuth from codecov_auth.models import OktaUser, User -from codecov_auth.views.base import LoginMixin, StateMixin +from codecov_auth.views.base import LoginMixin +from codecov_auth.views.okta_mixin import ( + ISS_REGEX, + OktaLoginMixin, + OktaTokenResponse, + validate_id_token, +) from utils.services import get_short_service_name log = logging.getLogger(__name__) -iss_regex = re.compile(r"https://[\w\d\-\_]+.okta.com/?") +OKTA_BASIC_AUTH = HTTPBasicAuth( + settings.OKTA_OAUTH_CLIENT_ID, settings.OKTA_OAUTH_CLIENT_SECRET +) -def validate_id_token(iss: str, id_token: str) -> dict: - res = requests.get(f"{iss}/oauth2/v1/keys") - jwks = res.json() - public_keys = {} - for jwk in jwks["keys"]: - kid = jwk["kid"] - public_keys[kid] = jwt.algorithms.RSAAlgorithm.from_jwk(json.dumps(jwk)) - - kid = jwt.get_unverified_header(id_token)["kid"] - key = public_keys[kid] - - id_payload = jwt.decode( - id_token, - key=key, - algorithms=["RS256"], - audience=settings.OKTA_OAUTH_CLIENT_ID, - ) - assert id_payload["iss"] == iss - assert id_payload["aud"] == settings.OKTA_OAUTH_CLIENT_ID - - return id_payload - - -auth = HTTPBasicAuth(settings.OKTA_OAUTH_CLIENT_ID, settings.OKTA_OAUTH_CLIENT_SECRET) - - -class OktaLoginView(LoginMixin, StateMixin, View): +class OktaLoginView(LoginMixin, OktaLoginMixin, View): service = "okta" - def _fetch_user_data(self, iss: str, code: str, state: str) -> Optional[Dict]: - res = requests.post( - f"{iss}/oauth2/v1/token", - auth=auth, - data={ - "grant_type": "authorization_code", - "code": code, - "redirect_uri": settings.OKTA_OAUTH_REDIRECT_URL, - "state": state, - }, - ) - - if not self.verify_state(state): - log.warning("Invalid state during Okta OAuth") - return None - - if res.status_code >= 400: - return None - return res.json() - - def _redirect_to_consent(self, iss: str) -> HttpResponse: - state = self.generate_state() - qs = urlencode( - dict( - response_type="code", - client_id=settings.OKTA_OAUTH_CLIENT_ID, - scope="openid email profile", - redirect_uri=settings.OKTA_OAUTH_REDIRECT_URL, - state=state, - ) - ) - redirect_url = f"{iss}/oauth2/v1/authorize?{qs}" - response = redirect(redirect_url) - self.store_to_cookie_utm_tags(response) - - return response - - def _perform_login(self, request: HttpRequest) -> HttpResponse: + def _perform_login(self, request: HttpRequest, iss: str) -> HttpResponse: code = request.GET.get("code") state = request.GET.get("state") @@ -96,12 +35,9 @@ def _perform_login(self, request: HttpRequest) -> HttpResponse: log.warning("Invalid state during Okta login") return redirect(f"{settings.CODECOV_DASHBOARD_URL}/login") - iss = settings.OKTA_ISS - if iss is None: - log.warning("Unable to log in due to missing Okta issuer", exc_info=True) - return redirect(f"{settings.CODECOV_DASHBOARD_URL}/login") - - user_data = self._fetch_user_data(iss, code, state) + user_data: OktaTokenResponse | None = self._fetch_user_data( + iss, code, state, settings.OKTA_OAUTH_REDIRECT_URL, OKTA_BASIC_AUTH + ) if user_data is None: log.warning("Unable to log in due to problem on Okta", exc_info=True) return redirect(f"{settings.CODECOV_DASHBOARD_URL}/login") @@ -121,19 +57,18 @@ def _perform_login(self, request: HttpRequest) -> HttpResponse: return response - def _login_user(self, request: HttpRequest, iss: str, user_data: dict): - id_token = user_data[ - "id_token" - ] # this will be present since we requested the `oidc` scope - id_payload = validate_id_token(iss, id_token) + def _login_user( + self, request: HttpRequest, iss: str, user_data: OktaTokenResponse + ) -> User: + id_token = user_data.id_token + id_payload = validate_id_token(iss, id_token, settings.OKTA_OAUTH_CLIENT_ID) - okta_id = id_payload["sub"] - user_email = id_payload["email"] - user_name = id_payload["name"] + okta_id = id_payload.sub + user_email = id_payload.email + user_name = id_payload.name okta_user = OktaUser.objects.filter(okta_id=okta_id).first() - current_user = None if request.user is not None and not request.user.is_anonymous: # we're already authenticated current_user = request.user @@ -169,7 +104,7 @@ def _login_user(self, request: HttpRequest, iss: str, user_data: dict): okta_id=okta_id, name=user_name, email=user_email, - access_token=user_data["access_token"], + access_token=user_data.access_token, ) log.info( "Created Okta user", @@ -179,15 +114,29 @@ def _login_user(self, request: HttpRequest, iss: str, user_data: dict): login(request, current_user) return current_user - def get(self, request): + def validate_issuer(self) -> str | None: + """Checks that the issuer is valid. If not, it returns None.""" + iss = settings.OKTA_ISS + if iss is None: + log.warning("Unable to log in due to missing Okta issuer", exc_info=True) + return None + if not ISS_REGEX.match(iss): + log.warning("Invalid Okta issuer") + return None + return iss + + def get(self, request: HttpRequest) -> HttpResponse: + iss = self.validate_issuer() + if iss is None: + return redirect(f"{settings.CODECOV_DASHBOARD_URL}/login") + if request.GET.get("code"): - return self._perform_login(request) + return self._perform_login(request, iss) else: - iss = settings.OKTA_ISS - if not iss: - log.warning("Missing Okta issuer") - return redirect(f"{settings.CODECOV_DASHBOARD_URL}/login") - if not iss_regex.match(iss): - log.warning("Invalid Okta issuer") - return redirect(f"{settings.CODECOV_DASHBOARD_URL}/login") - return self._redirect_to_consent(iss=iss) + response = self._redirect_to_consent( + iss=iss, + client_id=settings.OKTA_OAUTH_CLIENT_ID, + oauth_redirect_url=settings.OKTA_OAUTH_REDIRECT_URL, + ) + self.store_to_cookie_utm_tags(response) + return response diff --git a/codecov_auth/views/okta_cloud.py b/codecov_auth/views/okta_cloud.py new file mode 100644 index 0000000000..93de883168 --- /dev/null +++ b/codecov_auth/views/okta_cloud.py @@ -0,0 +1,197 @@ +import logging + +from django.conf import settings +from django.http import HttpRequest, HttpResponse +from django.shortcuts import redirect +from django.views import View +from requests.auth import HTTPBasicAuth +from shared.django_apps.codecov_auth.models import Account, OktaSettings, Owner + +from codecov_auth.views.okta_mixin import ( + OktaLoginMixin, + OktaTokenResponse, + validate_id_token, +) + +# The key for accessing the Okta signed in accounts list in the session +OKTA_SIGNED_IN_ACCOUNTS_SESSION_KEY = "okta_signed_in_accounts" + +# The key for the currently signing in session in Okta. +# This is so that the callback can reference the orgs/accounts that we're +# signing in for. +OKTA_CURRENT_SESSION = "okta_current_session" + +log = logging.getLogger(__name__) + + +def get_app_redirect_url(org_username: str, service: str) -> str: + """The Codecov app page we redirect users to.""" + return f"{settings.CODECOV_DASHBOARD_URL}/{service}/{org_username}" + + +def get_oauth_redirect_url() -> str: + """The Okta callback URL for us to finish the authentication.""" + return f"{settings.CODECOV_API_URL}/login/okta/callback" + + +def get_okta_settings(organization: Owner) -> OktaSettings | None: + account: Account | None = organization.account + if account: + okta_settings: OktaSettings | None = account.okta_settings.first() + if okta_settings: + return okta_settings + return None + + +class OktaCloudLoginView(OktaLoginMixin, View): + service = "okta_cloud" + + def get( + self, request: HttpRequest, service: str, org_username: str + ) -> HttpResponse: + log_context: dict = {"service": service, "username": org_username} + if not request.session.get("current_owner_id"): + log.warning( + "User needs to be signed in before authenticating organization with Okta.", + extra=log_context, + ) + return HttpResponse(status=403) + + try: + organization: Owner = Owner.objects.get( + service=service, username=org_username + ) + except Owner.DoesNotExist: + log.warning("The organization doesn't exist.", extra=log_context) + return HttpResponse(status=404) + + okta_settings = get_okta_settings(organization) + if not okta_settings: + log.warning( + "Okta settings not found. Cannot sign into Okta", extra=log_context + ) + return HttpResponse(status=404) + + app_redirect_url = get_app_redirect_url( + organization.username, organization.service + ) + oauth_redirect_url = get_oauth_redirect_url() + + # User is already logged in, redirect them to the org page + if organization.account.id in request.session.get( + OKTA_SIGNED_IN_ACCOUNTS_SESSION_KEY, [] + ): + return redirect(app_redirect_url) + + # Otherwise start the process redirect them to the Issuer page to authenticate + else: + consent = self._redirect_to_consent( + iss=okta_settings.url.strip("/ "), + client_id=okta_settings.client_id, + oauth_redirect_url=oauth_redirect_url, + ) + request.session[OKTA_CURRENT_SESSION] = { + "org_ownerid": organization.ownerid, + "okta_settings_id": okta_settings.id, + } + return consent + + +class OktaCloudCallbackView(OktaLoginMixin, View): + service = "okta_cloud" + + def get(self, request: HttpRequest) -> HttpResponse: + current_okta_session: dict[str, int] | None = request.session.get( + OKTA_CURRENT_SESSION + ) + if not current_okta_session: + log.warning("Trying to sign into Okta with no existing sign-in session.") + return HttpResponse(status=403) + + org_owner = Owner.objects.get(ownerid=current_okta_session["org_ownerid"]) + log_context: dict = { + "service": org_owner.service, + "username": org_owner.username, + } + + if not request.session.get("current_owner_id"): + log.warning( + "User not logged in for Okta callback.", + extra=log_context, + ) + return HttpResponse(status=403) + + try: + okta_settings = OktaSettings.objects.get( + id=current_okta_session["okta_settings_id"] + ) + except OktaSettings.DoesNotExist: + log.warning( + "Okta settings not found. Cannot sign into Okta", extra=log_context + ) + return HttpResponse(status=404) + + app_redirect_url = get_app_redirect_url(org_owner.username, org_owner.service) + oauth_redirect_url = get_oauth_redirect_url() + + # Redirect URL, need to validate and mark user as logged in + if request.GET.get("code"): + return self._perform_login( + request, + org_owner, + okta_settings, + app_redirect_url, + oauth_redirect_url, + ) + else: + log.warning( + "No code is passed. Invalid callback. Cannot sign into Okta", + extra=log_context, + ) + + return HttpResponse(status=400) + + def _perform_login( + self, + request: HttpRequest, + organization: Owner, + okta_settings: OktaSettings, + app_redirect_url: str, + oauth_redirect_url: str, + ) -> HttpResponse: + code = request.GET.get("code") + state = request.GET.get("state") + + if not self.verify_state(state): + log.warning("Invalid state during Okta login") + return redirect(app_redirect_url) + + issuer: str = okta_settings.url.strip("/ ") + user_data: OktaTokenResponse | None = self._fetch_user_data( + issuer, + code, + state, + oauth_redirect_url, + HTTPBasicAuth(okta_settings.client_id, okta_settings.client_secret), + ) + + if user_data is None: + log.warning("Can't log in. Invalid Okta Token Response", exc_info=True) + return HttpResponse(status=400) + + _ = validate_id_token(issuer, user_data.id_token, okta_settings.client_id) + + self._login_user(request, organization) + + return redirect(app_redirect_url) + + def _login_user(self, request: HttpRequest, organization: Owner) -> None: + """Logging in the user will just mean adding the account to the user's + okta_logged_in_accounts session. + """ + okta_signed_in_accounts: list[int] = request.session.get( + OKTA_SIGNED_IN_ACCOUNTS_SESSION_KEY, [] + ) + okta_signed_in_accounts.append(organization.account.id) + request.session[OKTA_SIGNED_IN_ACCOUNTS_SESSION_KEY] = okta_signed_in_accounts + return diff --git a/codecov_auth/views/okta_mixin.py b/codecov_auth/views/okta_mixin.py new file mode 100644 index 0000000000..052dc54f4b --- /dev/null +++ b/codecov_auth/views/okta_mixin.py @@ -0,0 +1,112 @@ +import json +import logging +import re +from urllib.parse import urlencode + +import jwt +import pydantic +import requests +from django.http import HttpResponse +from django.shortcuts import redirect +from requests.auth import HTTPBasicAuth + +from codecov_auth.views.base import StateMixin + +log = logging.getLogger(__name__) + +ISS_REGEX = re.compile(r"https://[\w\d\-\_]+.okta.com/?") + + +class OktaTokenResponse(pydantic.BaseModel): + """This model serializes the response from Okta's oauth/v1/token endpoint. + ref: https://developer.okta.com/docs/reference/api/oidc/#token + + Keeping reference to only the fields that are used. + """ + + access_token: str + id_token: str # this will be present since we requested the `oidc` scope + + +class OktaIdTokenPayload(pydantic.BaseModel): + """Serializes the ID Payload from Okta's id_token deserialization. + ref: https://developer.okta.com/docs/reference/api/oidc/#id-token + """ + + aud: str + iss: str + sub: str + email: str + name: str + + +def validate_id_token(iss: str, id_token: str, client_id: str) -> OktaIdTokenPayload: + res = requests.get(f"{iss}/oauth2/v1/keys") + jwks = res.json() + + public_keys = {} + for jwk in jwks["keys"]: + kid = jwk["kid"] + public_keys[kid] = jwt.algorithms.RSAAlgorithm.from_jwk(json.dumps(jwk)) + + kid = jwt.get_unverified_header(id_token)["kid"] + key = public_keys[kid] + + id_payload = jwt.decode( + id_token, + key=key, + algorithms=["RS256"], + audience=client_id, + ) + id_token_payload = OktaIdTokenPayload(**id_payload) + assert id_token_payload.iss == iss + assert id_token_payload.aud == client_id + + return id_token_payload + + +class OktaLoginMixin(StateMixin): + def _fetch_user_data( + self, + iss: str, + code: str, + state: str, + redirect_url: str, + auth: HTTPBasicAuth, + ) -> OktaTokenResponse | None: + res = requests.post( + f"{iss}/oauth2/v1/token", + auth=auth, + data={ + "grant_type": "authorization_code", + "code": code, + "redirect_uri": redirect_url, + "state": state, + }, + ) + + if not self.verify_state(state): + log.warning("Invalid state during Okta OAuth") + return None + + if res.status_code >= 400: + return None + + return OktaTokenResponse(**res.json()) + + def _redirect_to_consent( + self, iss: str, client_id: str, oauth_redirect_url: str + ) -> HttpResponse: + state = self.generate_state() + qs = urlencode( + dict( + response_type="code", + client_id=client_id, + scope="openid email profile", + redirect_uri=oauth_redirect_url, + state=state, + ) + ) + redirect_url = f"{iss}/oauth2/v1/authorize?{qs}" + response = redirect(redirect_url) + return response diff --git a/compare/commands/compare/__init__.py b/compare/commands/compare/__init__.py index f32d111513..efe8c26c2f 100644 --- a/compare/commands/compare/__init__.py +++ b/compare/commands/compare/__init__.py @@ -1 +1,3 @@ from .compare import CompareCommands + +__all__ = ["CompareCommands"] diff --git a/compare/commands/compare/interactors/fetch_impacted_files.py b/compare/commands/compare/interactors/fetch_impacted_files.py index 8639b548f1..678e0c06ea 100644 --- a/compare/commands/compare/interactors/fetch_impacted_files.py +++ b/compare/commands/compare/interactors/fetch_impacted_files.py @@ -6,7 +6,7 @@ import services.components as components from codecov.commands.base import BaseInteractor from services.comparison import Comparison, ComparisonReport, ImpactedFile -from services.report import files_belonging_to_flags, files_in_sessions +from services.report import files_belonging_to_flags class ImpactedFileParameter(enum.Enum): diff --git a/compare/commands/compare/interactors/tests/test_fetch_impacted_files.py b/compare/commands/compare/interactors/tests/test_fetch_impacted_files.py index 8c30f419a3..0e1c63a324 100644 --- a/compare/commands/compare/interactors/tests/test_fetch_impacted_files.py +++ b/compare/commands/compare/interactors/tests/test_fetch_impacted_files.py @@ -1,7 +1,6 @@ import enum from unittest.mock import PropertyMock, patch -from django.contrib.auth.models import AnonymousUser from django.test import TransactionTestCase from shared.reports.resources import Report, ReportFile, ReportLine from shared.utils.sessions import Session diff --git a/conftest.py b/conftest.py index 4b46b276a0..a9c3902a5e 100644 --- a/conftest.py +++ b/conftest.py @@ -1,4 +1,3 @@ -import os from pathlib import Path import fakeredis diff --git a/core/admin.py b/core/admin.py index 0397d9e892..3cf8edd86d 100644 --- a/core/admin.py +++ b/core/admin.py @@ -1,3 +1,4 @@ +from django import forms from django.contrib import admin from django.core.paginator import Paginator from django.db import connections @@ -42,6 +43,19 @@ def count(self): return int(result[0]) +class RepositoryAdminForm(forms.ModelForm): + # the model field has null=True but not blank=True, so we have to add a workaround + # to be able to clear out this field through the django admin + webhook_secret = forms.CharField(required=False, empty_value=None) + yaml = forms.JSONField(required=False) + using_integration = forms.BooleanField(required=False) + hookid = forms.CharField(required=False, empty_value=None) + + class Meta: + model = Repository + fields = "__all__" + + @admin.register(Repository) class RepositoryAdmin(AdminMixin, admin.ModelAdmin): inlines = [RepositoryTokenInline] @@ -49,6 +63,7 @@ class RepositoryAdmin(AdminMixin, admin.ModelAdmin): search_fields = ("author__username__exact",) show_full_result_count = False autocomplete_fields = ("bot",) + form = RepositoryAdminForm paginator = EstimatedCountPaginator @@ -67,7 +82,13 @@ class RepositoryAdmin(AdminMixin, admin.ModelAdmin): "activated", "deleted", ) - fields = readonly_fields + ("bot", "using_integration", "branch", "private") + fields = readonly_fields + ( + "bot", + "using_integration", + "branch", + "private", + "webhook_secret", + ) def has_delete_permission(self, request, obj=None): return False diff --git a/core/apps.py b/core/apps.py index f1cc34ff42..db32d6e931 100644 --- a/core/apps.py +++ b/core/apps.py @@ -1,6 +1,4 @@ from django.apps import AppConfig -from redis import Redis -from shared.config import get_config from shared.helpers.cache import RedisBackend from services.redis_configuration import get_redis_connection @@ -12,7 +10,7 @@ class CoreConfig(AppConfig): name = "core" def ready(self): - import core.signals + import core.signals # noqa: F401 if RUN_ENV not in ["DEV", "TESTING"]: cache_backend = RedisBackend(get_redis_connection()) diff --git a/core/commands/branch/__init__.py b/core/commands/branch/__init__.py index 4cbb5bb371..62e21450e3 100644 --- a/core/commands/branch/__init__.py +++ b/core/commands/branch/__init__.py @@ -1 +1,3 @@ from .branch import BranchCommands + +__all__ = ["BranchCommands"] diff --git a/core/commands/branch/interactors/tests/test_fetch_branch.py b/core/commands/branch/interactors/tests/test_fetch_branch.py index bbdcb741c9..f30706e043 100644 --- a/core/commands/branch/interactors/tests/test_fetch_branch.py +++ b/core/commands/branch/interactors/tests/test_fetch_branch.py @@ -1,5 +1,3 @@ -import pytest -from django.contrib.auth.models import AnonymousUser from django.test import TransactionTestCase from codecov_auth.tests.factories import OwnerFactory diff --git a/core/commands/branch/interactors/tests/test_fetch_branches.py b/core/commands/branch/interactors/tests/test_fetch_branches.py index 7c1a627039..e791b7bab0 100644 --- a/core/commands/branch/interactors/tests/test_fetch_branches.py +++ b/core/commands/branch/interactors/tests/test_fetch_branches.py @@ -1,6 +1,4 @@ -import pytest from asgiref.sync import async_to_sync -from django.contrib.auth.models import AnonymousUser from django.test import TransactionTestCase from codecov_auth.tests.factories import OwnerFactory diff --git a/core/commands/commit/__init__.py b/core/commands/commit/__init__.py index 2b7f09c9d8..38bf294c84 100644 --- a/core/commands/commit/__init__.py +++ b/core/commands/commit/__init__.py @@ -1 +1,3 @@ from .commit import CommitCommands + +__all__ = ["CommitCommands"] diff --git a/core/commands/commit/interactors/fetch_totals.py b/core/commands/commit/interactors/fetch_totals.py index 94db6f0188..0abf1d66c7 100644 --- a/core/commands/commit/interactors/fetch_totals.py +++ b/core/commands/commit/interactors/fetch_totals.py @@ -1,8 +1,5 @@ -from django.db.models import Prefetch - from codecov.commands.base import BaseInteractor from codecov.db import sync_to_async -from reports.models import CommitReport class FetchTotalsInteractor(BaseInteractor): diff --git a/core/commands/commit/interactors/tests/test_get_commits_errors.py b/core/commands/commit/interactors/tests/test_get_commits_errors.py index bb3c7ccfa6..d5949d7501 100644 --- a/core/commands/commit/interactors/tests/test_get_commits_errors.py +++ b/core/commands/commit/interactors/tests/test_get_commits_errors.py @@ -1,5 +1,4 @@ from asgiref.sync import async_to_sync -from django.contrib.auth.models import AnonymousUser from django.test import TransactionTestCase from core.tests.factories import CommitErrorFactory, CommitFactory, OwnerFactory diff --git a/core/commands/commit/interactors/tests/test_get_file_content.py b/core/commands/commit/interactors/tests/test_get_file_content.py index b8e6e41ca9..053dfacb78 100644 --- a/core/commands/commit/interactors/tests/test_get_file_content.py +++ b/core/commands/commit/interactors/tests/test_get_file_content.py @@ -1,9 +1,6 @@ -import asyncio from unittest.mock import patch import pytest -from asgiref.sync import async_to_sync -from django.contrib.auth.models import AnonymousUser from django.test import TransactionTestCase from shared.torngit.exceptions import TorngitObjectNotFoundError @@ -65,7 +62,7 @@ async def test_when_path_has_no_file(self, mock_provider_adapter): response_data=404, message="not found" ) file_content = await self.execute(None, self.commit, "path") - assert file_content == None + assert file_content is None @patch("services.repo_providers.RepoProviderService.async_get_adapter") @pytest.mark.asyncio diff --git a/core/commands/commit/interactors/tests/test_get_final_yaml.py b/core/commands/commit/interactors/tests/test_get_final_yaml.py index 002570dab8..96a04f9584 100644 --- a/core/commands/commit/interactors/tests/test_get_final_yaml.py +++ b/core/commands/commit/interactors/tests/test_get_final_yaml.py @@ -1,9 +1,7 @@ import asyncio from unittest.mock import patch -import pytest from asgiref.sync import async_to_sync -from django.contrib.auth.models import AnonymousUser from django.test import TransactionTestCase from shared.torngit.exceptions import TorngitObjectNotFoundError diff --git a/core/commands/commit/interactors/tests/test_get_uploads_number.py b/core/commands/commit/interactors/tests/test_get_uploads_number.py index 171a0259aa..77b4cc4d3d 100644 --- a/core/commands/commit/interactors/tests/test_get_uploads_number.py +++ b/core/commands/commit/interactors/tests/test_get_uploads_number.py @@ -1,5 +1,4 @@ from asgiref.sync import async_to_sync -from django.contrib.auth.models import AnonymousUser from django.test import TransactionTestCase from core.tests.factories import CommitFactory diff --git a/core/commands/component/__init__.py b/core/commands/component/__init__.py index bad127d944..f26cd80a77 100644 --- a/core/commands/component/__init__.py +++ b/core/commands/component/__init__.py @@ -1 +1,3 @@ from .component import ComponentCommands + +__all__ = ["ComponentCommands"] diff --git a/core/commands/component/interactors/delete_component_measurements.py b/core/commands/component/interactors/delete_component_measurements.py index e557e5a906..c000d87d99 100644 --- a/core/commands/component/interactors/delete_component_measurements.py +++ b/core/commands/component/interactors/delete_component_measurements.py @@ -3,7 +3,6 @@ import services.self_hosted as self_hosted from codecov.commands.base import BaseInteractor from codecov.commands.exceptions import ( - NotFound, Unauthenticated, Unauthorized, ValidationError, diff --git a/core/commands/flag/__init__.py b/core/commands/flag/__init__.py index a7513e9951..56918fd20f 100644 --- a/core/commands/flag/__init__.py +++ b/core/commands/flag/__init__.py @@ -1 +1,3 @@ from .flag import FlagCommands + +__all__ = ["FlagCommands"] diff --git a/core/commands/flag/tests/test_flag.py b/core/commands/flag/tests/test_flag.py index 56c4031ab8..cc624f0b4b 100644 --- a/core/commands/flag/tests/test_flag.py +++ b/core/commands/flag/tests/test_flag.py @@ -1,6 +1,5 @@ from unittest.mock import patch -from django.contrib.auth.models import AnonymousUser from django.test import TransactionTestCase, override_settings from codecov.commands.exceptions import ( diff --git a/core/commands/pull/__init__.py b/core/commands/pull/__init__.py index 70e10e7bb1..f4c4878ad0 100644 --- a/core/commands/pull/__init__.py +++ b/core/commands/pull/__init__.py @@ -1 +1,3 @@ from .pull import PullCommands + +__all__ = ["PullCommands"] diff --git a/core/commands/pull/interactors/fetch_pull_request.py b/core/commands/pull/interactors/fetch_pull_request.py index 6c13b8b6c1..deb6824e2c 100644 --- a/core/commands/pull/interactors/fetch_pull_request.py +++ b/core/commands/pull/interactors/fetch_pull_request.py @@ -1,9 +1,25 @@ +from datetime import datetime, timedelta + +from shared.django_apps.core.models import Pull + from codecov.commands.base import BaseInteractor from codecov.db import sync_to_async -from core.models import Pull +from services.task.task import TaskService class FetchPullRequestInteractor(BaseInteractor): + def _should_sync_pull(self, pull: Pull | None) -> bool: + return ( + pull is not None + and pull.state == "open" + and pull.updatestamp is not None + and (datetime.now(tz=None) - pull.updatestamp) > timedelta(hours=1) + ) + @sync_to_async def execute(self, repository, id): - return repository.pull_requests.filter(pullid=id).first() + pull = repository.pull_requests.filter(pullid=id).first() + if self._should_sync_pull(pull): + TaskService().pulls_sync(repository.repoid, id) + + return pull diff --git a/core/commands/pull/interactors/tests/test_fetch_pull_request.py b/core/commands/pull/interactors/tests/test_fetch_pull_request.py index a651093a36..c8393ee809 100644 --- a/core/commands/pull/interactors/tests/test_fetch_pull_request.py +++ b/core/commands/pull/interactors/tests/test_fetch_pull_request.py @@ -1,11 +1,10 @@ +from datetime import datetime + import pytest -from asgiref.sync import async_to_sync -from django.contrib.auth.models import AnonymousUser from django.test import TransactionTestCase +from freezegun import freeze_time -from api.internal import pull from core.tests.factories import OwnerFactory, PullFactory, RepositoryFactory -from reports.tests.factories import UploadFactory from ..fetch_pull_request import FetchPullRequestInteractor @@ -28,3 +27,46 @@ async def test_fetch_when_pull_request_doesnt_exist(self): async def test_fetch_pull_request(self): pr = await self.execute(None, self.repo, self.pr.pullid) assert pr == self.pr + + +# Not part of the class because TransactionTestCase cannot be parametrized +@freeze_time("2024-07-01 12:00:00") +@pytest.mark.parametrize( + "pr_state, updatestamp, expected", + [ + pytest.param( + "open", "2024-07-01 11:50:00", False, id="pr_open_recently_updated" + ), + pytest.param("merged", "2024-07-01 01:00:00", False, id="pr_merged"), + pytest.param( + "closed", "2024-07-01 11:50:00", False, id="pr_closed_recently_updated" + ), + pytest.param( + "open", "2024-07-01 01:00:00", True, id="pr_open_not_recently_updated" + ), + ], +) +def test_fetch_pull_should_sync(pr_state, updatestamp, expected, db): + repo = RepositoryFactory(private=False) + pr = PullFactory(repository_id=repo.repoid, state=pr_state) + repo.save() + pr.save() # This will change the updatestamp, so we need to set it again + pr.updatestamp = datetime.fromisoformat(updatestamp).replace(tzinfo=None) + should_sync = FetchPullRequestInteractor( + repo.author, repo.service + )._should_sync_pull(pr) + assert pr.updatestamp == datetime.fromisoformat(updatestamp).replace(tzinfo=None) + assert should_sync == expected + + +def test_fetch_pull_updatestamp_is_none(db): + repo = RepositoryFactory(private=False) + pr = PullFactory(repository_id=repo.repoid, state="open") + repo.save() + pr.save() + pr.updatestamp = None + should_sync = FetchPullRequestInteractor( + repo.author, repo.service + )._should_sync_pull(pr) + assert pr.updatestamp is None + assert should_sync == False diff --git a/core/commands/pull/interactors/tests/test_fetch_pull_requests.py b/core/commands/pull/interactors/tests/test_fetch_pull_requests.py index 6e58ab46a8..9022654350 100644 --- a/core/commands/pull/interactors/tests/test_fetch_pull_requests.py +++ b/core/commands/pull/interactors/tests/test_fetch_pull_requests.py @@ -1,11 +1,8 @@ -import pytest from asgiref.sync import async_to_sync -from django.contrib.auth.models import AnonymousUser from django.test import TransactionTestCase from core.models import PullStates from core.tests.factories import OwnerFactory, PullFactory, RepositoryFactory -from reports.tests.factories import UploadFactory from ..fetch_pull_requests import FetchPullRequestsInteractor diff --git a/core/commands/pull/tests/test_pull.py b/core/commands/pull/tests/test_pull.py index c59f0ef3f3..21bbd10f7d 100644 --- a/core/commands/pull/tests/test_pull.py +++ b/core/commands/pull/tests/test_pull.py @@ -3,8 +3,7 @@ from django.test import TransactionTestCase from codecov_auth.tests.factories import OwnerFactory -from core.models import Pull -from core.tests.factories import PullFactory, RepositoryFactory +from core.tests.factories import RepositoryFactory from ..pull import PullCommands diff --git a/core/commands/repository/__init__.py b/core/commands/repository/__init__.py index 71688e39fa..099c43bceb 100644 --- a/core/commands/repository/__init__.py +++ b/core/commands/repository/__init__.py @@ -1 +1,3 @@ from .repository import RepositoryCommands + +__all__ = ["RepositoryCommands"] diff --git a/core/commands/repository/interactors/encode_secret_string.py b/core/commands/repository/interactors/encode_secret_string.py index d7f1b2bde4..25e6bc2fbb 100644 --- a/core/commands/repository/interactors/encode_secret_string.py +++ b/core/commands/repository/interactors/encode_secret_string.py @@ -1,5 +1,3 @@ -from dataclasses import dataclass - from codecov.commands.base import BaseInteractor from codecov.commands.exceptions import Unauthenticated, ValidationError from codecov.db import sync_to_async @@ -10,12 +8,13 @@ class EncodeSecretStringInteractor(BaseInteractor): @sync_to_async - def execute(self, owner: Owner, repo: Repository, value: str) -> str: + def execute(self, owner: Owner, repo_name: str, value: str) -> str: if not self.current_user.is_authenticated: raise Unauthenticated() + + repo = Repository.objects.viewable_repos(owner).filter(name=repo_name).first() if not repo: raise ValidationError("Repo not found") - to_encode = "/".join( ( owner.service, diff --git a/core/commands/repository/interactors/fetch_repository.py b/core/commands/repository/interactors/fetch_repository.py index 76ebb932ec..0993e959c5 100644 --- a/core/commands/repository/interactors/fetch_repository.py +++ b/core/commands/repository/interactors/fetch_repository.py @@ -1,14 +1,27 @@ +from shared.django_apps.codecov_auth.models import Owner +from shared.django_apps.core.models import Repository + from codecov.commands.base import BaseInteractor from codecov.db import sync_to_async -from core.models import Repository class FetchRepositoryInteractor(BaseInteractor): @sync_to_async - def execute(self, owner, name): + def execute( + self, + owner: Owner, + name: str, + okta_authenticated_accounts: list[int], + exclude_okta_enforced_repos: bool = True, + ): + queryset = Repository.objects.viewable_repos(self.current_owner) + if exclude_okta_enforced_repos: + queryset = queryset.exclude_accounts_enforced_okta( + okta_authenticated_accounts + ) + return ( - Repository.objects.viewable_repos(self.current_owner) - .filter(author=owner, name=name) + queryset.filter(author=owner, name=name) .with_recent_coverage() .with_oldest_commit_at() .select_related("author") diff --git a/core/commands/repository/interactors/get_repository_token.py b/core/commands/repository/interactors/get_repository_token.py index 3d68973e83..55cd13b952 100644 --- a/core/commands/repository/interactors/get_repository_token.py +++ b/core/commands/repository/interactors/get_repository_token.py @@ -1,5 +1,5 @@ from codecov.commands.base import BaseInteractor -from codecov.commands.exceptions import Unauthenticated, ValidationError +from codecov.commands.exceptions import Unauthenticated from codecov.db import sync_to_async from codecov_auth.helpers import current_user_part_of_org from codecov_auth.models import RepositoryToken @@ -9,12 +9,13 @@ class GetRepositoryTokenInteractor(BaseInteractor): def validate(self, repository): if not self.current_user.is_authenticated: raise Unauthenticated() - if not repository.active: - raise ValidationError("Repo is not active") @sync_to_async def execute(self, repository, token_type): self.validate(repository) + if not repository.active: + return None + if current_user_part_of_org(self.current_owner, repository.author): token = RepositoryToken.objects.filter( repository_id=repository.repoid, token_type=token_type diff --git a/core/commands/repository/interactors/get_upload_token.py b/core/commands/repository/interactors/get_upload_token.py index b63bbf7e46..80ef954d4a 100644 --- a/core/commands/repository/interactors/get_upload_token.py +++ b/core/commands/repository/interactors/get_upload_token.py @@ -1,6 +1,5 @@ from codecov.commands.base import BaseInteractor from codecov_auth.helpers import current_user_part_of_org -from core.models import Repository class GetUploadTokenInteractor(BaseInteractor): diff --git a/core/commands/repository/interactors/regenerate_repository_token.py b/core/commands/repository/interactors/regenerate_repository_token.py index e294ef4c50..fa25e84498 100644 --- a/core/commands/repository/interactors/regenerate_repository_token.py +++ b/core/commands/repository/interactors/regenerate_repository_token.py @@ -1,5 +1,5 @@ from codecov.commands.base import BaseInteractor -from codecov.commands.exceptions import Unauthenticated, ValidationError +from codecov.commands.exceptions import ValidationError from codecov.db import sync_to_async from codecov_auth.models import Owner, RepositoryToken from core.models import Repository diff --git a/core/commands/repository/interactors/regenerate_repository_upload_token.py b/core/commands/repository/interactors/regenerate_repository_upload_token.py index c51b40ba0b..3642b9faff 100644 --- a/core/commands/repository/interactors/regenerate_repository_upload_token.py +++ b/core/commands/repository/interactors/regenerate_repository_upload_token.py @@ -15,7 +15,7 @@ def execute(self, repo_name: str, owner_username: str) -> uuid.UUID: ).first() repo = ( Repository.objects.viewable_repos(self.current_owner) - .filter(author=author, name=repo_name, active=True) + .filter(author=author, name=repo_name) .first() ) if not repo: diff --git a/core/commands/repository/interactors/tests/test_activate_measurements.py b/core/commands/repository/interactors/tests/test_activate_measurements.py index b2bd3e7a5a..4699b4a59f 100644 --- a/core/commands/repository/interactors/tests/test_activate_measurements.py +++ b/core/commands/repository/interactors/tests/test_activate_measurements.py @@ -4,7 +4,6 @@ import pytest from asgiref.sync import async_to_sync from django.conf import settings -from django.contrib.auth.models import AnonymousUser from django.test import TransactionTestCase, override_settings from django.utils import timezone from freezegun import freeze_time diff --git a/core/commands/repository/interactors/tests/test_encode_secret_string.py b/core/commands/repository/interactors/tests/test_encode_secret_string.py index 97c02c492b..8e3230b025 100644 --- a/core/commands/repository/interactors/tests/test_encode_secret_string.py +++ b/core/commands/repository/interactors/tests/test_encode_secret_string.py @@ -1,8 +1,5 @@ -from unittest.mock import patch - import pytest from asgiref.sync import async_to_sync -from django.contrib.auth.models import AnonymousUser from django.test import TransactionTestCase from shared.encryption.yaml_secret import yaml_secret_encryptor @@ -14,21 +11,23 @@ class EncodeSecretStringInteractorTest(TransactionTestCase): @async_to_sync - def execute(self, owner, repo, value): - return EncodeSecretStringInteractor(owner, "github").execute(owner, repo, value) + def execute(self, owner, repo_name, value): + return EncodeSecretStringInteractor(owner, "github").execute( + owner, repo_name, value + ) def test_encode_secret_string(self): owner = OwnerFactory() - repo = RepositoryFactory(author=owner, name="repo-1") - res = self.execute(owner, repo=repo, value="token-1") + RepositoryFactory(author=owner, name="repo-1") + res = self.execute(owner, repo_name="repo-1", value="token-1") check_encryptor = yaml_secret_encryptor assert "token-1" in check_encryptor.decode(res[7:]) def test_validation_error_when_repo_not_found(self): owner = OwnerFactory() with pytest.raises(ValidationError): - self.execute(owner, repo=None, value="token-1") + self.execute(owner, repo_name=None, value="token-1") def test_user_is_not_authenticated(self): - with pytest.raises(Unauthenticated) as e: - self.execute(None, repo=None, value="test") + with pytest.raises(Unauthenticated): + self.execute(None, repo_name=None, value="test") diff --git a/core/commands/repository/interactors/tests/test_erase_repository.py b/core/commands/repository/interactors/tests/test_erase_repository.py index f866c0f13f..649d6772f9 100644 --- a/core/commands/repository/interactors/tests/test_erase_repository.py +++ b/core/commands/repository/interactors/tests/test_erase_repository.py @@ -1,5 +1,4 @@ import pytest -from asgiref.sync import async_to_sync from django.test import TransactionTestCase from codecov.commands.exceptions import Unauthorized diff --git a/core/commands/repository/interactors/tests/test_fetch_repository.py b/core/commands/repository/interactors/tests/test_fetch_repository.py index 77e48174e2..c76e632d13 100644 --- a/core/commands/repository/interactors/tests/test_fetch_repository.py +++ b/core/commands/repository/interactors/tests/test_fetch_repository.py @@ -1,13 +1,9 @@ -import pytest -from django.contrib.auth.models import AnonymousUser from django.test import TransactionTestCase - -from codecov.commands.exceptions import ( - NotFound, - Unauthenticated, - Unauthorized, - ValidationError, +from shared.django_apps.codecov_auth.tests.factories import ( + AccountFactory, + OktaSettingsFactory, ) + from codecov_auth.tests.factories import OwnerFactory from core.tests.factories import RepositoryFactory @@ -17,36 +13,77 @@ class FetchRepositoryInteractorTest(TransactionTestCase): def setUp(self): self.org = OwnerFactory() + + self.okta_account = AccountFactory() + self.okta_settings = OktaSettingsFactory( + account=self.okta_account, enforced=True + ) + self.okta_org = OwnerFactory(account=self.okta_account) + self.public_repo = RepositoryFactory(author=self.org, private=False) self.hidden_private_repo = RepositoryFactory(author=self.org, private=True) self.private_repo = RepositoryFactory(author=self.org, private=True) + self.okta_private_repo = RepositoryFactory(author=self.okta_org, private=True) self.current_user = OwnerFactory( - permission=[self.private_repo.repoid], organizations=[self.org.ownerid] + permission=[self.private_repo.repoid, self.okta_private_repo.repoid], + organizations=[self.org.ownerid, self.okta_org.ownerid], ) # helper to execute the interactor - def execute(self, owner, *args): + def execute(self, owner, *args, **kwargs): service = owner.service if owner else "github" - return FetchRepositoryInteractor(owner, service).execute(*args) + return FetchRepositoryInteractor(owner, service).execute(*args, **kwargs) async def test_fetch_public_repo_unauthenticated(self): - repo = await self.execute(None, self.org, self.public_repo.name) + repo = await self.execute(None, self.org, self.public_repo.name, []) assert repo == self.public_repo async def test_fetch_public_repo_authenticated(self): - repo = await self.execute(self.current_user, self.org, self.public_repo.name) + repo = await self.execute( + self.current_user, self.org, self.public_repo.name, [] + ) assert repo == self.public_repo async def test_fetch_private_repo_unauthenticated(self): - repo = await self.execute(None, self.org, self.private_repo.name) + repo = await self.execute(None, self.org, self.private_repo.name, []) assert repo is None async def test_fetch_private_repo_authenticated_but_no_permissions(self): repo = await self.execute( - self.current_user, self.org, self.hidden_private_repo.name + self.current_user, self.org, self.hidden_private_repo.name, [] ) assert repo is None async def test_fetch_private_repo_authenticated_with_permissions(self): - repo = await self.execute(self.current_user, self.org, self.private_repo.name) + repo = await self.execute( + self.current_user, self.org, self.private_repo.name, [] + ) assert repo == self.private_repo + + async def test_fetch_okta_private_repo_authenticated(self): + repo = await self.execute( + self.current_user, + self.okta_org, + self.okta_private_repo.name, + [self.okta_account.id], + ) + assert repo == self.okta_private_repo + + async def test_fetch_okta_private_repo_unauthenticated(self): + repo = await self.execute( + self.current_user, + self.okta_org, + self.okta_private_repo.name, + [], + ) + assert repo is None + + async def test_fetch_okta_private_repo_do_not_exclude_unauthenticated(self): + repo = await self.execute( + self.current_user, + self.okta_org, + self.okta_private_repo.name, + [], + exclude_okta_enforced_repos=False, + ) + assert repo == self.okta_private_repo diff --git a/core/commands/repository/interactors/tests/test_get_repository_token.py b/core/commands/repository/interactors/tests/test_get_repository_token.py index f570f9c970..16e15ac384 100644 --- a/core/commands/repository/interactors/tests/test_get_repository_token.py +++ b/core/commands/repository/interactors/tests/test_get_repository_token.py @@ -1,10 +1,7 @@ -from xml.dom import ValidationErr - import pytest -from django.contrib.auth.models import AnonymousUser from django.test import TransactionTestCase -from codecov.commands.exceptions import Unauthenticated, ValidationError +from codecov.commands.exceptions import Unauthenticated from codecov_auth.tests.factories import OwnerFactory from core.tests.factories import RepositoryFactory, RepositoryTokenFactory @@ -36,8 +33,8 @@ async def test_when_unauthenticated_raise(self): await self.execute(owner="", repo=self.active_repo) async def test_when_repo_inactive(self): - with pytest.raises(ValidationError): - await self.execute(owner=self.user, repo=self.inactive_repo) + token = await self.execute(owner=self.user, repo=self.inactive_repo) + assert token is None async def test_when_repo_has_no_token(self): token = await self.execute(owner=self.user, repo=self.repo_with_no_token) diff --git a/core/commands/repository/interactors/tests/test_get_upload_token.py b/core/commands/repository/interactors/tests/test_get_upload_token.py index 52b9f62687..940264d9dc 100644 --- a/core/commands/repository/interactors/tests/test_get_upload_token.py +++ b/core/commands/repository/interactors/tests/test_get_upload_token.py @@ -1,5 +1,3 @@ -import pytest -from django.contrib.auth.models import AnonymousUser from django.test import TransactionTestCase from codecov_auth.tests.factories import OwnerFactory diff --git a/core/commands/repository/interactors/tests/test_regenerate_repository_token.py b/core/commands/repository/interactors/tests/test_regenerate_repository_token.py index 3870f3ae46..58452bced0 100644 --- a/core/commands/repository/interactors/tests/test_regenerate_repository_token.py +++ b/core/commands/repository/interactors/tests/test_regenerate_repository_token.py @@ -1,9 +1,7 @@ import pytest -from asgiref.sync import async_to_sync -from django.contrib.auth.models import AnonymousUser from django.test import TransactionTestCase -from codecov.commands.exceptions import Unauthenticated, ValidationError +from codecov.commands.exceptions import ValidationError from codecov_auth.tests.factories import OwnerFactory from core.tests.factories import RepositoryFactory, RepositoryTokenFactory diff --git a/core/commands/repository/interactors/tests/test_update_repository.py b/core/commands/repository/interactors/tests/test_update_repository.py index bd702509f2..bbd366be7c 100644 --- a/core/commands/repository/interactors/tests/test_update_repository.py +++ b/core/commands/repository/interactors/tests/test_update_repository.py @@ -1,11 +1,8 @@ import pytest -from asgiref.sync import async_to_sync -from django.contrib.auth.models import AnonymousUser from django.test import TransactionTestCase from codecov.commands.exceptions import Unauthorized from codecov_auth.tests.factories import OwnerFactory -from core.tests.factories import RepositoryFactory, RepositoryTokenFactory from ..update_repository import UpdateRepositoryInteractor diff --git a/core/commands/repository/interactors/update_repository.py b/core/commands/repository/interactors/update_repository.py index b4737438bd..d455da85c5 100644 --- a/core/commands/repository/interactors/update_repository.py +++ b/core/commands/repository/interactors/update_repository.py @@ -1,8 +1,5 @@ from typing import Optional -from django.conf import settings - -import services.self_hosted as self_hosted from codecov.commands.base import BaseInteractor from codecov.commands.exceptions import Unauthenticated, Unauthorized, ValidationError from codecov.db import sync_to_async diff --git a/core/commands/repository/repository.py b/core/commands/repository/repository.py index 0d8d76691b..a40c484cd7 100644 --- a/core/commands/repository/repository.py +++ b/core/commands/repository/repository.py @@ -20,8 +20,19 @@ class RepositoryCommands(BaseCommand): - def fetch_repository(self, owner, name): - return self.get_interactor(FetchRepositoryInteractor).execute(owner, name) + def fetch_repository( + self, + owner, + name, + okta_authenticated_accounts: list[int], + exclude_okta_enforced_repos: bool = True, + ) -> Repository: + return self.get_interactor(FetchRepositoryInteractor).execute( + owner, + name, + okta_authenticated_accounts, + exclude_okta_enforced_repos=exclude_okta_enforced_repos, + ) def regenerate_repository_upload_token( self, @@ -68,7 +79,7 @@ def activate_measurements( def erase_repository(self, repo_name: str, owner: Owner): return self.get_interactor(EraseRepositoryInteractor).execute(repo_name, owner) - def encode_secret_string(self, owner: Owner, repo: Repository, value: str): + def encode_secret_string(self, owner: Owner, repo_name: str, value: str): return self.get_interactor(EncodeSecretStringInteractor).execute( - owner, repo, value + owner, repo_name, value ) diff --git a/core/commands/repository/tests/test_repository.py b/core/commands/repository/tests/test_repository.py index c750b7a1ce..693efc1187 100644 --- a/core/commands/repository/tests/test_repository.py +++ b/core/commands/repository/tests/test_repository.py @@ -18,8 +18,19 @@ def setUp(self): @patch("core.commands.repository.repository.FetchRepositoryInteractor.execute") def test_fetch_repository_to_interactor(self, interactor_mock): - self.command.fetch_repository(self.org, self.repo.name) - interactor_mock.assert_called_once_with(self.org, self.repo.name) + self.command.fetch_repository(self.org, self.repo.name, []) + interactor_mock.assert_called_once_with( + self.org, self.repo.name, [], exclude_okta_enforced_repos=True + ) + + @patch("core.commands.repository.repository.FetchRepositoryInteractor.execute") + def test_fetch_repository_to_interactor_with_enforcing_okta(self, interactor_mock): + self.command.fetch_repository( + self.org, self.repo.name, [], exclude_okta_enforced_repos=False + ) + interactor_mock.assert_called_once_with( + self.org, self.repo.name, [], exclude_okta_enforced_repos=False + ) @patch("core.commands.repository.repository.GetUploadTokenInteractor.execute") def test_get_upload_token_to_interactor(self, interactor_mock): diff --git a/core/commands/upload/__init__.py b/core/commands/upload/__init__.py index 0942ec31d9..e374f41f81 100644 --- a/core/commands/upload/__init__.py +++ b/core/commands/upload/__init__.py @@ -1 +1,3 @@ from .upload import UploadCommands + +__all__ = ["UploadCommands"] diff --git a/core/commands/upload/interactors/tests/test_upload_error.py b/core/commands/upload/interactors/tests/test_upload_error.py index 02e942996a..26060ff899 100644 --- a/core/commands/upload/interactors/tests/test_upload_error.py +++ b/core/commands/upload/interactors/tests/test_upload_error.py @@ -1,10 +1,7 @@ -import pytest from asgiref.sync import async_to_sync -from django.contrib.auth.models import AnonymousUser from django.test import TransactionTestCase from codecov_auth.tests.factories import OwnerFactory -from core import models from core.tests.factories import CommitFactory, RepositoryFactory from graphql_api.types.enums import UploadErrorEnum, UploadState from reports.tests.factories import ( @@ -57,15 +54,15 @@ def test_get_upload_errors_no_error(self): other_upload = UploadFactory( report=commit_report, state=UploadState.ERROR.value ) - other_upload_error_1 = UploadErrorFactory(report_session=other_upload) - other_upload_error_2 = UploadErrorFactory(report_session=other_upload) + UploadErrorFactory(report_session=other_upload) + UploadErrorFactory(report_session=other_upload) another_upload = UploadFactory( report=commit_report, state=UploadState.ERROR.value ) - another_upload_error_1 = UploadErrorFactory(report_session=another_upload) - another_upload_error_2 = UploadErrorFactory(report_session=another_upload) - another_upload_error_3 = UploadErrorFactory(report_session=another_upload) + UploadErrorFactory(report_session=another_upload) + UploadErrorFactory(report_session=another_upload) + UploadErrorFactory(report_session=another_upload) upload = UploadFactory(report=commit_report) diff --git a/core/management/commands/backfill_commits.py b/core/management/commands/backfill_commits.py index 97c163420c..b72f55252c 100644 --- a/core/management/commands/backfill_commits.py +++ b/core/management/commands/backfill_commits.py @@ -1,6 +1,6 @@ import logging -from django.core.management.base import BaseCommand, CommandError, CommandParser +from django.core.management.base import BaseCommand, CommandParser from redis import Redis from shared.config import get_config diff --git a/core/management/commands/check_for_migration_conflicts.py b/core/management/commands/check_for_migration_conflicts.py index 93e7983a2e..85b965c678 100644 --- a/core/management/commands/check_for_migration_conflicts.py +++ b/core/management/commands/check_for_migration_conflicts.py @@ -1,7 +1,7 @@ import os from django.apps import apps -from django.core.management.base import BaseCommand, CommandError, no_translations +from django.core.management.base import BaseCommand class Command(BaseCommand): diff --git a/core/management/commands/update_gitlab_webhooks.py b/core/management/commands/update_gitlab_webhooks.py index 62d5d897b5..74697aaf1b 100644 --- a/core/management/commands/update_gitlab_webhooks.py +++ b/core/management/commands/update_gitlab_webhooks.py @@ -7,7 +7,6 @@ from shared.torngit.exceptions import TorngitClientError, TorngitRefreshTokenFailedError from shared.torngit.gitlab import Gitlab -from codecov_auth.models import Owner from core.models import Repository from services.repo_providers import RepoProviderService from utils.repos import get_bot_user diff --git a/core/managers.py b/core/managers.py index 1b86036c03..6a8d5590cb 100644 --- a/core/managers.py +++ b/core/managers.py @@ -4,7 +4,6 @@ from django.db.models import ( Avg, Count, - DateTimeField, F, FloatField, IntegerField, @@ -132,7 +131,7 @@ def with_latest_coverage_change(self): branch) of each repository. Depends on having called "with_latest_commit_totals_before" with "include_previous_totals=True". """ - from core.models import Commit + from core.models import Commit # noqa: F401 return self.annotate( latest_coverage=Cast( @@ -276,7 +275,7 @@ def get_or_create_from_git_repo(self, git_repo, owner): author=owner, service_id=git_repo.get("service_id") or git_repo.get("id"), private=git_repo["private"], - branch=git_repo.get("branch") or git_repo.get("default_branch") or "master", + branch=git_repo.get("branch") or git_repo.get("default_branch") or "main", name=git_repo["name"], ) diff --git a/core/migrations/0034_remove_repository_cache.py b/core/migrations/0034_remove_repository_cache.py index adea85487b..d0a454008f 100644 --- a/core/migrations/0034_remove_repository_cache.py +++ b/core/migrations/0034_remove_repository_cache.py @@ -1,7 +1,6 @@ # Generated by Django 4.2.2 on 2023-08-14 13:23 from django.db import migrations -from shared.django_apps.migration_utils import RiskyRemoveField class Migration(migrations.Migration): diff --git a/core/models.py b/core/models.py index dbfa2b90f7..7c63f11ab9 100644 --- a/core/models.py +++ b/core/models.py @@ -1,2 +1,2 @@ from shared.django_apps.core.models import * -from shared.django_apps.core.models import _gen_image_token +from shared.django_apps.core.models import _gen_image_token # noqa: F401 diff --git a/core/signals.py b/core/signals.py index 229ddc6917..55e232739f 100644 --- a/core/signals.py +++ b/core/signals.py @@ -1,13 +1,16 @@ import json +import logging from django.conf import settings from django.db.models.signals import post_save from django.dispatch import receiver from google.cloud import pubsub_v1 +from shared.django_apps.core.models import Commit from core.models import Repository _pubsub_publisher = None +log = logging.getLogger(__name__) def _get_pubsub_publisher(): @@ -19,9 +22,35 @@ def _get_pubsub_publisher(): @receiver(post_save, sender=Repository, dispatch_uid="shelter_sync_repo") def update_repository(sender, instance: Repository, **kwargs): + log.info(f"Signal triggered for repository {instance.repoid}") created = kwargs["created"] changes = instance.tracker.changed() if created or any([field in changes for field in ["name", "upload_token"]]): + try: + pubsub_project_id = settings.SHELTER_PUBSUB_PROJECT_ID + topic_id = settings.SHELTER_PUBSUB_SYNC_REPO_TOPIC_ID + if pubsub_project_id and topic_id: + publisher = _get_pubsub_publisher() + topic_path = publisher.topic_path(pubsub_project_id, topic_id) + publisher.publish( + topic_path, + json.dumps( + { + "type": "repo", + "sync": "one", + "id": instance.repoid, + } + ).encode("utf-8"), + ) + log.info(f"Message published for repository {instance.repoid}") + except Exception as e: + log.warning(f"Failed to publish message for repo {instance.repoid}: {e}") + + +@receiver(post_save, sender=Commit, dispatch_uid="shelter_sync_commit") +def update_commit(sender, instance: Commit, **kwargs): + branch = instance.branch + if branch and ":" in branch: pubsub_project_id = settings.SHELTER_PUBSUB_PROJECT_ID topic_id = settings.SHELTER_PUBSUB_SYNC_REPO_TOPIC_ID if pubsub_project_id and topic_id: @@ -31,9 +60,9 @@ def update_repository(sender, instance: Repository, **kwargs): topic_path, json.dumps( { - "type": "repo", + "type": "commit", "sync": "one", - "id": instance.repoid, + "id": instance.id, } ).encode("utf-8"), ) diff --git a/core/tests/factories.py b/core/tests/factories.py index d93577607f..fb2a5bfc58 100644 --- a/core/tests/factories.py +++ b/core/tests/factories.py @@ -89,25 +89,18 @@ def _create(cls, model_class, *args, **kwargs): "filename": "tests/__init__.py", "file_index": 0, "file_totals": [0, 3, 2, 1, 0, "66.66667", 0, 0, 0, 0, 0, 0, 0], - "session_totals": [ - [0, 3, 2, 1, 0, "66.66667", 0, 0, 0, 0, 0, 0, 0] - ], "diff_totals": None, }, { "filename": "tests/test_sample.py", "file_index": 1, "file_totals": [0, 7, 7, 0, 0, "100", 0, 0, 0, 0, 0, 0, 0], - "session_totals": [[0, 7, 7, 0, 0, "100", 0, 0, 0, 0, 0, 0, 0]], "diff_totals": None, }, { "filename": "awesome/__init__.py", "file_index": 2, "file_totals": [0, 10, 8, 2, 0, "80.00000", 0, 0, 0, 0, 0, 0, 0], - "session_totals": [ - [0, 10, 8, 2, 0, "80.00000", 0, 0, 0, 0, 0, 0, 0] - ], "diff_totals": [0, 2, 1, 1, 0, "50.00000", 0, 0, 0, 0, 0, 0, 0], }, ], diff --git a/core/tests/test_admin.py b/core/tests/test_admin.py index 6ecb682527..dac0c6efc0 100644 --- a/core/tests/test_admin.py +++ b/core/tests/test_admin.py @@ -1,10 +1,11 @@ +import uuid from unittest.mock import MagicMock, patch from django.contrib.admin.sites import AdminSite from django.test import TestCase from codecov_auth.tests.factories import UserFactory -from core.admin import RepositoryAdmin +from core.admin import RepositoryAdmin, RepositoryAdminForm from core.models import Repository from core.tests.factories import RepositoryFactory from utils.test_utils import Client @@ -52,3 +53,29 @@ def test_prev_and_new_values_in_log_entry(self, mocked_super_log_change): {"changed": {"fields": ["using_integration"]}}, {"using_integration": "prev value: True, new value: False"}, ] + + +class RepositoryAdminTests(AdminTest): + def test_webhook_secret_nullable(self): + repo = RepositoryFactory( + webhook_secret=str(uuid.uuid4()), + ) + self.assertIsNotNone(repo.webhook_secret) + data = { + "webhook_secret": "", + # all the required fields have to be filled out in the form even though they aren't changed + "name": repo.name, + "author": repo.author, + "service_id": repo.service_id, + "upload_token": repo.upload_token, + "image_token": repo.image_token, + "branch": repo.branch, + } + + form = RepositoryAdminForm(data=data, instance=repo) + self.assertTrue(form.is_valid()) + updated_instance = form.save() + self.assertIsNone(updated_instance.webhook_secret) + + repo.refresh_from_db() + self.assertIsNone(repo.webhook_secret) diff --git a/core/tests/test_management_commands.py b/core/tests/test_management_commands.py index 025dbe33d6..cca4ab0abe 100644 --- a/core/tests/test_management_commands.py +++ b/core/tests/test_management_commands.py @@ -1,5 +1,4 @@ import unittest.mock as mock -import uuid from io import StringIO import fakeredis diff --git a/core/tests/test_managers.py b/core/tests/test_managers.py index b5982890b4..36f9dfdd00 100644 --- a/core/tests/test_managers.py +++ b/core/tests/test_managers.py @@ -74,7 +74,7 @@ def test_get_or_create_from_github_repo_data(self): with self.subTest("doesnt crash when fork but no parent"): repo_data = { "id": 45, - "default_branch": "master", + "default_branch": "main", "private": True, "name": "test", "fork": True, @@ -85,7 +85,7 @@ def test_get_or_create_from_github_repo_data(self): ) assert created assert repo.service_id == 45 - assert repo.branch == "master" + assert repo.branch == "main" assert repo.private assert repo.name == "test" diff --git a/core/tests/test_middleware.py b/core/tests/test_middleware.py index ce72029281..c8d43d865a 100644 --- a/core/tests/test_middleware.py +++ b/core/tests/test_middleware.py @@ -1,8 +1,5 @@ -from django.test import TestCase from prometheus_client import REGISTRY -from core.middleware import USER_AGENT_METRICS - # TODO: consolidate with worker/helpers/tests/unit/test_checkpoint_logger.py into shared repo class CounterAssertion: diff --git a/core/tests/test_signals.py b/core/tests/test_signals.py index e6f4ec0ec4..17c5d8bfc6 100644 --- a/core/tests/test_signals.py +++ b/core/tests/test_signals.py @@ -3,6 +3,7 @@ import pytest from django.test import override_settings +from shared.django_apps.core.tests.factories import CommitFactory from core.tests.factories import RepositoryFactory @@ -45,3 +46,35 @@ def test_shelter_repo_sync(mocker): publish_calls = publish.call_args_list # does not trigger another publish assert len(publish_calls) == 2 + + +@override_settings( + SHELTER_PUBSUB_PROJECT_ID="test-project-id", + SHELTER_PUBSUB_SYNC_REPO_TOPIC_ID="test-topic-id", +) +@pytest.mark.django_db +def test_shelter_commit_sync(mocker): + # this prevents the pubsub SDK from trying to load credentials + os.environ["PUBSUB_EMULATOR_HOST"] = "localhost" + publish = mocker.patch("google.cloud.pubsub_v1.PublisherClient.publish") + + # this triggers the publish via Django signals - has to have this format + commit = CommitFactory(id=167829367, branch="random:branch") + + publish_calls = publish.call_args_list + # Twice cause there's a signal triggered when the commit factory creates a Repository + # which can't be null + assert len(publish_calls) == 2 + + # triggers publish on update + assert publish_calls[1] == call( + "projects/test-project-id/topics/test-topic-id", + b'{"type": "commit", "sync": "one", "id": 167829367}', + ) + + commit.branch = "normal-incompatible-branch" + commit.save() + + publish_calls = publish.call_args_list + # does not trigger another publish since unchanged length + assert len(publish_calls) == 2 diff --git a/docker/Dockerfile b/docker/Dockerfile index 51df17b01e..2b79c1e30a 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -8,6 +8,7 @@ FROM us-docker.pkg.dev/berglas/berglas/berglas:$BERGLAS_VERSION as berglas FROM $REQUIREMENTS_IMAGE as app COPY . /app WORKDIR /app +RUN pip install setuptools==71.1.0 RUN python manage.py collectstatic --no-input FROM app as local diff --git a/docker/Dockerfile.requirements b/docker/Dockerfile.requirements index f3997575b2..71172f86a1 100644 --- a/docker/Dockerfile.requirements +++ b/docker/Dockerfile.requirements @@ -4,13 +4,6 @@ ARG PYTHON_IMAGE=python:3.12-slim-bookworm # BUILD STAGE - Download dependencies from GitHub that require SSH access FROM $PYTHON_IMAGE as build - -# Pinning a specific nightly version so that builds don't suddenly break if a -# "this feature is now stabilized" warning is promoted to an error or something. -# We would like to keep up with nightly if we can. -ARG RUST_VERSION=nightly-2024-02-22 -ENV RUST_VERSION=${RUST_VERSION} - RUN apt-get update RUN apt-get install -y \ build-essential \ @@ -18,11 +11,6 @@ RUN apt-get install -y \ libpq-dev \ curl -# Install Rust -RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs \ - | bash -s -- -y --default-toolchain $RUST_VERSION -ENV PATH="/root/.cargo/bin:$PATH" - COPY requirements.txt / WORKDIR /pip-packages/ RUN pip wheel -r /requirements.txt diff --git a/graphql_api/actions/commits.py b/graphql_api/actions/commits.py index 0b828b8375..8378014420 100644 --- a/graphql_api/actions/commits.py +++ b/graphql_api/actions/commits.py @@ -122,14 +122,14 @@ def repo_commits( queryset = queryset.filter(state__in=states) coverage_status = filters.get("coverage_status") + if coverage_status: - to_be_included = [] - for commit in queryset: - if ( - commit_status(commit, CommitReport.ReportType.COVERAGE) - in coverage_status - ): - to_be_included.append(commit.id) + to_be_included = [ + commit.id + for commit in queryset + if commit_status(commit, CommitReport.ReportType.COVERAGE) + in coverage_status + ] queryset = queryset.filter(id__in=to_be_included) # We need `deleted is not true` in order for the query to use the right index. diff --git a/graphql_api/actions/comparison.py b/graphql_api/actions/comparison.py index 206dc06dab..4ad267a7b4 100644 --- a/graphql_api/actions/comparison.py +++ b/graphql_api/actions/comparison.py @@ -1,13 +1,11 @@ from typing import Optional, Union -from codecov.db import sync_to_async from compare.models import CommitComparison from graphql_api.types.comparison.comparison import ( MissingBaseReport, MissingComparison, MissingHeadReport, ) -from services.comparison import Comparison, PullRequestComparison def validate_commit_comparison( diff --git a/graphql_api/actions/components.py b/graphql_api/actions/components.py index 2ab07b1d00..096689dba9 100644 --- a/graphql_api/actions/components.py +++ b/graphql_api/actions/components.py @@ -3,7 +3,6 @@ from django.db.models import QuerySet -from codecov_auth.models import Owner from core.models import Repository from graphql_api.actions.measurements import ( measurements_by_ids, diff --git a/graphql_api/actions/measurements.py b/graphql_api/actions/measurements.py index 465bd62509..77aaa8c21b 100644 --- a/graphql_api/actions/measurements.py +++ b/graphql_api/actions/measurements.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Iterable, Mapping, Optional +from typing import Any, Dict, Iterable, List, Optional from django.db.models import Max, QuerySet @@ -13,19 +13,23 @@ def measurements_by_ids( measurable_name: str, measurable_ids: Iterable[str], interval: Interval, - after: datetime, before: datetime, + after: Optional[datetime] = None, branch: Optional[str] = None, -) -> Mapping[int, Iterable[dict]]: +) -> Dict[int, List[Dict[str, Any]]]: queryset = MeasurementSummary.agg_by(interval).filter( name=measurable_name, owner_id=repository.author_id, repo_id=repository.pk, measurable_id__in=measurable_ids, - timestamp_bin__gte=aligned_start_date(interval, after), timestamp_bin__lte=before, ) + if after is not None: + queryset = queryset.filter( + timestamp_bin__gte=aligned_start_date(interval, after) + ) + if branch: queryset = queryset.filter(branch=branch) @@ -34,7 +38,7 @@ def measurements_by_ids( ) # group by measurable_id - measurements = {} + measurements: Dict[int, List[Dict[str, Any]]] = {} for measurement in queryset: measurable_id = measurement["measurable_id"] if measurable_id not in measurements: @@ -44,6 +48,30 @@ def measurements_by_ids( return measurements +def measurements_last_uploaded_before_start_date( + owner_id: int, + repo_id: int, + measurable_name: str, + measurable_id: int, + start_date: datetime, + branch: Optional[str] = None, +) -> QuerySet: + queryset = Measurement.objects.filter( + owner_id=owner_id, + repo_id=repo_id, + measurable_id=measurable_id, + name=measurable_name, + timestamp__lt=start_date, + ) + + if branch: + queryset = queryset.filter(branch=branch) + + return queryset.values("measurable_id", "value").annotate( + last_uploaded=Max("timestamp") + ) + + def measurements_last_uploaded_by_ids( owner_id: int, repo_id: int, diff --git a/graphql_api/actions/owner.py b/graphql_api/actions/owner.py index aced5cfd0d..5514393162 100644 --- a/graphql_api/actions/owner.py +++ b/graphql_api/actions/owner.py @@ -19,7 +19,11 @@ def get_owner(service, username): raise MissingService() long_service = get_long_service_name(service) - return Owner.objects.filter(username=username, service=long_service).first() + return ( + Owner.objects.filter(username=username, service=long_service) + .prefetch_related("account") + .first() + ) def get_owner_login_sessions(current_user): diff --git a/graphql_api/actions/path_contents.py b/graphql_api/actions/path_contents.py index f0501a590b..63338eeb9c 100644 --- a/graphql_api/actions/path_contents.py +++ b/graphql_api/actions/path_contents.py @@ -14,7 +14,7 @@ def partition_list_into_files_and_directories( # Separate files and directories for item in items: - if type(item) == Dir: + if isinstance(item, Dir): directories.append(item) else: files.append(item) diff --git a/graphql_api/actions/repository.py b/graphql_api/actions/repository.py index 464680908d..a81d94afd6 100644 --- a/graphql_api/actions/repository.py +++ b/graphql_api/actions/repository.py @@ -1,8 +1,14 @@ -from codecov_auth.models import Owner -from core.models import Repository +import logging +from typing import Any +from django.db.models import QuerySet +from shared.django_apps.codecov_auth.models import Owner +from shared.django_apps.core.models import Repository -def apply_filters_to_queryset(queryset, filters): +log = logging.getLogger(__name__) + + +def apply_filters_to_queryset(queryset, filters: dict[str, Any]) -> QuerySet: filters = filters or {} term = filters.get("term") active = filters.get("active") @@ -24,22 +30,40 @@ def apply_filters_to_queryset(queryset, filters): return queryset -def list_repository_for_owner(current_owner: Owner, owner: Owner, filters): +def list_repository_for_owner( + current_owner: Owner, + owner: Owner, + filters: dict[str, Any] | None, + okta_account_auths: list[int], + exclude_okta_enforced_repos: bool = True, +) -> QuerySet: + queryset = Repository.objects.viewable_repos(current_owner) + + if exclude_okta_enforced_repos: + queryset = queryset.exclude_accounts_enforced_okta(okta_account_auths) + queryset = ( - Repository.objects.viewable_repos(current_owner) - .with_recent_coverage() - .with_latest_commit_at() - .filter(author=owner) + queryset.with_recent_coverage().with_latest_commit_at().filter(author=owner) ) + queryset = apply_filters_to_queryset(queryset, filters) return queryset -def search_repos(current_owner, filters): +def search_repos( + current_owner: Owner, + filters: dict[str, Any] | None, + okta_account_auths: list[int], + exclude_okta_enforced_repos: bool = True, +) -> QuerySet: authors_from = [current_owner.ownerid] + (current_owner.organizations or []) + queryset = Repository.objects.viewable_repos(current_owner) + + if exclude_okta_enforced_repos: + queryset = queryset.exclude_accounts_enforced_okta(okta_account_auths) + queryset = ( - Repository.objects.viewable_repos(current_owner) - .with_recent_coverage() + queryset.with_recent_coverage() .with_latest_commit_at() .filter(author__ownerid__in=authors_from) ) diff --git a/graphql_api/dataloader/bundle_analysis.py b/graphql_api/dataloader/bundle_analysis.py index 5204234aa9..5349dd7850 100644 --- a/graphql_api/dataloader/bundle_analysis.py +++ b/graphql_api/dataloader/bundle_analysis.py @@ -1,3 +1,5 @@ +from typing import Union + from shared.bundle_analysis import ( BundleAnalysisReportLoader, MissingBaseReportError, @@ -14,7 +16,7 @@ def load_bundle_analysis_comparison( base_commit: Commit, head_commit: Commit -) -> BundleAnalysisComparison: +) -> Union[BundleAnalysisComparison, MissingHeadReport, MissingBaseReport]: head_report = CommitReport.objects.filter( report_type=CommitReport.ReportType.BUNDLE_ANALYSIS, commit=head_commit ).first() @@ -37,6 +39,7 @@ def load_bundle_analysis_comparison( loader=loader, base_report_key=base_report.external_id, head_report_key=head_report.external_id, + repository=head_commit.repository, ) except MissingBaseReportError: return MissingBaseReport() @@ -44,7 +47,9 @@ def load_bundle_analysis_comparison( return MissingHeadReport() -def load_bundle_analysis_report(commit: Commit) -> BundleAnalysisReport: +def load_bundle_analysis_report( + commit: Commit, +) -> Union[BundleAnalysisReport, MissingHeadReport, MissingBaseReport]: report = CommitReport.objects.filter( report_type=CommitReport.ReportType.BUNDLE_ANALYSIS, commit=commit ).first() diff --git a/graphql_api/dataloader/commit.py b/graphql_api/dataloader/commit.py index 335a1b1f31..3f957a1800 100644 --- a/graphql_api/dataloader/commit.py +++ b/graphql_api/dataloader/commit.py @@ -1,4 +1,4 @@ -from django.db.models import Prefetch, Q +from django.db.models import Prefetch from core.models import Commit from reports.models import CommitReport diff --git a/graphql_api/dataloader/tests/test_commit.py b/graphql_api/dataloader/tests/test_commit.py index 1951f2ee92..d474d75516 100644 --- a/graphql_api/dataloader/tests/test_commit.py +++ b/graphql_api/dataloader/tests/test_commit.py @@ -1,6 +1,4 @@ -import asyncio - -from django.test import TestCase, TransactionTestCase +from django.test import TransactionTestCase from core.tests.factories import CommitFactory, PullFactory, RepositoryFactory from graphql_api.dataloader.commit import CommitLoader diff --git a/graphql_api/dataloader/tests/test_loader.py b/graphql_api/dataloader/tests/test_loader.py index b558754ee1..13b9a741f6 100644 --- a/graphql_api/dataloader/tests/test_loader.py +++ b/graphql_api/dataloader/tests/test_loader.py @@ -19,7 +19,7 @@ def setUp(self): async def test_unimplemented_load(self): loader = BaseLoader.loader(self.info) - with pytest.raises(NotImplementedError) as err_info: + with pytest.raises(NotImplementedError): await loader.load(self.record.id) async def test_default_key(self): diff --git a/graphql_api/helpers/connection.py b/graphql_api/helpers/connection.py index cca15a5039..5b1aa98808 100644 --- a/graphql_api/helpers/connection.py +++ b/graphql_api/helpers/connection.py @@ -68,6 +68,49 @@ def page_info(self, *args, **kwargs): } +class DictCursorPaginator(CursorPaginator): + """ + WARNING: DictCursorPaginator does not work for dict objects where a key contains the following string: "__" + TODO: if instance is a dictionary and not an object, don't split the ordering + + ordering = "test__name" + Django object: + -> obj.test.name + + Dict: + -> obj["test"]["name"] X wrong + we want obj["test__name"] + + overrides CursorPaginator's position_from_instance method + because it assumes that instance's fields are attributes on the + instance. This doesn't work with the aggregate_test_results query + because since it uses annotate() and values() the instance is actually + a dict and the fields are keys in that dict. + + So if getattr fails to find the attribute on the instance then we try getting the "attr" + via a dict access + + if the dict access fails then it throws an exception, although it would be a different + """ + + def position_from_instance(self, instance): + position = [] + for order in self.ordering: + parts = order.lstrip("-").split("__") + attr = instance + while parts: + try: + attr = getattr(attr, parts[0]) + except AttributeError as attr_err: + try: + attr = attr[parts[0]] + except (KeyError, TypeError): + raise attr_err from None + parts.pop(0) + position.append(str(attr)) + return position + + def queryset_to_connection_sync( queryset, *, @@ -85,8 +128,9 @@ def queryset_to_connection_sync( first = 25 ordering = tuple(field_order(field, ordering_direction) for field in ordering) - paginator = CursorPaginator(queryset, ordering=ordering) + paginator = DictCursorPaginator(queryset, ordering=ordering) page = paginator.page(first=first, after=after, last=last, before=before) + return Connection(queryset, paginator, page) diff --git a/graphql_api/helpers/lookahead.py b/graphql_api/helpers/lookahead.py index 2e5af27002..6de145bab9 100644 --- a/graphql_api/helpers/lookahead.py +++ b/graphql_api/helpers/lookahead.py @@ -6,9 +6,7 @@ SelectionSetNode, VariableNode, ) -from graphql.type import GraphQLInputType from graphql.type.definition import GraphQLResolveInfo -from graphql.utilities.value_from_ast import value_from_ast class LookaheadNode: @@ -56,7 +54,7 @@ def _flatten_selections(self, selection_set: SelectionSetNode) -> Iterable[Node] if isinstance(selection, FragmentSpreadNode): fragment = self.info.fragments[selection.name.value] for selection in fragment.selection_set.selections: - selections.append(selection) + selections.append(selection) # noqa: PERF402 else: selections.append(selection) return selections diff --git a/graphql_api/helpers/tests/test_mutation.py b/graphql_api/helpers/tests/test_mutation.py index 690e80e319..2b118f8469 100644 --- a/graphql_api/helpers/tests/test_mutation.py +++ b/graphql_api/helpers/tests/test_mutation.py @@ -71,4 +71,4 @@ def resolver(): raise AttributeError() with self.assertRaises(AttributeError): - resolved_value = await resolver() + await resolver() diff --git a/graphql_api/schema.py b/graphql_api/schema.py index 2fb7349c48..e37b22cb04 100644 --- a/graphql_api/schema.py +++ b/graphql_api/schema.py @@ -1,8 +1,7 @@ -from ariadne import make_executable_schema, snake_case_fallback_resolvers +from ariadne import make_executable_schema from .types import bindables, types -# snake_case_fallbck_resolvers gives use a default resolver which convert automatically -# the field name from camelCase to snake_case and try to get it from the object -# see https://ariadnegraphql.org/docs/resolvers#fallback-resolvers -schema = make_executable_schema(types, *bindables, snake_case_fallback_resolvers) +# convert_names_case automatically converts the field name from camelCase +# to snake_case. See: https://ariadnegraphql.org/docs/api-reference#optional-arguments-10 +schema = make_executable_schema(types, *bindables, convert_names_case=True) diff --git a/graphql_api/tests/actions/test_commits.py b/graphql_api/tests/actions/test_commits.py index 3bdac47e0d..81839b84bc 100644 --- a/graphql_api/tests/actions/test_commits.py +++ b/graphql_api/tests/actions/test_commits.py @@ -1,6 +1,5 @@ from collections import Counter -from asgiref.sync import async_to_sync from django.test import TransactionTestCase from codecov_auth.tests.factories import OwnerFactory diff --git a/graphql_api/tests/helper.py b/graphql_api/tests/helper.py index 1da177049e..0a7f9da271 100644 --- a/graphql_api/tests/helper.py +++ b/graphql_api/tests/helper.py @@ -1,5 +1,8 @@ -from unittest.mock import patch +from http.cookies import SimpleCookie +from shared.django_apps.codecov_auth.tests.factories import OwnerFactory, UserFactory + +from codecov_auth.views.okta_cloud import OKTA_SIGNED_IN_ACCOUNTS_SESSION_KEY from utils.test_utils import Client @@ -11,12 +14,27 @@ def gql_request( owner=None, variables=None, with_errors=False, + okta_signed_in_accounts=[], + impersonate_owner=False, ): url = f"/graphql/{provider}" if owner: self.client = Client() - self.client.force_login_owner(owner) + + if impersonate_owner: + staff_owner = OwnerFactory( + name="staff_user", service="github", user=UserFactory(is_staff=True) + ) + self.client.cookies = SimpleCookie({"staff_user": owner.pk}) + self.client.force_login_owner(staff_owner) + else: + self.client.force_login_owner(owner) + + if okta_signed_in_accounts: + session = self.client.session + session[OKTA_SIGNED_IN_ACCOUNTS_SESSION_KEY] = okta_signed_in_accounts + session.save() response = self.client.post( url, diff --git a/graphql_api/tests/mutation/test_encode_secret_string.py b/graphql_api/tests/mutation/test_encode_secret_string.py new file mode 100644 index 0000000000..c0812e13bf --- /dev/null +++ b/graphql_api/tests/mutation/test_encode_secret_string.py @@ -0,0 +1,44 @@ +from django.test import TransactionTestCase +from shared.encryption.yaml_secret import yaml_secret_encryptor + +from codecov_auth.tests.factories import OwnerFactory +from core.tests.factories import RepositoryFactory +from graphql_api.tests.helper import GraphQLTestHelper + +query = """ +mutation($input: EncodeSecretStringInput!) { + encodeSecretString(input: $input) { + value + error { + __typename + ... on ResolverError { + message + } + } + } +} +""" + + +class TestEncodeSecretString(TransactionTestCase, GraphQLTestHelper): + def _request(self): + data = self.gql_request( + query, + owner=self.org, + variables={"input": {"repoName": "test-repo", "value": "token-1"}}, + ) + return data["encodeSecretString"]["value"] + + def setUp(self): + self.org = OwnerFactory(username="test-org") + self.repo = RepositoryFactory( + name="test-repo", + author=self.org, + private=True, + ) + self.owner = OwnerFactory(permission=[self.repo.pk]) + + def test_encoded_secret_string(self): + res = self._request() + check_encryptor = yaml_secret_encryptor + assert "token-1" in check_encryptor.decode(res[7:]) diff --git a/graphql_api/tests/mutation/test_erase_repository.py b/graphql_api/tests/mutation/test_erase_repository.py index 078e87e33a..65889fcfe7 100644 --- a/graphql_api/tests/mutation/test_erase_repository.py +++ b/graphql_api/tests/mutation/test_erase_repository.py @@ -1,9 +1,9 @@ -from unittest.mock import PropertyMock, patch +from unittest.mock import patch from django.test import TransactionTestCase, override_settings from codecov_auth.tests.factories import OwnerFactory -from core.tests.factories import RepositoryFactory, RepositoryTokenFactory +from core.tests.factories import RepositoryFactory from graphql_api.tests.helper import GraphQLTestHelper query = """ diff --git a/graphql_api/tests/mutation/test_regenerate_repository_upload_token.py b/graphql_api/tests/mutation/test_regenerate_repository_upload_token.py index c5db324587..454cc4bb17 100644 --- a/graphql_api/tests/mutation/test_regenerate_repository_upload_token.py +++ b/graphql_api/tests/mutation/test_regenerate_repository_upload_token.py @@ -1,7 +1,7 @@ -from django.test import TransactionTestCase, override_settings +from django.test import TransactionTestCase from codecov_auth.tests.factories import OwnerFactory -from core.tests.factories import BranchFactory, RepositoryFactory +from core.tests.factories import RepositoryFactory from graphql_api.tests.helper import GraphQLTestHelper query = """ @@ -22,7 +22,7 @@ class RegenerateRepositoryUploadTokenTests(GraphQLTestHelper, TransactionTestCase): def setUp(self): self.org = OwnerFactory(username="codecov") - self.repo = RepositoryFactory(author=self.org, name="gazebo", active=True) + self.repo = RepositoryFactory(author=self.org, name="gazebo") self.old_repo_token = self.repo.upload_token def test_when_authenticated_updates_token(self): diff --git a/graphql_api/tests/mutation/test_revoke_user_token.py b/graphql_api/tests/mutation/test_revoke_user_token.py index 4f0eb5d2e4..3dce68565a 100644 --- a/graphql_api/tests/mutation/test_revoke_user_token.py +++ b/graphql_api/tests/mutation/test_revoke_user_token.py @@ -28,6 +28,6 @@ def test_authenticated(self): data = self.gql_request( query, owner=self.owner, variables={"input": {"tokenid": tokenid}} ) - assert data["revokeUserToken"] == None + assert data["revokeUserToken"] is None deleted_user_token = self.owner.user_tokens.filter(external_id=tokenid).first() assert deleted_user_token is None diff --git a/graphql_api/tests/mutation/test_save_okta_config.py b/graphql_api/tests/mutation/test_save_okta_config.py new file mode 100644 index 0000000000..97693991db --- /dev/null +++ b/graphql_api/tests/mutation/test_save_okta_config.py @@ -0,0 +1,59 @@ +from django.test import TransactionTestCase + +from codecov_auth.models import OktaSettings +from codecov_auth.tests.factories import AccountFactory, OwnerFactory +from graphql_api.tests.helper import GraphQLTestHelper + +query = """ +mutation($input: SaveOktaConfigInput!) { + saveOktaConfig(input: $input) { + error { + __typename + } + } +} +""" + + +class SaveOktaConfigTestCase(GraphQLTestHelper, TransactionTestCase): + def setUp(self): + self.current_user = OwnerFactory(username="codecov-user") + self.owner = OwnerFactory( + username="codecov-owner", + admins=[self.current_user.ownerid], + account=AccountFactory(), + ) + + def test_when_unauthenticated(self): + data = self.gql_request( + query, + variables={ + "input": { + "clientId": "some-client-id", + "clientSecret": "some-client-secret", + "url": "https://okta.example.com", + "enabled": True, + "enforced": True, + "orgUsername": self.owner.username, + } + }, + ) + assert data["saveOktaConfig"]["error"]["__typename"] == "UnauthenticatedError" + + def test_when_authenticated(self): + data = self.gql_request( + query, + owner=self.owner, + variables={ + "input": { + "clientId": "some-client-id", + "clientSecret": "some-client-secret", + "url": "https://okta.example.com", + "enabled": True, + "enforced": True, + "orgUsername": self.owner.username, + } + }, + ) + assert OktaSettings.objects.filter(account=self.owner.account).exists() + assert data["saveOktaConfig"] is None diff --git a/graphql_api/tests/mutation/test_save_terms_agreement.py b/graphql_api/tests/mutation/test_save_terms_agreement.py index e1d3f753be..e237a1b3bc 100644 --- a/graphql_api/tests/mutation/test_save_terms_agreement.py +++ b/graphql_api/tests/mutation/test_save_terms_agreement.py @@ -1,7 +1,5 @@ -import pytest from django.test import TransactionTestCase -from codecov.commands.exceptions import ValidationError from codecov_auth.tests.factories import OwnerFactory from graphql_api.tests.helper import GraphQLTestHelper diff --git a/graphql_api/tests/mutation/test_set_yaml_on_owner.py b/graphql_api/tests/mutation/test_set_yaml_on_owner.py index facc2e4fcc..6ace86d5e5 100644 --- a/graphql_api/tests/mutation/test_set_yaml_on_owner.py +++ b/graphql_api/tests/mutation/test_set_yaml_on_owner.py @@ -1,10 +1,8 @@ import asyncio from unittest.mock import patch -import pytest from django.test import TransactionTestCase -from codecov.db import sync_to_async from codecov_auth.tests.factories import OwnerFactory from graphql_api.tests.helper import GraphQLTestHelper @@ -39,4 +37,4 @@ def test_mutation_dispatch_to_command(self, command_mock): } data = self.gql_request(query, owner=self.owner, variables={"input": input}) command_mock.assert_called_once_with(input["username"], input["yaml"]) - data["setYamlOnOwner"]["owner"]["username"] == self.owner.username + assert data["setYamlOnOwner"]["owner"]["username"] == self.owner.username diff --git a/graphql_api/tests/mutation/test_store_codecov_metrics.py b/graphql_api/tests/mutation/test_store_codecov_metrics.py new file mode 100644 index 0000000000..a49f41c4a6 --- /dev/null +++ b/graphql_api/tests/mutation/test_store_codecov_metrics.py @@ -0,0 +1,108 @@ +from django.test import TransactionTestCase +from shared.django_apps.codecov_metrics.models import UserOnboardingLifeCycleMetrics + +from codecov_auth.tests.factories import OwnerFactory +from graphql_api.tests.helper import GraphQLTestHelper + +query = """ + mutation($input: StoreEventMetricsInput!) { + storeEventMetric(input: $input) { + error { + __typename + ... on ResolverError { + message + } + } + } + } +""" + + +class StoreEventMetricMutationTest(GraphQLTestHelper, TransactionTestCase): + def _request(self, org_username: str, event: str, json_payload: str, owner=None): + return self.gql_request( + query, + variables={ + "input": { + "orgUsername": org_username, + "eventName": event, + "jsonPayload": json_payload, + } + }, + owner=owner, + ) + + def setUp(self): + self.owner = OwnerFactory(username="codecov-user") + + def test_unauthenticated(self): + response = self._request( + org_username="codecov-user", + event="VISITED_PAGE", + json_payload='{"key": "value"}', + ) + assert response == { + "storeEventMetric": { + "error": { + "__typename": "UnauthenticatedError", + "message": "You are not authenticated", + } + } + } + + def test_authenticated_inserts_into_db(self): + self._request( + org_username="codecov-user", + event="VISITED_PAGE", + json_payload='{"some-key": "some-value"}', + owner=self.owner, + ) + metric = UserOnboardingLifeCycleMetrics.objects.filter( + event="VISITED_PAGE" + ).first() + self.assertIsNotNone(metric) + self.assertEqual(metric.additional_data, {"some-key": "some-value"}) + + def test_invalid_org(self): + response = self._request( + org_username="invalid_org", + event="VISITED_PAGE", + json_payload='{"key": "value"}', + owner=self.owner, + ) + assert response == { + "storeEventMetric": { + "error": { + "__typename": "ValidationError", + "message": "Cannot find owner record in the database", + } + } + } + + def test_invalid_event(self): + self._request( + org_username="codecov-user", + event="INVALID_EVENT", + json_payload='{"key": "value"}', + owner=self.owner, + ) + metric = UserOnboardingLifeCycleMetrics.objects.filter( + event="INVALID_EVENT" + ).first() + self.assertIsNone(metric) + + def test_invalid_json_string(self): + response = self._request( + org_username="codecov-user", + event="VISITED_PAGE", + json_payload="invalid-json", + owner=self.owner, + ) + assert response == { + "storeEventMetric": { + "error": { + "__typename": "ValidationError", + "message": "Invalid JSON string", + } + } + } diff --git a/graphql_api/tests/mutation/test_update_profile.py b/graphql_api/tests/mutation/test_update_profile.py index 130751ce5b..ed2d942f1b 100644 --- a/graphql_api/tests/mutation/test_update_profile.py +++ b/graphql_api/tests/mutation/test_update_profile.py @@ -1,6 +1,5 @@ from django.test import TransactionTestCase -from codecov_auth.models import Session from codecov_auth.tests.factories import OwnerFactory from graphql_api.tests.helper import GraphQLTestHelper diff --git a/graphql_api/tests/mutation/test_update_repository.py b/graphql_api/tests/mutation/test_update_repository.py index 38766a9cee..81f97fae6a 100644 --- a/graphql_api/tests/mutation/test_update_repository.py +++ b/graphql_api/tests/mutation/test_update_repository.py @@ -1,4 +1,4 @@ -from django.test import TransactionTestCase, override_settings +from django.test import TransactionTestCase from codecov_auth.tests.factories import OwnerFactory from core.tests.factories import BranchFactory, RepositoryFactory @@ -53,9 +53,7 @@ def test_when_authenticated_update_activated(self): assert data == {"updateRepository": None} def test_when_authenticated_update_branch(self): - other_branch = BranchFactory.create( - name="some other branch", repository=self.repo - ) + BranchFactory.create(name="some other branch", repository=self.repo) data = self.gql_request( query, owner=self.org, diff --git a/graphql_api/tests/mutation/test_update_self_hosted_settings.py b/graphql_api/tests/mutation/test_update_self_hosted_settings.py index 065b2fe32a..34c7c59cb2 100644 --- a/graphql_api/tests/mutation/test_update_self_hosted_settings.py +++ b/graphql_api/tests/mutation/test_update_self_hosted_settings.py @@ -1,7 +1,5 @@ -import pytest from django.test import TransactionTestCase, override_settings -from codecov.commands.exceptions import ValidationError from codecov_auth.tests.factories import OwnerFactory from graphql_api.tests.helper import GraphQLTestHelper diff --git a/graphql_api/tests/test_account.py b/graphql_api/tests/test_account.py new file mode 100644 index 0000000000..27ad8ef16d --- /dev/null +++ b/graphql_api/tests/test_account.py @@ -0,0 +1,43 @@ +from django.test import TransactionTestCase + +from codecov_auth.tests.factories import ( + AccountFactory, + OktaSettingsFactory, + OwnerFactory, +) + +from .helper import GraphQLTestHelper + + +class AccountTestCase(GraphQLTestHelper, TransactionTestCase): + def setUp(self): + self.account = AccountFactory(name="Test Account") + self.owner = OwnerFactory( + username="randomOwner", service="github", account=self.account + ) + self.okta_settings = OktaSettingsFactory( + account=self.account, + client_id="test-client-id", + client_secret="test-client-secret", + ) + + def test_fetch_okta_config(self) -> None: + query = """ + query { + owner(username: "%s"){ + account { + oktaConfig { + clientId + clientSecret + } + } + } + } + """ % (self.owner.username) + + result = self.gql_request(query, owner=self.owner) + + assert "errors" not in result + data = result["owner"]["account"] + assert data["oktaConfig"]["clientId"] == "test-client-id" + assert data["oktaConfig"]["clientSecret"] == "test-client-secret" diff --git a/graphql_api/tests/test_bundle_analysis_measurements.py b/graphql_api/tests/test_bundle_analysis_measurements.py index 8488635433..e12e0a63f1 100644 --- a/graphql_api/tests/test_bundle_analysis_measurements.py +++ b/graphql_api/tests/test_bundle_analysis_measurements.py @@ -840,3 +840,1585 @@ def test_bundle_asset_measurements(self, get_storage_service): }, }, } + + @patch("graphql_api.dataloader.bundle_analysis.get_appropriate_storage_service") + def test_bundle_report_measurements_carryovers(self, get_storage_service): + storage = MemoryStorageService({}) + get_storage_service.return_value = storage + + with open("./services/tests/samples/bundle_with_uuid.sqlite", "rb") as f: + storage_path = StoragePaths.bundle_report.path( + repo_key=ArchiveService.get_archive_hash(self.repo), + report_key=self.head_commit_report.external_id, + ) + storage.write_file(get_bucket_name(), storage_path, f) + + query = """ + query FetchMeasurements( + $org: String!, + $repo: String!, + $commit: String! + $filters: BundleAnalysisMeasurementsSetFilters + $orderingDirection: OrderingDirection! + $interval: MeasurementInterval! + $before: DateTime! + $after: DateTime! + ) { + owner(username: $org) { + repository(name: $repo) { + ... on Repository { + commit(id: $commit) { + bundleAnalysisReport { + __typename + ... on BundleAnalysisReport { + bundle(name: "super") { + name + measurements( + filters: $filters + orderingDirection: $orderingDirection + after: $after + interval: $interval + before: $before + ){ + assetType + name + size { + loadTime { + threeG + highSpeed + } + size { + gzip + uncompress + } + } + change { + loadTime { + threeG + highSpeed + } + size { + gzip + uncompress + } + } + measurements { + avg + min + max + timestamp + } + } + } + } + } + } + } + } + } + } + """ + + # Test without using asset type filters + variables = { + "org": self.org.username, + "repo": self.repo.name, + "commit": self.commit.commitid, + "orderingDirection": "ASC", + "interval": "INTERVAL_1_DAY", + "after": "2024-06-07", + "before": "2024-06-10", + "filters": {}, + } + data = self.gql_request(query, variables=variables) + commit = data["owner"]["repository"]["commit"] + + assert commit["bundleAnalysisReport"] == { + "__typename": "BundleAnalysisReport", + "bundle": { + "measurements": [ + { + "assetType": "ASSET_SIZE", + "change": { + "loadTime": { + "highSpeed": 2, + "threeG": 106, + }, + "size": { + "gzip": 10, + "uncompress": 10000, + }, + }, + "measurements": [ + { + "avg": 4126.0, + "max": 4126.0, + "min": 4126.0, + "timestamp": "2024-06-07T00:00:00+00:00", + }, + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-08T00:00:00+00:00", + }, + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-09T00:00:00+00:00", + }, + { + "avg": 14126.0, + "max": 14126.0, + "min": 14126.0, + "timestamp": "2024-06-10T00:00:00+00:00", + }, + ], + "name": "asset-*.js", + "size": { + "loadTime": { + "highSpeed": 3, + "threeG": 150, + }, + "size": { + "gzip": 14, + "uncompress": 14126, + }, + }, + }, + { + "assetType": "ASSET_SIZE", + "change": { + "loadTime": { + "highSpeed": 2, + "threeG": 106, + }, + "size": { + "gzip": 10, + "uncompress": 10000, + }, + }, + "measurements": [ + { + "avg": 1421.0, + "max": 1421.0, + "min": 1421.0, + "timestamp": "2024-06-07T00:00:00+00:00", + }, + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-08T00:00:00+00:00", + }, + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-09T00:00:00+00:00", + }, + { + "avg": 11421.0, + "max": 11421.0, + "min": 11421.0, + "timestamp": "2024-06-10T00:00:00+00:00", + }, + ], + "name": "asset-*.js", + "size": { + "loadTime": { + "highSpeed": 3, + "threeG": 121, + }, + "size": { + "gzip": 11, + "uncompress": 11421, + }, + }, + }, + { + "assetType": "ASSET_SIZE", + "change": None, + "measurements": [], + "name": "asset-*.js", + "size": None, + }, + { + "assetType": "FONT_SIZE", + "change": { + "loadTime": { + "highSpeed": 0, + "threeG": 2, + }, + "size": { + "gzip": 0, + "uncompress": 240, + }, + }, + "measurements": [ + { + "avg": 50.0, + "max": 50.0, + "min": 50.0, + "timestamp": "2024-06-07T00:00:00+00:00", + }, + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-08T00:00:00+00:00", + }, + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-09T00:00:00+00:00", + }, + { + "avg": 290.0, + "max": 290.0, + "min": 290.0, + "timestamp": "2024-06-10T00:00:00+00:00", + }, + ], + "name": None, + "size": { + "loadTime": { + "highSpeed": 0, + "threeG": 3, + }, + "size": { + "gzip": 0, + "uncompress": 290, + }, + }, + }, + { + "assetType": "IMAGE_SIZE", + "change": { + "loadTime": { + "highSpeed": 0, + "threeG": 25, + }, + "size": { + "gzip": 2, + "uncompress": 2400, + }, + }, + "measurements": [ + { + "avg": 500.0, + "max": 500.0, + "min": 500.0, + "timestamp": "2024-06-07T00:00:00+00:00", + }, + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-08T00:00:00+00:00", + }, + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-09T00:00:00+00:00", + }, + { + "avg": 2900.0, + "max": 2900.0, + "min": 2900.0, + "timestamp": "2024-06-10T00:00:00+00:00", + }, + ], + "name": None, + "size": { + "loadTime": { + "highSpeed": 0, + "threeG": 30, + }, + "size": { + "gzip": 2, + "uncompress": 2900, + }, + }, + }, + { + "assetType": "JAVASCRIPT_SIZE", + "change": { + "loadTime": { + "highSpeed": 5, + "threeG": 224, + }, + "size": { + "gzip": 21, + "uncompress": 21000, + }, + }, + "measurements": [ + { + "avg": 5708.0, + "max": 5708.0, + "min": 5708.0, + "timestamp": "2024-06-07T00:00:00+00:00", + }, + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-08T00:00:00+00:00", + }, + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-09T00:00:00+00:00", + }, + { + "avg": 26708.0, + "max": 26708.0, + "min": 26708.0, + "timestamp": "2024-06-10T00:00:00+00:00", + }, + ], + "name": None, + "size": { + "loadTime": { + "highSpeed": 7, + "threeG": 284, + }, + "size": { + "gzip": 26, + "uncompress": 26708, + }, + }, + }, + { + "assetType": "REPORT_SIZE", + "change": { + "loadTime": { + "highSpeed": 6, + "threeG": 252, + }, + "size": { + "gzip": 23, + "uncompress": 23664, + }, + }, + "measurements": [ + { + "avg": 6263.0, + "max": 6263.0, + "min": 6263.0, + "timestamp": "2024-06-07T00:00:00+00:00", + }, + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-08T00:00:00+00:00", + }, + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-09T00:00:00+00:00", + }, + { + "avg": 29927.0, + "max": 29927.0, + "min": 29927.0, + "timestamp": "2024-06-10T00:00:00+00:00", + }, + ], + "name": None, + "size": { + "loadTime": { + "highSpeed": 7, + "threeG": 319, + }, + "size": { + "gzip": 29, + "uncompress": 29927, + }, + }, + }, + { + "assetType": "STYLESHEET_SIZE", + "change": { + "loadTime": { + "highSpeed": 0, + "threeG": 0, + }, + "size": { + "gzip": 0, + "uncompress": 24, + }, + }, + "measurements": [ + { + "avg": 5.0, + "max": 5.0, + "min": 5.0, + "timestamp": "2024-06-07T00:00:00+00:00", + }, + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-08T00:00:00+00:00", + }, + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-09T00:00:00+00:00", + }, + { + "avg": 29.0, + "max": 29.0, + "min": 29.0, + "timestamp": "2024-06-10T00:00:00+00:00", + }, + ], + "name": None, + "size": { + "loadTime": { + "highSpeed": 0, + "threeG": 0, + }, + "size": { + "gzip": 0, + "uncompress": 29, + }, + }, + }, + ], + "name": "super", + }, + } + + # Test with using asset type filters + variables = { + "org": self.org.username, + "repo": self.repo.name, + "commit": self.commit.commitid, + "orderingDirection": "ASC", + "interval": "INTERVAL_1_DAY", + "after": "2024-06-07", + "before": "2024-06-10", + "filters": {"assetTypes": "JAVASCRIPT_SIZE"}, + } + data = self.gql_request(query, variables=variables) + commit = data["owner"]["repository"]["commit"] + + assert commit["bundleAnalysisReport"] == { + "__typename": "BundleAnalysisReport", + "bundle": { + "measurements": [ + { + "assetType": "JAVASCRIPT_SIZE", + "change": { + "loadTime": { + "highSpeed": 5, + "threeG": 224, + }, + "size": { + "gzip": 21, + "uncompress": 21000, + }, + }, + "measurements": [ + { + "avg": 5708.0, + "max": 5708.0, + "min": 5708.0, + "timestamp": "2024-06-07T00:00:00+00:00", + }, + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-08T00:00:00+00:00", + }, + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-09T00:00:00+00:00", + }, + { + "avg": 26708.0, + "max": 26708.0, + "min": 26708.0, + "timestamp": "2024-06-10T00:00:00+00:00", + }, + ], + "name": None, + "size": { + "loadTime": { + "highSpeed": 7, + "threeG": 284, + }, + "size": { + "gzip": 26, + "uncompress": 26708, + }, + }, + }, + ], + "name": "super", + }, + } + + @patch("graphql_api.dataloader.bundle_analysis.get_appropriate_storage_service") + def test_bundle_report_no_carryovers(self, get_storage_service): + storage = MemoryStorageService({}) + get_storage_service.return_value = storage + + with open("./services/tests/samples/bundle_with_uuid.sqlite", "rb") as f: + storage_path = StoragePaths.bundle_report.path( + repo_key=ArchiveService.get_archive_hash(self.repo), + report_key=self.head_commit_report.external_id, + ) + storage.write_file(get_bucket_name(), storage_path, f) + + query = """ + query FetchMeasurements( + $org: String!, + $repo: String!, + $commit: String! + $filters: BundleAnalysisMeasurementsSetFilters + $orderingDirection: OrderingDirection! + $interval: MeasurementInterval! + $before: DateTime! + $after: DateTime! + ) { + owner(username: $org) { + repository(name: $repo) { + ... on Repository { + commit(id: $commit) { + bundleAnalysisReport { + __typename + ... on BundleAnalysisReport { + bundle(name: "super") { + name + measurements( + filters: $filters + orderingDirection: $orderingDirection + after: $after + interval: $interval + before: $before + ){ + assetType + name + size { + loadTime { + threeG + highSpeed + } + size { + gzip + uncompress + } + } + change { + loadTime { + threeG + highSpeed + } + size { + gzip + uncompress + } + } + measurements { + avg + min + max + timestamp + } + } + } + } + } + } + } + } + } + } + """ + + # Test without using asset type filters + variables = { + "org": self.org.username, + "repo": self.repo.name, + "commit": self.commit.commitid, + "orderingDirection": "ASC", + "interval": "INTERVAL_1_DAY", + "after": "2024-06-05", + "before": "2024-06-10", + "filters": {}, + } + data = self.gql_request(query, variables=variables) + commit = data["owner"]["repository"]["commit"] + + assert commit["bundleAnalysisReport"] == { + "__typename": "BundleAnalysisReport", + "bundle": { + "measurements": [ + { + "assetType": "ASSET_SIZE", + "change": { + "loadTime": { + "highSpeed": 2, + "threeG": 106, + }, + "size": { + "gzip": 10, + "uncompress": 10000, + }, + }, + "measurements": [ + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-05T00:00:00+00:00", + }, + { + "avg": 4126.0, + "max": 4126.0, + "min": 4126.0, + "timestamp": "2024-06-06T00:00:00+00:00", + }, + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-07T00:00:00+00:00", + }, + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-08T00:00:00+00:00", + }, + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-09T00:00:00+00:00", + }, + { + "avg": 14126.0, + "max": 14126.0, + "min": 14126.0, + "timestamp": "2024-06-10T00:00:00+00:00", + }, + ], + "name": "asset-*.js", + "size": { + "loadTime": { + "highSpeed": 3, + "threeG": 150, + }, + "size": { + "gzip": 14, + "uncompress": 14126, + }, + }, + }, + { + "assetType": "ASSET_SIZE", + "change": { + "loadTime": { + "highSpeed": 2, + "threeG": 106, + }, + "size": { + "gzip": 10, + "uncompress": 10000, + }, + }, + "measurements": [ + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-05T00:00:00+00:00", + }, + { + "avg": 1421.0, + "max": 1421.0, + "min": 1421.0, + "timestamp": "2024-06-06T00:00:00+00:00", + }, + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-07T00:00:00+00:00", + }, + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-08T00:00:00+00:00", + }, + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-09T00:00:00+00:00", + }, + { + "avg": 11421.0, + "max": 11421.0, + "min": 11421.0, + "timestamp": "2024-06-10T00:00:00+00:00", + }, + ], + "name": "asset-*.js", + "size": { + "loadTime": { + "highSpeed": 3, + "threeG": 121, + }, + "size": { + "gzip": 11, + "uncompress": 11421, + }, + }, + }, + { + "assetType": "ASSET_SIZE", + "change": None, + "measurements": [], + "name": "asset-*.js", + "size": None, + }, + { + "assetType": "FONT_SIZE", + "change": { + "loadTime": { + "highSpeed": 0, + "threeG": 2, + }, + "size": { + "gzip": 0, + "uncompress": 240, + }, + }, + "measurements": [ + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-05T00:00:00+00:00", + }, + { + "avg": 50.0, + "max": 50.0, + "min": 50.0, + "timestamp": "2024-06-06T00:00:00+00:00", + }, + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-07T00:00:00+00:00", + }, + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-08T00:00:00+00:00", + }, + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-09T00:00:00+00:00", + }, + { + "avg": 290.0, + "max": 290.0, + "min": 290.0, + "timestamp": "2024-06-10T00:00:00+00:00", + }, + ], + "name": None, + "size": { + "loadTime": { + "highSpeed": 0, + "threeG": 3, + }, + "size": { + "gzip": 0, + "uncompress": 290, + }, + }, + }, + { + "assetType": "IMAGE_SIZE", + "change": { + "loadTime": { + "highSpeed": 0, + "threeG": 25, + }, + "size": { + "gzip": 2, + "uncompress": 2400, + }, + }, + "measurements": [ + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-05T00:00:00+00:00", + }, + { + "avg": 500.0, + "max": 500.0, + "min": 500.0, + "timestamp": "2024-06-06T00:00:00+00:00", + }, + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-07T00:00:00+00:00", + }, + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-08T00:00:00+00:00", + }, + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-09T00:00:00+00:00", + }, + { + "avg": 2900.0, + "max": 2900.0, + "min": 2900.0, + "timestamp": "2024-06-10T00:00:00+00:00", + }, + ], + "name": None, + "size": { + "loadTime": { + "highSpeed": 0, + "threeG": 30, + }, + "size": { + "gzip": 2, + "uncompress": 2900, + }, + }, + }, + { + "assetType": "JAVASCRIPT_SIZE", + "change": { + "loadTime": { + "highSpeed": 5, + "threeG": 224, + }, + "size": { + "gzip": 21, + "uncompress": 21000, + }, + }, + "measurements": [ + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-05T00:00:00+00:00", + }, + { + "avg": 5708.0, + "max": 5708.0, + "min": 5708.0, + "timestamp": "2024-06-06T00:00:00+00:00", + }, + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-07T00:00:00+00:00", + }, + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-08T00:00:00+00:00", + }, + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-09T00:00:00+00:00", + }, + { + "avg": 26708.0, + "max": 26708.0, + "min": 26708.0, + "timestamp": "2024-06-10T00:00:00+00:00", + }, + ], + "name": None, + "size": { + "loadTime": { + "highSpeed": 7, + "threeG": 284, + }, + "size": { + "gzip": 26, + "uncompress": 26708, + }, + }, + }, + { + "assetType": "REPORT_SIZE", + "change": { + "loadTime": { + "highSpeed": 6, + "threeG": 252, + }, + "size": { + "gzip": 23, + "uncompress": 23664, + }, + }, + "measurements": [ + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-05T00:00:00+00:00", + }, + { + "avg": 6263.0, + "max": 6263.0, + "min": 6263.0, + "timestamp": "2024-06-06T00:00:00+00:00", + }, + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-07T00:00:00+00:00", + }, + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-08T00:00:00+00:00", + }, + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-09T00:00:00+00:00", + }, + { + "avg": 29927.0, + "max": 29927.0, + "min": 29927.0, + "timestamp": "2024-06-10T00:00:00+00:00", + }, + ], + "name": None, + "size": { + "loadTime": { + "highSpeed": 7, + "threeG": 319, + }, + "size": { + "gzip": 29, + "uncompress": 29927, + }, + }, + }, + { + "assetType": "STYLESHEET_SIZE", + "change": { + "loadTime": { + "highSpeed": 0, + "threeG": 0, + }, + "size": { + "gzip": 0, + "uncompress": 24, + }, + }, + "measurements": [ + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-05T00:00:00+00:00", + }, + { + "avg": 5.0, + "max": 5.0, + "min": 5.0, + "timestamp": "2024-06-06T00:00:00+00:00", + }, + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-07T00:00:00+00:00", + }, + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-08T00:00:00+00:00", + }, + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-09T00:00:00+00:00", + }, + { + "avg": 29.0, + "max": 29.0, + "min": 29.0, + "timestamp": "2024-06-10T00:00:00+00:00", + }, + ], + "name": None, + "size": { + "loadTime": { + "highSpeed": 0, + "threeG": 0, + }, + "size": { + "gzip": 0, + "uncompress": 29, + }, + }, + }, + ], + "name": "super", + }, + } + + # Test with using asset type filters + variables = { + "org": self.org.username, + "repo": self.repo.name, + "commit": self.commit.commitid, + "orderingDirection": "ASC", + "interval": "INTERVAL_1_DAY", + "after": "2024-06-05", + "before": "2024-06-10", + "filters": {"assetTypes": "JAVASCRIPT_SIZE"}, + } + data = self.gql_request(query, variables=variables) + commit = data["owner"]["repository"]["commit"] + + assert commit["bundleAnalysisReport"] == { + "__typename": "BundleAnalysisReport", + "bundle": { + "measurements": [ + { + "assetType": "JAVASCRIPT_SIZE", + "change": { + "loadTime": { + "highSpeed": 5, + "threeG": 224, + }, + "size": { + "gzip": 21, + "uncompress": 21000, + }, + }, + "measurements": [ + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-05T00:00:00+00:00", + }, + { + "avg": 5708.0, + "max": 5708.0, + "min": 5708.0, + "timestamp": "2024-06-06T00:00:00+00:00", + }, + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-07T00:00:00+00:00", + }, + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-08T00:00:00+00:00", + }, + { + "avg": None, + "max": None, + "min": None, + "timestamp": "2024-06-09T00:00:00+00:00", + }, + { + "avg": 26708.0, + "max": 26708.0, + "min": 26708.0, + "timestamp": "2024-06-10T00:00:00+00:00", + }, + ], + "name": None, + "size": { + "loadTime": { + "highSpeed": 7, + "threeG": 284, + }, + "size": { + "gzip": 26, + "uncompress": 26708, + }, + }, + }, + ], + "name": "super", + }, + } + + @patch("graphql_api.dataloader.bundle_analysis.get_appropriate_storage_service") + def test_bundle_report_branch(self, get_storage_service): + measurements_data = [ + # 2024-06-10 + ["bundle_analysis_report_size", "super", "2024-06-10T19:07:23", 123], + # 2024-06-06 + ["bundle_analysis_report_size", "super", "2024-06-06T19:07:23", 456], + ] + + for item in measurements_data: + MeasurementFactory( + name=item[0], + owner_id=self.org.pk, + repo_id=self.repo.pk, + branch="feat", + measurable_id=item[1], + commit_sha=self.commit.pk, + timestamp=item[2], + value=item[3], + ) + + storage = MemoryStorageService({}) + get_storage_service.return_value = storage + + with open("./services/tests/samples/bundle_with_uuid.sqlite", "rb") as f: + storage_path = StoragePaths.bundle_report.path( + repo_key=ArchiveService.get_archive_hash(self.repo), + report_key=self.head_commit_report.external_id, + ) + storage.write_file(get_bucket_name(), storage_path, f) + + query = """ + query FetchMeasurements( + $org: String!, + $repo: String!, + $commit: String! + $filters: BundleAnalysisMeasurementsSetFilters + $orderingDirection: OrderingDirection! + $interval: MeasurementInterval! + $before: DateTime! + $after: DateTime! + $branch: String! + ) { + owner(username: $org) { + repository(name: $repo) { + ... on Repository { + commit(id: $commit) { + bundleAnalysisReport { + __typename + ... on BundleAnalysisReport { + bundle(name: "super") { + name + measurements( + filters: $filters + orderingDirection: $orderingDirection + after: $after + interval: $interval + before: $before + branch: $branch + ){ + assetType + name + size { + loadTime { + threeG + highSpeed + } + size { + gzip + uncompress + } + } + change { + loadTime { + threeG + highSpeed + } + size { + gzip + uncompress + } + } + measurements { + avg + min + max + timestamp + } + } + } + } + } + } + } + } + } + } + """ + + variables = { + "org": self.org.username, + "repo": self.repo.name, + "commit": self.commit.commitid, + "orderingDirection": "ASC", + "interval": "INTERVAL_1_DAY", + "after": "2024-06-07", + "before": "2024-06-10", + "branch": "feat", + "filters": {}, + } + data = self.gql_request(query, variables=variables) + commit = data["owner"]["repository"]["commit"] + + assert commit["bundleAnalysisReport"] == { + "__typename": "BundleAnalysisReport", + "bundle": { + "name": "super", + "measurements": [ + { + "assetType": "ASSET_SIZE", + "name": "asset-*.js", + "size": None, + "change": None, + "measurements": [], + }, + { + "assetType": "ASSET_SIZE", + "name": "asset-*.js", + "size": None, + "change": None, + "measurements": [], + }, + { + "assetType": "ASSET_SIZE", + "name": "asset-*.js", + "size": None, + "change": None, + "measurements": [], + }, + { + "assetType": "FONT_SIZE", + "name": None, + "size": None, + "change": None, + "measurements": [], + }, + { + "assetType": "IMAGE_SIZE", + "name": None, + "size": None, + "change": None, + "measurements": [], + }, + { + "assetType": "JAVASCRIPT_SIZE", + "name": None, + "size": None, + "change": None, + "measurements": [], + }, + { + "assetType": "REPORT_SIZE", + "name": None, + "size": { + "loadTime": {"threeG": 1, "highSpeed": 0}, + "size": {"gzip": 0, "uncompress": 123}, + }, + "change": { + "loadTime": {"threeG": -3, "highSpeed": 0}, + "size": {"gzip": 0, "uncompress": -333}, + }, + "measurements": [ + { + "avg": 456.0, + "min": 456.0, + "max": 456.0, + "timestamp": "2024-06-07T00:00:00+00:00", + }, + { + "avg": None, + "min": None, + "max": None, + "timestamp": "2024-06-08T00:00:00+00:00", + }, + { + "avg": None, + "min": None, + "max": None, + "timestamp": "2024-06-09T00:00:00+00:00", + }, + { + "avg": 123.0, + "min": 123.0, + "max": 123.0, + "timestamp": "2024-06-10T00:00:00+00:00", + }, + ], + }, + { + "assetType": "STYLESHEET_SIZE", + "name": None, + "size": None, + "change": None, + "measurements": [], + }, + ], + }, + } + + @patch("graphql_api.dataloader.bundle_analysis.get_appropriate_storage_service") + def test_bundle_report_no_after(self, get_storage_service): + measurements_data = [ + # 2024-06-10 + ["bundle_analysis_report_size", "super", "2024-06-10T19:07:23", 123], + # 2024-06-06 + ["bundle_analysis_report_size", "super", "2024-06-06T19:07:23", 456], + ] + + for item in measurements_data: + MeasurementFactory( + name=item[0], + owner_id=self.org.pk, + repo_id=self.repo.pk, + branch="feat", + measurable_id=item[1], + commit_sha=self.commit.pk, + timestamp=item[2], + value=item[3], + ) + + storage = MemoryStorageService({}) + get_storage_service.return_value = storage + + with open("./services/tests/samples/bundle_with_uuid.sqlite", "rb") as f: + storage_path = StoragePaths.bundle_report.path( + repo_key=ArchiveService.get_archive_hash(self.repo), + report_key=self.head_commit_report.external_id, + ) + storage.write_file(get_bucket_name(), storage_path, f) + + query = """ + query FetchMeasurements( + $org: String!, + $repo: String!, + $commit: String! + $filters: BundleAnalysisMeasurementsSetFilters + $orderingDirection: OrderingDirection! + $interval: MeasurementInterval! + $before: DateTime! + $after: DateTime + $branch: String! + ) { + owner(username: $org) { + repository(name: $repo) { + ... on Repository { + commit(id: $commit) { + bundleAnalysisReport { + __typename + ... on BundleAnalysisReport { + bundle(name: "super") { + name + measurements( + filters: $filters + orderingDirection: $orderingDirection + after: $after + interval: $interval + before: $before + branch: $branch + ){ + assetType + name + size { + loadTime { + threeG + highSpeed + } + size { + gzip + uncompress + } + } + change { + loadTime { + threeG + highSpeed + } + size { + gzip + uncompress + } + } + measurements { + avg + min + max + timestamp + } + } + } + } + } + } + } + } + } + } + """ + + variables = { + "org": self.org.username, + "repo": self.repo.name, + "commit": self.commit.commitid, + "orderingDirection": "ASC", + "interval": "INTERVAL_1_DAY", + "after": None, + "before": "2024-06-10", + "branch": "feat", + "filters": {}, + } + data = self.gql_request(query, variables=variables) + commit = data["owner"]["repository"]["commit"] + + assert commit["bundleAnalysisReport"] == { + "__typename": "BundleAnalysisReport", + "bundle": { + "name": "super", + "measurements": [ + { + "assetType": "ASSET_SIZE", + "name": "asset-*.js", + "size": None, + "change": None, + "measurements": [], + }, + { + "assetType": "ASSET_SIZE", + "name": "asset-*.js", + "size": None, + "change": None, + "measurements": [], + }, + { + "assetType": "ASSET_SIZE", + "name": "asset-*.js", + "size": None, + "change": None, + "measurements": [], + }, + { + "assetType": "FONT_SIZE", + "name": None, + "size": None, + "change": None, + "measurements": [], + }, + { + "assetType": "IMAGE_SIZE", + "name": None, + "size": None, + "change": None, + "measurements": [], + }, + { + "assetType": "JAVASCRIPT_SIZE", + "name": None, + "size": None, + "change": None, + "measurements": [], + }, + { + "assetType": "REPORT_SIZE", + "name": None, + "size": { + "loadTime": {"threeG": 1, "highSpeed": 0}, + "size": {"gzip": 0, "uncompress": 123}, + }, + "change": { + "loadTime": {"threeG": -3, "highSpeed": 0}, + "size": {"gzip": 0, "uncompress": -333}, + }, + "measurements": [ + { + "avg": 456.0, + "min": 456.0, + "max": 456.0, + "timestamp": "2024-06-06T00:00:00+00:00", + }, + { + "avg": None, + "min": None, + "max": None, + "timestamp": "2024-06-07T00:00:00+00:00", + }, + { + "avg": None, + "min": None, + "max": None, + "timestamp": "2024-06-08T00:00:00+00:00", + }, + { + "avg": None, + "min": None, + "max": None, + "timestamp": "2024-06-09T00:00:00+00:00", + }, + { + "avg": 123.0, + "min": 123.0, + "max": 123.0, + "timestamp": "2024-06-10T00:00:00+00:00", + }, + ], + }, + { + "assetType": "STYLESHEET_SIZE", + "name": None, + "size": None, + "change": None, + "measurements": [], + }, + ], + }, + } diff --git a/graphql_api/tests/test_commit.py b/graphql_api/tests/test_commit.py index c6f48f5c4a..443d3070f2 100644 --- a/graphql_api/tests/test_commit.py +++ b/graphql_api/tests/test_commit.py @@ -11,7 +11,6 @@ from shared.reports.types import LineSession from shared.storage.memory import MemoryStorageService -import services.comparison as comparison from codecov_auth.tests.factories import OwnerFactory from compare.models import CommitComparison from compare.tests.factories import CommitComparisonFactory @@ -89,7 +88,7 @@ def __init__(self): class MockReport(object): def get(self, file, _else): - lines = MockLines() + MockLines() return MockLines() def filter(self, **kwargs): @@ -161,23 +160,22 @@ def test_fetch_commits(self): self.repo_2 = RepositoryFactory( author=self.org, name="test-repo", private=False ) - commits_in_db = [ - CommitFactory( - repository=self.repo_2, - commitid=123, - timestamp=datetime.today() - timedelta(days=3), - ), - CommitFactory( - repository=self.repo_2, - commitid=456, - timestamp=datetime.today() - timedelta(days=1), - ), - CommitFactory( - repository=self.repo_2, - commitid=789, - timestamp=datetime.today() - timedelta(days=2), - ), - ] + + CommitFactory( + repository=self.repo_2, + commitid=123, + timestamp=datetime.today() - timedelta(days=3), + ) + CommitFactory( + repository=self.repo_2, + commitid=456, + timestamp=datetime.today() - timedelta(days=1), + ) + CommitFactory( + repository=self.repo_2, + commitid=789, + timestamp=datetime.today() - timedelta(days=2), + ) variables = {"org": self.org.username, "repo": self.repo_2.name} data = self.gql_request(query, variables=variables) @@ -208,7 +206,7 @@ def test_resolve_commit_without_parent(self): } data = self.gql_request(query, variables=variables) commit = data["owner"]["repository"]["commit"] - assert commit["parent"] == None + assert commit["parent"] is None def test_fetch_commit_coverage(self): ReportLevelTotalsFactory(report=self.report, coverage=12) @@ -240,18 +238,19 @@ def test_fetch_commit_build(self): ] def test_fetch_commit_uploads_state(self): - session_one = UploadFactory( + UploadFactory( report=self.report, provider="circleci", state=UploadState.PROCESSED.value ) - session_two = UploadFactory( + UploadFactory( report=self.report, provider="travisci", state=UploadState.ERROR.value ) - session_three = UploadFactory( + UploadFactory( report=self.report, provider="travisci", state=UploadState.COMPLETE.value ) - session_four = UploadFactory( + UploadFactory( report=self.report, provider="travisci", state=UploadState.UPLOADED.value ) + UploadFactory(report=self.report, provider="travisci", state="") query = ( query_commit % """ @@ -278,6 +277,7 @@ def test_fetch_commit_uploads_state(self): {"state": UploadState.ERROR.name}, {"state": UploadState.COMPLETE.name}, {"state": UploadState.UPLOADED.name}, + {"state": UploadState.ERROR.name}, ] def test_fetch_commit_uploads(self): @@ -320,7 +320,7 @@ def test_fetch_commit_uploads(self): order_number=4, ) UploadFlagMembershipFactory(report_session=session_e, flag=flag_d) - session_f = UploadFactory( + UploadFactory( report=self.report, upload_type=UploadType.UPLOADED.value, order_number=5, @@ -400,10 +400,10 @@ def test_fetch_commit_uploads_errors(self): session = UploadFactory( report=self.report, provider="circleci", state=UploadState.ERROR.value ) - error_one = UploadErrorFactory( + UploadErrorFactory( report_session=session, error_code=UploadErrorEnum.REPORT_EXPIRED.value ) - error_two = UploadErrorFactory( + UploadErrorFactory( report_session=session, error_code=UploadErrorEnum.FILE_NOT_IN_STORAGE.value ) @@ -459,7 +459,7 @@ def test_yaml_return_default_state_if_default(self): assert data["owner"]["repository"]["commit"]["yamlState"] == "DEFAULT" def test_fetch_commit_ci(self): - session_one = UploadFactory( + UploadFactory( report=self.report, provider="circleci", job_code=123, @@ -707,7 +707,7 @@ def test_fetch_commit_compare_no_parent(self): query_commit % """ compareWithParent { __typename ... on Comparison { state } } - bundleAnalysisCompareWithParent { __typename ... on BundleAnalysisComparison { sizeDelta } } + bundleAnalysisCompareWithParent { __typename ... on BundleAnalysisComparison { bundleData { size { uncompress } } } } """ ) variables = { @@ -818,17 +818,9 @@ def test_bundle_analysis_compare(self, get_storage_service): bundleAnalysisCompareWithParent { __typename ... on BundleAnalysisComparison { - sizeDelta - sizeTotal - loadTimeDelta - loadTimeTotal bundles { name changeType - sizeDelta - sizeTotal - loadTimeDelta - loadTimeTotal bundleData { size { uncompress @@ -864,58 +856,34 @@ def test_bundle_analysis_compare(self, get_storage_service): commit = data["owner"]["repository"]["commit"] assert commit["bundleAnalysisCompareWithParent"] == { "__typename": "BundleAnalysisComparison", - "sizeDelta": 36555, - "sizeTotal": 201720, - "loadTimeDelta": 0.1, - "loadTimeTotal": 0.5, "bundles": [ { "name": "b1", "changeType": "changed", - "sizeDelta": 5, - "sizeTotal": 20, - "loadTimeDelta": 0.0, - "loadTimeTotal": 0.0, "bundleData": {"size": {"uncompress": 20}}, "bundleChange": {"size": {"uncompress": 5}}, }, { "name": "b2", "changeType": "changed", - "sizeDelta": 50, - "sizeTotal": 200, - "loadTimeDelta": 0.0, - "loadTimeTotal": 0.0, "bundleData": {"size": {"uncompress": 200}}, "bundleChange": {"size": {"uncompress": 50}}, }, { "name": "b3", "changeType": "added", - "sizeDelta": 1500, - "sizeTotal": 1500, - "loadTimeDelta": 0.0, - "loadTimeTotal": 0.0, "bundleData": {"size": {"uncompress": 1500}}, "bundleChange": {"size": {"uncompress": 1500}}, }, { "name": "b5", "changeType": "changed", - "sizeDelta": 50000, - "sizeTotal": 200000, - "loadTimeDelta": 0.1, - "loadTimeTotal": 0.5, "bundleData": {"size": {"uncompress": 200000}}, "bundleChange": {"size": {"uncompress": 50000}}, }, { "name": "b4", "changeType": "removed", - "sizeDelta": -15000, - "sizeTotal": 0, - "loadTimeDelta": -0.0, - "loadTimeTotal": 0.0, "bundleData": {"size": {"uncompress": 0}}, "bundleChange": {"size": {"uncompress": -15000}}, }, @@ -924,6 +892,104 @@ def test_bundle_analysis_compare(self, get_storage_service): "bundleChange": {"size": {"uncompress": 36555}}, } + @patch("graphql_api.dataloader.bundle_analysis.get_appropriate_storage_service") + def test_bundle_analysis_compare_with_compare_sha(self, get_storage_service): + """ + This tests creates 3 commits C1 -> C2 -> C3 + C1 uses Report1, C2 and C3 uses Report2 + Normally when doing a compare of C3, it would select C2 as its parent + then it would show no change, as expected + However the difference is that in C3's Report2 it has the compareSha set to C1.commitid + Now when doing comparison of C3, it would now select C1 as the parent + therefore show correct comparison in numbers between Report1 and Report2 + """ + storage = MemoryStorageService({}) + get_storage_service.return_value = storage + + commit_1 = CommitFactory( + repository=self.repo, + commitid="6ca727b0142bf5625bb82af2555d308862063222", + ) + commit_2 = CommitFactory( + repository=self.repo, parent_commit_id=commit_1.commitid + ) + commit_3 = CommitFactory( + repository=self.repo, parent_commit_id=commit_2.commitid + ) + + commit_report_1 = CommitReportFactory( + commit=commit_1, + report_type=CommitReport.ReportType.BUNDLE_ANALYSIS, + ) + + commit_report_2 = CommitReportFactory( + commit=commit_2, + report_type=CommitReport.ReportType.BUNDLE_ANALYSIS, + ) + + commit_report_3 = CommitReportFactory( + commit=commit_3, + report_type=CommitReport.ReportType.BUNDLE_ANALYSIS, + ) + + with open("./services/tests/samples/base_bundle_report.sqlite", "rb") as f: + storage_path = StoragePaths.bundle_report.path( + repo_key=ArchiveService.get_archive_hash(self.repo), + report_key=commit_report_1.external_id, + ) + storage.write_file(get_bucket_name(), storage_path, f) + + with open("./services/tests/samples/head_bundle_report.sqlite", "rb") as f: + storage_path = StoragePaths.bundle_report.path( + repo_key=ArchiveService.get_archive_hash(self.repo), + report_key=commit_report_2.external_id, + ) + storage.write_file(get_bucket_name(), storage_path, f) + + with open( + "./services/tests/samples/head_bundle_report_with_compare_sha_6ca727b0142bf5625bb82af2555d308862063222.sqlite", + "rb", + ) as f: + storage_path = StoragePaths.bundle_report.path( + repo_key=ArchiveService.get_archive_hash(self.repo), + report_key=commit_report_3.external_id, + ) + storage.write_file(get_bucket_name(), storage_path, f) + + query = ( + query_commit + % """ + bundleAnalysisCompareWithParent { + __typename + ... on BundleAnalysisComparison { + bundleData { + size { + uncompress + } + } + bundleChange { + size { + uncompress + } + } + } + } + """ + ) + + variables = { + "org": self.org.username, + "repo": self.repo.name, + "commit": commit_report_3.commit.commitid, + } + data = self.gql_request(query, variables=variables) + commit = data["owner"]["repository"]["commit"] + assert commit["bundleAnalysisCompareWithParent"] == { + "__typename": "BundleAnalysisComparison", + "bundleData": {"size": {"uncompress": 201720}}, + "bundleChange": {"size": {"uncompress": 36555}}, + } + @patch("graphql_api.dataloader.bundle_analysis.get_appropriate_storage_service") def test_bundle_analysis_sqlite_file_deleted(self, get_storage_service): os.system("rm -rf /tmp/bundle_analysis_*") @@ -958,7 +1024,11 @@ def test_bundle_analysis_sqlite_file_deleted(self, get_storage_service): bundleAnalysisCompareWithParent { __typename ... on BundleAnalysisComparison { - sizeTotal + bundleData { + size { + uncompress + } + } } } """ @@ -1013,7 +1083,11 @@ def test_bundle_analysis_sqlite_file_not_deleted( bundleAnalysisCompareWithParent { __typename ... on BundleAnalysisComparison { - sizeTotal + bundleData { + size { + uncompress + } + } } } """ @@ -1076,7 +1150,7 @@ def test_bundle_analysis_report(self, get_storage_service): storage.write_file(get_bucket_name(), storage_path, f) query = """ - query FetchCommit($org: String!, $repo: String!, $commit: String!, $filters: BundleAnalysisReportFilters) { + query FetchCommit($org: String!, $repo: String!, $commit: String!) { owner(username: $org) { repository(name: $repo) { ... on Repository { @@ -1084,13 +1158,9 @@ def test_bundle_analysis_report(self, get_storage_service): bundleAnalysisReport { __typename ... on BundleAnalysisReport { - sizeTotal - loadTimeTotal bundles { name - sizeTotal - loadTimeTotal - assets(filters: $filters) { + assets { normalizedName } asset(name: "not_exist") { @@ -1106,6 +1176,7 @@ def test_bundle_analysis_report(self, get_storage_service): uncompress } } + isCached } bundleData { loadTime { @@ -1119,7 +1190,9 @@ def test_bundle_analysis_report(self, get_storage_service): } bundle(name: "not_exist") { name + isCached } + isCached } ... on MissingHeadReport { message @@ -1136,26 +1209,21 @@ def test_bundle_analysis_report(self, get_storage_service): "org": self.org.username, "repo": self.repo.name, "commit": self.commit.commitid, - "filters": {"moduleExtensions": []}, } data = self.gql_request(query, variables=variables) commit = data["owner"]["repository"]["commit"] assert commit["bundleAnalysisReport"] == { "__typename": "BundleAnalysisReport", - "sizeTotal": 201720, - "loadTimeTotal": 0.5, "bundles": [ { "name": "b1", - "sizeTotal": 20, - "loadTimeTotal": 0.0, "assets": [ - {"normalizedName": "assets/react-*.svg"}, - {"normalizedName": "assets/index-*.css"}, - {"normalizedName": "assets/LazyComponent-*.js"}, {"normalizedName": "assets/index-*.js"}, {"normalizedName": "assets/index-*.js"}, + {"normalizedName": "assets/LazyComponent-*.js"}, + {"normalizedName": "assets/index-*.css"}, + {"normalizedName": "assets/react-*.svg"}, ], "asset": None, "bundleData": { @@ -1168,17 +1236,16 @@ def test_bundle_analysis_report(self, get_storage_service): "uncompress": 20, }, }, + "isCached": False, }, { "name": "b2", - "sizeTotal": 200, - "loadTimeTotal": 0.0, "assets": [ - {"normalizedName": "assets/react-*.svg"}, - {"normalizedName": "assets/index-*.css"}, - {"normalizedName": "assets/LazyComponent-*.js"}, {"normalizedName": "assets/index-*.js"}, {"normalizedName": "assets/index-*.js"}, + {"normalizedName": "assets/LazyComponent-*.js"}, + {"normalizedName": "assets/index-*.css"}, + {"normalizedName": "assets/react-*.svg"}, ], "asset": None, "bundleData": { @@ -1191,17 +1258,16 @@ def test_bundle_analysis_report(self, get_storage_service): "uncompress": 200, }, }, + "isCached": False, }, { "name": "b3", - "sizeTotal": 1500, - "loadTimeTotal": 0.0, "assets": [ - {"normalizedName": "assets/react-*.svg"}, - {"normalizedName": "assets/index-*.css"}, - {"normalizedName": "assets/LazyComponent-*.js"}, {"normalizedName": "assets/index-*.js"}, {"normalizedName": "assets/index-*.js"}, + {"normalizedName": "assets/LazyComponent-*.js"}, + {"normalizedName": "assets/index-*.css"}, + {"normalizedName": "assets/react-*.svg"}, ], "asset": None, "bundleData": { @@ -1210,21 +1276,20 @@ def test_bundle_analysis_report(self, get_storage_service): "highSpeed": 0, }, "size": { - "gzip": 1, + "gzip": 0, "uncompress": 1500, }, }, + "isCached": False, }, { "name": "b5", - "sizeTotal": 200000, - "loadTimeTotal": 0.5, "assets": [ - {"normalizedName": "assets/react-*.svg"}, - {"normalizedName": "assets/index-*.css"}, - {"normalizedName": "assets/LazyComponent-*.js"}, {"normalizedName": "assets/index-*.js"}, {"normalizedName": "assets/index-*.js"}, + {"normalizedName": "assets/LazyComponent-*.js"}, + {"normalizedName": "assets/index-*.css"}, + {"normalizedName": "assets/react-*.svg"}, ], "asset": None, "bundleData": { @@ -1237,6 +1302,7 @@ def test_bundle_analysis_report(self, get_storage_service): "uncompress": 200000, }, }, + "isCached": False, }, ], "bundleData": { @@ -1250,10 +1316,13 @@ def test_bundle_analysis_report(self, get_storage_service): }, }, "bundle": None, + "isCached": False, } @patch("graphql_api.dataloader.bundle_analysis.get_appropriate_storage_service") - def test_bundle_analysis_asset(self, get_storage_service): + def test_bundle_analysis_report_assets_paginated_first_after( + self, get_storage_service + ): storage = MemoryStorageService({}) get_storage_service.return_value = storage @@ -1261,9 +1330,7 @@ def test_bundle_analysis_asset(self, get_storage_service): commit=self.commit, report_type=CommitReport.ReportType.BUNDLE_ANALYSIS ) - with open( - "./services/tests/samples/bundle_with_assets_and_modules.sqlite", "rb" - ) as f: + with open("./services/tests/samples/head_bundle_report.sqlite", "rb") as f: storage_path = StoragePaths.bundle_report.path( repo_key=ArchiveService.get_archive_hash(self.repo), report_key=head_commit_report.external_id, @@ -1271,7 +1338,13 @@ def test_bundle_analysis_asset(self, get_storage_service): storage.write_file(get_bucket_name(), storage_path, f) query = """ - query FetchCommit($org: String!, $repo: String!, $commit: String!) { + query FetchCommit( + $org: String!, + $repo: String!, + $commit: String!, + $ordering: AssetOrdering, + $orderingDirection: OrderingDirection + ) { owner(username: $org) { repository(name: $repo) { ... on Repository { @@ -1279,36 +1352,25 @@ def test_bundle_analysis_asset(self, get_storage_service): bundleAnalysisReport { __typename ... on BundleAnalysisReport { - bundle(name: "b5") { - moduleExtensions - moduleCount - asset(name: "assets/LazyComponent-fcbb0922.js") { - name - normalizedName - extension - moduleExtensions - bundleData { - loadTime { - threeG - highSpeed - } - size { - gzip - uncompress + bundle(name: "b1") { + assetsPaginated ( + ordering: $ordering, + orderingDirection: $orderingDirection, + first: 2, + after: "5", + ){ + totalCount + edges { + cursor + node { + normalizedName } } - modules { - name - bundleData { - loadTime { - threeG - highSpeed - } - size { - gzip - uncompress - } - } + pageInfo { + hasNextPage + hasPreviousPage + startCursor + endCursor } } } @@ -1325,113 +1387,770 @@ def test_bundle_analysis_asset(self, get_storage_service): "org": self.org.username, "repo": self.repo.name, "commit": self.commit.commitid, + "ordering": "NAME", + "orderingDirection": "ASC", } data = self.gql_request(query, variables=variables) commit = data["owner"]["repository"]["commit"] - bundle_report = commit["bundleAnalysisReport"]["bundle"] - asset_report = bundle_report["asset"] - - assert bundle_report is not None - assert sorted(bundle_report["moduleExtensions"]) == [ - "", - "css", - "html", - "js", - "svg", - "ts", - "tsx", - ] - assert bundle_report["moduleCount"] == 7 - - assert asset_report is not None - assert asset_report["name"] == "assets/LazyComponent-fcbb0922.js" - assert asset_report["normalizedName"] == "assets/LazyComponent-*.js" - assert asset_report["extension"] == "js" - assert set(asset_report["moduleExtensions"]) == set(["", "tsx"]) - assert asset_report["bundleData"] == { - "loadTime": { - "threeG": 320, - "highSpeed": 8, - }, - "size": { - "gzip": 30, - "uncompress": 30000, + assert commit["bundleAnalysisReport"] == { + "__typename": "BundleAnalysisReport", + "bundle": { + "assetsPaginated": { + "totalCount": 5, + "edges": [ + { + "cursor": "4", + "node": { + "normalizedName": "assets/index-*.js", + }, + }, + { + "cursor": "2", + "node": { + "normalizedName": "assets/index-*.css", + }, + }, + ], + "pageInfo": { + "hasNextPage": True, + "hasPreviousPage": False, + "startCursor": "4", + "endCursor": "2", + }, + } }, } - modules = sorted(asset_report["modules"], key=lambda m: m["name"]) - - assert modules and len(modules) == 3 - assert modules[0] == { - "name": "./src/LazyComponent/LazyComponent", - "bundleData": { - "loadTime": { - "threeG": 64, - "highSpeed": 1, - }, - "size": { - "gzip": 6, - "uncompress": 6000, - }, - }, - } - assert modules[1] == { - "name": "./src/LazyComponent/LazyComponent.tsx", - "bundleData": { - "loadTime": { - "threeG": 53, - "highSpeed": 1, - }, - "size": { - "gzip": 5, - "uncompress": 5000, - }, - }, - } - assert modules[2] == { - "name": "./src/LazyComponent/LazyComponent.tsx?module", - "bundleData": { - "loadTime": { - "threeG": 53, - "highSpeed": 1, - }, - "size": { - "gzip": 4, - "uncompress": 4970, - }, - }, - } + @patch("graphql_api.dataloader.bundle_analysis.get_appropriate_storage_service") + def test_bundle_analysis_report_assets_paginated_first_after_non_existing_cursor( + self, get_storage_service + ): + storage = MemoryStorageService({}) + get_storage_service.return_value = storage - def test_compare_with_parent_missing_change_coverage(self): - CommitComparisonFactory( - base_commit=self.parent_commit, - compare_commit=self.commit, - state=CommitComparison.CommitComparisonStates.PROCESSED, - ) - ReportLevelTotalsFactory( - report=CommitReportFactory(commit=self.parent_commit), - coverage=75.0, - files=0, - lines=0, - hits=0, - misses=0, - partials=0, - branches=0, - methods=0, + head_commit_report = CommitReportFactory( + commit=self.commit, report_type=CommitReport.ReportType.BUNDLE_ANALYSIS ) - query = ( - query_commit % "compareWithParent { ... on Comparison { changeCoverage } }" - ) - variables = { - "org": self.org.username, - "repo": self.repo.name, - "commit": self.commit.commitid, - } - data = self.gql_request(query, variables=variables) - commit = data["owner"]["repository"]["commit"] - assert commit["compareWithParent"]["changeCoverage"] == None + with open("./services/tests/samples/head_bundle_report.sqlite", "rb") as f: + storage_path = StoragePaths.bundle_report.path( + repo_key=ArchiveService.get_archive_hash(self.repo), + report_key=head_commit_report.external_id, + ) + storage.write_file(get_bucket_name(), storage_path, f) + + query = """ + query FetchCommit( + $org: String!, + $repo: String!, + $commit: String!, + $ordering: AssetOrdering, + $orderingDirection: OrderingDirection + ) { + owner(username: $org) { + repository(name: $repo) { + ... on Repository { + commit(id: $commit) { + bundleAnalysisReport { + __typename + ... on BundleAnalysisReport { + bundle(name: "b1") { + assetsPaginated ( + ordering: $ordering, + orderingDirection: $orderingDirection, + first: 2, + after: "notanumber", + ){ + totalCount + edges { + cursor + node { + normalizedName + } + } + pageInfo { + hasNextPage + hasPreviousPage + startCursor + endCursor + } + } + } + } + } + } + } + } + } + } + """ + + variables = { + "org": self.org.username, + "repo": self.repo.name, + "commit": self.commit.commitid, + "ordering": "NAME", + "orderingDirection": "ASC", + } + data = self.gql_request(query, variables=variables) + commit = data["owner"]["repository"]["commit"] + + assert commit["bundleAnalysisReport"] == { + "__typename": "BundleAnalysisReport", + "bundle": { + "assetsPaginated": { + "totalCount": 5, + "edges": [ + { + "cursor": "3", + "node": { + "normalizedName": "assets/LazyComponent-*.js", + }, + }, + { + "cursor": "5", + "node": { + "normalizedName": "assets/index-*.js", + }, + }, + ], + "pageInfo": { + "hasNextPage": True, + "hasPreviousPage": False, + "startCursor": "3", + "endCursor": "5", + }, + } + }, + } + + @patch("graphql_api.dataloader.bundle_analysis.get_appropriate_storage_service") + def test_bundle_analysis_report_assets_paginated_last_before( + self, get_storage_service + ): + storage = MemoryStorageService({}) + get_storage_service.return_value = storage + + head_commit_report = CommitReportFactory( + commit=self.commit, report_type=CommitReport.ReportType.BUNDLE_ANALYSIS + ) + + with open("./services/tests/samples/head_bundle_report.sqlite", "rb") as f: + storage_path = StoragePaths.bundle_report.path( + repo_key=ArchiveService.get_archive_hash(self.repo), + report_key=head_commit_report.external_id, + ) + storage.write_file(get_bucket_name(), storage_path, f) + + query = """ + query FetchCommit( + $org: String!, + $repo: String!, + $commit: String!, + $ordering: AssetOrdering, + $orderingDirection: OrderingDirection + ) { + owner(username: $org) { + repository(name: $repo) { + ... on Repository { + commit(id: $commit) { + bundleAnalysisReport { + __typename + ... on BundleAnalysisReport { + bundle(name: "b1") { + assetsPaginated ( + ordering: $ordering, + orderingDirection: $orderingDirection, + last: 2, + before: "1", + ){ + totalCount + edges { + cursor + node { + normalizedName + } + } + pageInfo { + hasNextPage + hasPreviousPage + startCursor + endCursor + } + } + } + } + } + } + } + } + } + } + """ + + variables = { + "org": self.org.username, + "repo": self.repo.name, + "commit": self.commit.commitid, + "ordering": "NAME", + "orderingDirection": "ASC", + } + data = self.gql_request(query, variables=variables) + commit = data["owner"]["repository"]["commit"] + + assert commit["bundleAnalysisReport"] == { + "__typename": "BundleAnalysisReport", + "bundle": { + "assetsPaginated": { + "totalCount": 5, + "edges": [ + { + "cursor": "4", + "node": { + "normalizedName": "assets/index-*.js", + }, + }, + { + "cursor": "2", + "node": { + "normalizedName": "assets/index-*.css", + }, + }, + ], + "pageInfo": { + "hasNextPage": False, + "hasPreviousPage": True, + "startCursor": "4", + "endCursor": "2", + }, + } + }, + } + + @patch("graphql_api.dataloader.bundle_analysis.get_appropriate_storage_service") + def test_bundle_analysis_report_assets_paginated_last_before_non_existing_cursor( + self, get_storage_service + ): + storage = MemoryStorageService({}) + get_storage_service.return_value = storage + + head_commit_report = CommitReportFactory( + commit=self.commit, report_type=CommitReport.ReportType.BUNDLE_ANALYSIS + ) + + with open("./services/tests/samples/head_bundle_report.sqlite", "rb") as f: + storage_path = StoragePaths.bundle_report.path( + repo_key=ArchiveService.get_archive_hash(self.repo), + report_key=head_commit_report.external_id, + ) + storage.write_file(get_bucket_name(), storage_path, f) + + query = """ + query FetchCommit( + $org: String!, + $repo: String!, + $commit: String!, + $ordering: AssetOrdering, + $orderingDirection: OrderingDirection + ) { + owner(username: $org) { + repository(name: $repo) { + ... on Repository { + commit(id: $commit) { + bundleAnalysisReport { + __typename + ... on BundleAnalysisReport { + bundle(name: "b1") { + assetsPaginated ( + ordering: $ordering, + orderingDirection: $orderingDirection, + last: 2, + before: "99999", + ){ + totalCount + edges { + cursor + node { + normalizedName + } + } + pageInfo { + hasNextPage + hasPreviousPage + startCursor + endCursor + } + } + } + } + } + } + } + } + } + } + """ + + variables = { + "org": self.org.username, + "repo": self.repo.name, + "commit": self.commit.commitid, + "ordering": "NAME", + "orderingDirection": "ASC", + } + data = self.gql_request(query, variables=variables) + commit = data["owner"]["repository"]["commit"] + + assert commit["bundleAnalysisReport"] == { + "__typename": "BundleAnalysisReport", + "bundle": { + "assetsPaginated": { + "totalCount": 5, + "edges": [ + { + "cursor": "2", + "node": { + "normalizedName": "assets/index-*.css", + }, + }, + { + "cursor": "1", + "node": { + "normalizedName": "assets/react-*.svg", + }, + }, + ], + "pageInfo": { + "hasNextPage": False, + "hasPreviousPage": True, + "startCursor": "2", + "endCursor": "1", + }, + } + }, + } + + @patch("graphql_api.dataloader.bundle_analysis.get_appropriate_storage_service") + def test_bundle_analysis_report_assets_paginated_before_and_after_error( + self, get_storage_service + ): + storage = MemoryStorageService({}) + get_storage_service.return_value = storage + + head_commit_report = CommitReportFactory( + commit=self.commit, report_type=CommitReport.ReportType.BUNDLE_ANALYSIS + ) + + with open("./services/tests/samples/head_bundle_report.sqlite", "rb") as f: + storage_path = StoragePaths.bundle_report.path( + repo_key=ArchiveService.get_archive_hash(self.repo), + report_key=head_commit_report.external_id, + ) + storage.write_file(get_bucket_name(), storage_path, f) + + query = """ + query FetchCommit( + $org: String!, + $repo: String!, + $commit: String!, + $ordering: AssetOrdering, + $orderingDirection: OrderingDirection + ) { + owner(username: $org) { + repository(name: $repo) { + ... on Repository { + commit(id: $commit) { + bundleAnalysisReport { + __typename + ... on BundleAnalysisReport { + bundle(name: "b1") { + assetsPaginated ( + ordering: $ordering, + orderingDirection: $orderingDirection, + before: "1", + after: "2", + ){ + totalCount + } + } + } + } + } + } + } + } + } + """ + + variables = { + "org": self.org.username, + "repo": self.repo.name, + "commit": self.commit.commitid, + } + data = self.gql_request(query, with_errors=True, variables=variables) + commit = data["data"]["owner"]["repository"]["commit"] + + assert commit["bundleAnalysisReport"] == { + "__typename": "BundleAnalysisReport", + "bundle": {"assetsPaginated": None}, + } + + assert ( + data["errors"][0]["message"] + == "After and before can not be used at the same time" + ) + + @patch("graphql_api.dataloader.bundle_analysis.get_appropriate_storage_service") + def test_bundle_analysis_report_assets_paginated_first_and_last_error( + self, get_storage_service + ): + storage = MemoryStorageService({}) + get_storage_service.return_value = storage + + head_commit_report = CommitReportFactory( + commit=self.commit, report_type=CommitReport.ReportType.BUNDLE_ANALYSIS + ) + + with open("./services/tests/samples/head_bundle_report.sqlite", "rb") as f: + storage_path = StoragePaths.bundle_report.path( + repo_key=ArchiveService.get_archive_hash(self.repo), + report_key=head_commit_report.external_id, + ) + storage.write_file(get_bucket_name(), storage_path, f) + + query = """ + query FetchCommit( + $org: String!, + $repo: String!, + $commit: String!, + $ordering: AssetOrdering, + $orderingDirection: OrderingDirection + ) { + owner(username: $org) { + repository(name: $repo) { + ... on Repository { + commit(id: $commit) { + bundleAnalysisReport { + __typename + ... on BundleAnalysisReport { + bundle(name: "b1") { + assetsPaginated ( + ordering: $ordering, + orderingDirection: $orderingDirection, + first: 1, + last: 2, + ){ + totalCount + } + } + } + } + } + } + } + } + } + """ + + variables = { + "org": self.org.username, + "repo": self.repo.name, + "commit": self.commit.commitid, + } + data = self.gql_request(query, with_errors=True, variables=variables) + commit = data["data"]["owner"]["repository"]["commit"] + + assert commit["bundleAnalysisReport"] == { + "__typename": "BundleAnalysisReport", + "bundle": {"assetsPaginated": None}, + } + + assert ( + data["errors"][0]["message"] + == "First and last can not be used at the same time" + ) + + @patch("graphql_api.dataloader.bundle_analysis.get_appropriate_storage_service") + def test_bundle_analysis_asset(self, get_storage_service): + storage = MemoryStorageService({}) + get_storage_service.return_value = storage + + head_commit_report = CommitReportFactory( + commit=self.commit, report_type=CommitReport.ReportType.BUNDLE_ANALYSIS + ) + + with open( + "./services/tests/samples/bundle_with_assets_and_modules.sqlite", "rb" + ) as f: + storage_path = StoragePaths.bundle_report.path( + repo_key=ArchiveService.get_archive_hash(self.repo), + report_key=head_commit_report.external_id, + ) + storage.write_file(get_bucket_name(), storage_path, f) + + query = """ + query FetchCommit($org: String!, $repo: String!, $commit: String!) { + owner(username: $org) { + repository(name: $repo) { + ... on Repository { + commit(id: $commit) { + bundleAnalysisReport { + __typename + ... on BundleAnalysisReport { + bundle(name: "b5") { + moduleCount + asset(name: "assets/LazyComponent-fcbb0922.js") { + name + normalizedName + extension + bundleData { + loadTime { + threeG + highSpeed + } + size { + gzip + uncompress + } + } + modules { + name + bundleData { + loadTime { + threeG + highSpeed + } + size { + gzip + uncompress + } + } + } + } + } + } + } + } + } + } + } + } + """ + + variables = { + "org": self.org.username, + "repo": self.repo.name, + "commit": self.commit.commitid, + } + data = self.gql_request(query, variables=variables) + commit = data["owner"]["repository"]["commit"] + + bundle_report = commit["bundleAnalysisReport"]["bundle"] + asset_report = bundle_report["asset"] + + assert bundle_report is not None + assert bundle_report["moduleCount"] == 33 + + assert asset_report is not None + assert asset_report["name"] == "assets/LazyComponent-fcbb0922.js" + assert asset_report["normalizedName"] == "assets/LazyComponent-*.js" + assert asset_report["extension"] == "js" + assert asset_report["bundleData"] == { + "loadTime": { + "threeG": 320, + "highSpeed": 8, + }, + "size": { + "gzip": 30, + "uncompress": 30000, + }, + } + + modules = sorted(asset_report["modules"], key=lambda m: m["name"]) + + assert modules and len(modules) == 3 + assert modules[0] == { + "name": "./src/LazyComponent/LazyComponent", + "bundleData": { + "loadTime": { + "threeG": 64, + "highSpeed": 1, + }, + "size": { + "gzip": 6, + "uncompress": 6000, + }, + }, + } + assert modules[1] == { + "name": "./src/LazyComponent/LazyComponent.tsx", + "bundleData": { + "loadTime": { + "threeG": 53, + "highSpeed": 1, + }, + "size": { + "gzip": 5, + "uncompress": 5000, + }, + }, + } + assert modules[2] == { + "name": "./src/LazyComponent/LazyComponent.tsx?module", + "bundleData": { + "loadTime": { + "threeG": 53, + "highSpeed": 1, + }, + "size": { + "gzip": 4, + "uncompress": 4970, + }, + }, + } + + @patch("shared.bundle_analysis.BundleReport.asset_reports") + @patch("graphql_api.dataloader.bundle_analysis.get_appropriate_storage_service") + def test_bundle_analysis_asset_filtering( + self, get_storage_service, asset_reports_mock + ): + storage = MemoryStorageService({}) + + get_storage_service.return_value = storage + asset_reports_mock.return_value = [] + + head_commit_report = CommitReportFactory( + commit=self.commit, report_type=CommitReport.ReportType.BUNDLE_ANALYSIS + ) + + with open( + "./services/tests/samples/bundle_with_assets_and_modules.sqlite", "rb" + ) as f: + storage_path = StoragePaths.bundle_report.path( + repo_key=ArchiveService.get_archive_hash(self.repo), + report_key=head_commit_report.external_id, + ) + storage.write_file(get_bucket_name(), storage_path, f) + + query = """ + query FetchCommit($org: String!, $repo: String!, $commit: String!, $filters: BundleAnalysisReportFilters) { + owner(username: $org) { + repository(name: $repo) { + ... on Repository { + commit(id: $commit) { + bundleAnalysisReport { + __typename + ... on BundleAnalysisReport { + bundle(name: "b5", filters: $filters) { + moduleCount + assets { + name + } + } + } + } + } + } + } + } + } + """ + + variables = { + "org": self.org.username, + "repo": self.repo.name, + "commit": self.commit.commitid, + "filters": {}, + } + + configurations = [ + # No filters + ( + {"loadTypes": None, "reportGroups": None}, + {"asset_types": None, "chunk_entry": None, "chunk_initial": None}, + ), + ({}, {"asset_types": None, "chunk_entry": None, "chunk_initial": None}), + # Just report groups + ( + {"reportGroups": ["JAVASCRIPT", "FONT"]}, + { + "asset_types": ["JAVASCRIPT", "FONT"], + "chunk_entry": None, + "chunk_initial": None, + }, + ), + # Load types -> chunk_entry cancels out + ( + {"loadTypes": ["ENTRY", "INITIAL"]}, + {"asset_types": None, "chunk_entry": None, "chunk_initial": True}, + ), + # Load types -> chunk_entry = True + ( + {"loadTypes": ["ENTRY"]}, + {"asset_types": None, "chunk_entry": True, "chunk_initial": None}, + ), + # Load types -> chunk_lazy = False + ( + {"loadTypes": ["LAZY"]}, + {"asset_types": None, "chunk_entry": False, "chunk_initial": False}, + ), + # Load types -> chunk_initial cancels out + ( + {"loadTypes": ["LAZY", "INITIAL"]}, + {"asset_types": None, "chunk_entry": False, "chunk_initial": None}, + ), + # Load types -> chunk_initial = True + ( + {"loadTypes": ["INITIAL"]}, + {"asset_types": None, "chunk_entry": False, "chunk_initial": True}, + ), + # Load types -> chunk_initial = False + ( + {"loadTypes": ["LAZY"]}, + {"asset_types": None, "chunk_entry": False, "chunk_initial": False}, + ), + ] + + for config in configurations: + input_d, output_d = config + variables["filters"] = input_d + data = self.gql_request(query, variables=variables) + assert ( + data["owner"]["repository"]["commit"]["bundleAnalysisReport"]["bundle"] + is not None + ) + asset_reports_mock.assert_called_with(**output_d) + + def test_compare_with_parent_missing_change_coverage(self): + CommitComparisonFactory( + base_commit=self.parent_commit, + compare_commit=self.commit, + state=CommitComparison.CommitComparisonStates.PROCESSED, + ) + ReportLevelTotalsFactory( + report=CommitReportFactory(commit=self.parent_commit), + coverage=75.0, + files=0, + lines=0, + hits=0, + misses=0, + partials=0, + branches=0, + methods=0, + ) + + query = ( + query_commit % "compareWithParent { ... on Comparison { changeCoverage } }" + ) + variables = { + "org": self.org.username, + "repo": self.repo.name, + "commit": self.commit.commitid, + } + data = self.gql_request(query, variables=variables) + commit = data["owner"]["repository"]["commit"] + assert commit["compareWithParent"]["changeCoverage"] is None @patch( "services.profiling.ProfilingSummary.critical_files", new_callable=PropertyMock @@ -1639,8 +2358,8 @@ def test_fetch_commit_status_no_reports(self): } data = self.gql_request(query, variables=variables) commit = data["owner"]["repository"]["commit"] - assert commit["coverageStatus"] == None - assert commit["bundleStatus"] == None + assert commit["coverageStatus"] is None + assert commit["bundleStatus"] is None def test_fetch_commit_status_no_sessions(self): CommitReportFactory( @@ -1663,8 +2382,8 @@ def test_fetch_commit_status_no_sessions(self): } data = self.gql_request(query, variables=variables) commit = data["owner"]["repository"]["commit"] - assert commit["coverageStatus"] == None - assert commit["bundleStatus"] == None + assert commit["coverageStatus"] is None + assert commit["bundleStatus"] is None def test_fetch_commit_status_completed(self): coverage_report = CommitReportFactory( @@ -1767,3 +2486,102 @@ def test_fetch_commit_status_pending(self): commit = data["owner"]["repository"]["commit"] assert commit["coverageStatus"] == CommitStatus.PENDING.value assert commit["bundleStatus"] == CommitStatus.PENDING.value + + @patch("graphql_api.dataloader.bundle_analysis.get_appropriate_storage_service") + def test_bundle_analysis_report_gzip_size_total(self, get_storage_service): + storage = MemoryStorageService({}) + get_storage_service.return_value = storage + + head_commit_report = CommitReportFactory( + commit=self.commit, report_type=CommitReport.ReportType.BUNDLE_ANALYSIS + ) + + with open( + "./services/tests/samples/head_bundle_report_with_gzip_size.sqlite", "rb" + ) as f: + storage_path = StoragePaths.bundle_report.path( + repo_key=ArchiveService.get_archive_hash(self.repo), + report_key=head_commit_report.external_id, + ) + storage.write_file(get_bucket_name(), storage_path, f) + + query = """ + query FetchCommit($org: String!, $repo: String!, $commit: String!) { + owner(username: $org) { + repository(name: $repo) { + ... on Repository { + commit(id: $commit) { + bundleAnalysisReport { + __typename + ... on BundleAnalysisReport { + bundles { + name + bundleData { + size { + gzip + uncompress + } + } + } + } + } + } + } + } + } + } + """ + + variables = { + "org": self.org.username, + "repo": self.repo.name, + "commit": self.commit.commitid, + } + data = self.gql_request(query, variables=variables) + commit = data["owner"]["repository"]["commit"] + + assert commit["bundleAnalysisReport"] == { + "__typename": "BundleAnalysisReport", + "bundles": [ + { + # All assets non compressible + "name": "b1", + "bundleData": { + "size": { + "gzip": 20, + "uncompress": 20, + }, + }, + }, + { + # Some assets non compressible + "name": "b2", + "bundleData": { + "size": { + "gzip": 198, + "uncompress": 200, + }, + }, + }, + { + # All assets non compressible + "name": "b3", + "bundleData": { + "size": { + "gzip": 1495, + "uncompress": 1500, + }, + }, + }, + { + # All assets non compressible + "name": "b5", + "bundleData": { + "size": { + "gzip": 199995, + "uncompress": 200000, + }, + }, + }, + ], + } diff --git a/graphql_api/tests/test_components.py b/graphql_api/tests/test_components.py index 6c3782d0d7..45791c5b17 100644 --- a/graphql_api/tests/test_components.py +++ b/graphql_api/tests/test_components.py @@ -5,7 +5,6 @@ from django.test import TransactionTestCase, override_settings from django.utils import timezone from shared.reports.resources import Report, ReportFile, ReportLine -from shared.reports.types import ReportTotals from shared.utils.sessions import Session from codecov_auth.tests.factories import OwnerFactory @@ -112,7 +111,9 @@ def test_no_components(self): "repo": self.repo.name, "sha": self.commit.commitid, } - data = self.gql_request(query_commit_components, variables=variables) + data = self.gql_request( + query_commit_components, variables=variables, owner=OwnerFactory() + ) assert data == { "owner": { "repository": { @@ -148,7 +149,11 @@ def test_components(self, commit_components_mock, full_report_mock): "repo": self.repo.name, "sha": self.commit.commitid, } - data = self.gql_request(query_commit_components, variables=variables) + data = self.gql_request( + query_commit_components, + variables=variables, + owner=OwnerFactory(), + ) assert data == { "owner": { "repository": { @@ -229,7 +234,9 @@ def test_components_filtering(self, commit_components_mock, full_report_mock): "sha": self.commit.commitid, "filter": {"components": ["Python"]}, } - data = self.gql_request(query_commit_components, variables=variables) + data = self.gql_request( + query_commit_components, variables=variables, owner=OwnerFactory() + ) assert data == { "owner": { "repository": { @@ -289,7 +296,9 @@ def test_components_filtering(self, commit_components_mock, full_report_mock): "sha": self.commit.commitid, "filter": {"components": ["C", "Golang"]}, } - data = self.gql_request(query_commit_components, variables=variables) + data = self.gql_request( + query_commit_components, variables=variables, owner=OwnerFactory() + ) assert data == { "owner": { "repository": { @@ -349,7 +358,9 @@ def test_components_filtering_case_insensitive( "sha": self.commit.commitid, "filter": {"components": ["pYtHoN"]}, } - data = self.gql_request(query_commit_components, variables=variables) + data = self.gql_request( + query_commit_components, variables=variables, owner=OwnerFactory() + ) assert data == { "owner": { "repository": { diff --git a/graphql_api/tests/test_config.py b/graphql_api/tests/test_config.py index f0ca6cbf57..59378271bf 100644 --- a/graphql_api/tests/test_config.py +++ b/graphql_api/tests/test_config.py @@ -272,7 +272,7 @@ def test_self_hosted_license_returns_null_if_invalid_license(self, license_mock) ) assert data == { "config": { - "selfHostedLicense": {"expirationDate": None}, + "selfHostedLicense": None, }, } diff --git a/graphql_api/tests/test_current_user_ariadne.py b/graphql_api/tests/test_current_user_ariadne.py index ff832e90ce..5fc93b49c9 100644 --- a/graphql_api/tests/test_current_user_ariadne.py +++ b/graphql_api/tests/test_current_user_ariadne.py @@ -3,8 +3,12 @@ from unittest.mock import patch from django.test import TransactionTestCase +from shared.django_apps.codecov_auth.tests.factories import ( + AccountFactory, + OktaSettingsFactory, +) -from codecov_auth.models import Owner, OwnerProfile +from codecov_auth.models import OwnerProfile from codecov_auth.tests.factories import OwnerFactory, UserFactory from core.tests.factories import CommitFactory, RepositoryFactory @@ -183,7 +187,7 @@ def test_fetch_terms_agreement_and_business_email_when_owner_profile_is_null(sel assert data == {"me": {"businessEmail": None, "termsAgreement": False}} def test_fetch_null_terms_agreement_for_user_without_owner(self): - # There is an edge where a owner without user can call the me endpoint + # There is an edge where an owner without user can call the "me" endpoint # via impersonation, in that case return null for terms agreement owner_to_impersonate = OwnerFactory() owner_to_impersonate.user.delete() @@ -206,7 +210,23 @@ def test_fetch_null_terms_agreement_for_user_without_owner(self): def test_fetching_viewable_repositories(self): org_1 = OwnerFactory() org_2 = OwnerFactory() - current_user = OwnerFactory(organizations=[org_1.ownerid]) + + authed_account = AccountFactory() + OktaSettingsFactory(account=authed_account, enforced=True) + okta_enforced_authenticated = OwnerFactory(account=authed_account) + + unauthed_account = AccountFactory() + okta_enforced_org_unauth = OwnerFactory(account=unauthed_account) + OktaSettingsFactory(account=unauthed_account, enforced=True) + + current_user = OwnerFactory( + organizations=[ + org_1.ownerid, + okta_enforced_authenticated.ownerid, + okta_enforced_org_unauth.ownerid, + ] + ) + repos_in_db = [ RepositoryFactory(private=True, name="0"), RepositoryFactory(author=org_1, private=False, name="1"), @@ -216,8 +236,22 @@ def test_fetching_viewable_repositories(self): RepositoryFactory(private=True, name="5"), RepositoryFactory(author=current_user, private=True, name="6"), RepositoryFactory(author=current_user, private=False, name="7"), + RepositoryFactory( + author=okta_enforced_authenticated, + private=True, + name="okta_enforced_repo_authed", + ), + RepositoryFactory( + author=okta_enforced_org_unauth, + private=True, + name="okta_enforced_repo_unauthed", + ), + ] + current_user.permission = [ + repos_in_db[2].repoid, + repos_in_db[8].repoid, + repos_in_db[9].repoid, ] - current_user.permission = [repos_in_db[2].repoid] current_user.save() query = """{ me { @@ -231,15 +265,37 @@ def test_fetching_viewable_repositories(self): } } """ - data = self.gql_request(query, owner=current_user) + data = self.gql_request( + query, owner=current_user, okta_signed_in_accounts=[authed_account.id] + ) repos = paginate_connection(data["me"]["viewableRepositories"]) repos_name = [repo["name"] for repo in repos] - assert sorted(repos_name) == [ - "1", # public repo in org of user - "2", # private repo in org of user and in user permission - "6", # personal private repo - "7", # personal public repo - ] + assert ( + sorted(repos_name) + == [ + "1", # public repo in org of user + "2", # private repo in org of user and in user permission + "6", # personal private repo + "7", # personal public repo + "okta_enforced_repo_authed", # private repo in org with Okta Enforced permissions + ] + ) + + # Test with impersonation + data = self.gql_request(query, owner=current_user, impersonate_owner=True) + repos = paginate_connection(data["me"]["viewableRepositories"]) + repos_name = [repo["name"] for repo in repos] + assert ( + sorted(repos_name) + == [ + "1", # public repo in org of user + "2", # private repo in org of user and in user permission + "6", # personal private repo + "7", # personal public repo + "okta_enforced_repo_authed", # Okta repo should show up for impersonated users + "okta_enforced_repo_unauthed", # Okta repo should show up for impersonated users + ] + ) def test_fetching_viewable_repositories_ordering(self): current_user = OwnerFactory() @@ -259,7 +315,7 @@ def test_fetching_viewable_repositories_ordering(self): repo_1 = RepositoryFactory(author=current_user, name="A") repo_2 = RepositoryFactory(author=current_user, name="B") - repo_3 = RepositoryFactory(author=current_user, name="C") + RepositoryFactory(author=current_user, name="C") with self.subTest("No ordering (defaults to order by repoid)"): with self.subTest("no ordering Direction"): diff --git a/graphql_api/tests/test_flags.py b/graphql_api/tests/test_flags.py index 0a39226350..50229d90ed 100644 --- a/graphql_api/tests/test_flags.py +++ b/graphql_api/tests/test_flags.py @@ -1,16 +1,14 @@ -from datetime import datetime from unittest.mock import patch import pytest from django.conf import settings from django.test import TransactionTestCase, override_settings from django.utils import timezone -from freezegun import freeze_time from codecov_auth.tests.factories import OwnerFactory from core.tests.factories import CommitFactory, RepositoryFactory from reports.tests.factories import RepositoryFlagFactory -from timeseries.models import Dataset, MeasurementName +from timeseries.models import MeasurementName from timeseries.tests.factories import DatasetFactory, MeasurementFactory from .helper import GraphQLTestHelper @@ -102,11 +100,9 @@ def setUp(self): self.commit = CommitFactory(repository=self.repo) def test_fetch_flags_no_measurements(self): - flag1 = RepositoryFlagFactory(repository=self.repo, flag_name="flag1") - flag2 = RepositoryFlagFactory(repository=self.repo, flag_name="flag2") - flag3 = RepositoryFlagFactory( - repository=self.repo, flag_name="flag3", deleted=True - ) + RepositoryFlagFactory(repository=self.repo, flag_name="flag1") + RepositoryFlagFactory(repository=self.repo, flag_name="flag2") + RepositoryFlagFactory(repository=self.repo, flag_name="flag3", deleted=True) variables = { "org": self.org.username, "repo": self.repo.name, @@ -145,11 +141,9 @@ def test_fetch_flags_no_measurements(self): @override_settings(TIMESERIES_ENABLED=False) def test_fetch_flags_timeseries_not_enabled(self): - flag1 = RepositoryFlagFactory(repository=self.repo, flag_name="flag1") - flag2 = RepositoryFlagFactory(repository=self.repo, flag_name="flag2") - flag3 = RepositoryFlagFactory( - repository=self.repo, flag_name="flag3", deleted=True - ) + RepositoryFlagFactory(repository=self.repo, flag_name="flag1") + RepositoryFlagFactory(repository=self.repo, flag_name="flag2") + RepositoryFlagFactory(repository=self.repo, flag_name="flag3", deleted=True) variables = { "org": self.org.username, "repo": self.repo.name, @@ -189,9 +183,7 @@ def test_fetch_flags_timeseries_not_enabled(self): def test_fetch_flags_with_measurements(self): flag1 = RepositoryFlagFactory(repository=self.repo, flag_name="flag1") flag2 = RepositoryFlagFactory(repository=self.repo, flag_name="flag2") - flag3 = RepositoryFlagFactory( - repository=self.repo, flag_name="flag3", deleted=True - ) + RepositoryFlagFactory(repository=self.repo, flag_name="flag3", deleted=True) MeasurementFactory( name="flag_coverage", owner_id=self.org.pk, diff --git a/graphql_api/tests/test_impacted_file_deprecated.py b/graphql_api/tests/test_impacted_file_deprecated.py deleted file mode 100644 index 8629044368..0000000000 --- a/graphql_api/tests/test_impacted_file_deprecated.py +++ /dev/null @@ -1,715 +0,0 @@ -import hashlib -from dataclasses import dataclass -from unittest.mock import PropertyMock, patch - -from django.test import TransactionTestCase -from shared.torngit.exceptions import ( - TorngitClientGeneralError, - TorngitObjectNotFoundError, -) - -from codecov_auth.tests.factories import OwnerFactory -from compare.models import CommitComparison -from compare.tests.factories import CommitComparisonFactory -from core.tests.factories import CommitFactory, PullFactory, RepositoryFactory -from services.comparison import ComparisonReport, ImpactedFile, MissingComparisonReport - -from .helper import GraphQLTestHelper - -query_impacted_files = """ -query ImpactedFiles( - $org: String! - $repo: String! - $commit: String! -) { - owner(username: $org) { - repository(name: $repo) { - ... on Repository { - commit(id: $commit) { - compareWithParent { - ... on Comparison { - impactedFilesCount - indirectChangedFilesCount - impactedFilesDeprecated { - fileName - headName - baseName - isNewFile - isRenamedFile - isDeletedFile - isCriticalFile - baseCoverage { - percentCovered - } - headCoverage { - percentCovered - } - patchCoverage { - percentCovered - } - changeCoverage - missesCount - } - } - } - } - } - } - } - } -""" - -query_direct_changed_files_count = """ -query ImpactedFiles( - $org: String! - $repo: String! - $commit: String! -) { - owner(username: $org) { - repository(name: $repo) { - ... on Repository { - commit(id: $commit) { - compareWithParent { - ... on Comparison { - directChangedFilesCount - } - } - } - } - } - } -} -""" - -query_impacted_file_through_pull = """ -query ImpactedFile( - $org: String! - $repo: String! - $pull: Int! - $path: String! - $filters: SegmentsFilters -) { - owner(username: $org) { - repository(name: $repo) { - ... on Repository { - pull(id: $pull) { - compareWithBase { - ... on Comparison { - state - impactedFile(path: $path) { - headName - baseName - hashedPath - baseCoverage { - percentCovered - } - headCoverage { - percentCovered - } - patchCoverage { - percentCovered - } - segments(filters: $filters) { - ... on SegmentComparisons { - results { - hasUnintendedChanges - } - } - ... on ResolverError { - message - } - } - } - } - } - } - } - } - } -} -""" -mock_data_from_archive = """ -{ - "files": [{ - "head_name": "fileA", - "base_name": "fileA", - "head_coverage": { - "hits": 12, - "misses": 1, - "partials": 1, - "branches": 3, - "sessions": 0, - "complexity": 0, - "complexity_total": 0, - "methods": 5 - }, - "base_coverage": { - "hits": 5, - "misses": 6, - "partials": 1, - "branches": 2, - "sessions": 0, - "complexity": 0, - "complexity_total": 0, - "methods": 4 - }, - "added_diff_coverage": [ - [9,"h"], - [10,"m"] - ], - "unexpected_line_changes": [] - }, - { - "head_name": "fileB", - "base_name": "fileB", - "head_coverage": { - "hits": 12, - "misses": 1, - "partials": 1, - "branches": 3, - "sessions": 0, - "complexity": 0, - "complexity_total": 0, - "methods": 5 - }, - "base_coverage": { - "hits": 5, - "misses": 6, - "partials": 1, - "branches": 2, - "sessions": 0, - "complexity": 0, - "complexity_total": 0, - "methods": 4 - }, - "added_diff_coverage": [ - [9,"h"], - [10,"h"], - [13,"h"], - [14,"h"], - [15,"h"], - [16,"m"], - [17,"h"] - ], - "unexpected_line_changes": [[[1, "h"], [1, "m"]]] - }] -} -""" - - -@dataclass -class MockSegment: - has_diff_changes: bool = False - has_unintended_changes: bool = False - - -class MockFileComparison(object): - def __init__(self): - self.segments = [ - MockSegment(has_unintended_changes=True, has_diff_changes=False), - MockSegment(has_unintended_changes=False, has_diff_changes=True), - MockSegment(has_unintended_changes=True, has_diff_changes=True), - ] - - -class TestImpactedFileDeprecated(GraphQLTestHelper, TransactionTestCase): - def setUp(self): - self.org = OwnerFactory(username="codecov") - self.repo = RepositoryFactory(author=self.org, name="gazebo", private=False) - self.author = OwnerFactory() - self.parent_commit = CommitFactory(repository=self.repo) - self.commit = CommitFactory( - repository=self.repo, - totals={"c": "12", "diff": [0, 0, 0, 0, 0, "14"]}, - parent_commit_id=self.parent_commit.commitid, - ) - self.pull = PullFactory( - pullid=44, - repository=self.commit.repository, - head=self.commit.commitid, - base=self.parent_commit.commitid, - compared_to=self.parent_commit.commitid, - ) - self.comparison = CommitComparisonFactory( - base_commit=self.parent_commit, - compare_commit=self.commit, - state=CommitComparison.CommitComparisonStates.PROCESSED, - report_storage_path="v4/test.json", - ) - self.comparison_report = ComparisonReport(self.comparison) - - # mock reports for all tests in this class - self.head_report_patcher = patch( - "services.comparison.Comparison.head_report", new_callable=PropertyMock - ) - self.head_report = self.head_report_patcher.start() - self.head_report.return_value = None - self.addCleanup(self.head_report_patcher.stop) - self.base_report_patcher = patch( - "services.comparison.Comparison.base_report", new_callable=PropertyMock - ) - self.base_report = self.base_report_patcher.start() - self.base_report.return_value = None - self.addCleanup(self.base_report_patcher.stop) - - @patch("services.archive.ArchiveService.read_file") - def test_fetch_impacted_files(self, read_file): - read_file.return_value = mock_data_from_archive - variables = { - "org": self.org.username, - "repo": self.repo.name, - "commit": self.commit.commitid, - } - data = self.gql_request(query_impacted_files, variables=variables) - assert data == { - "owner": { - "repository": { - "commit": { - "compareWithParent": { - "impactedFilesCount": 2, - "indirectChangedFilesCount": 1, - "impactedFilesDeprecated": [ - { - "fileName": "fileA", - "headName": "fileA", - "baseName": "fileA", - "isNewFile": False, - "isRenamedFile": False, - "isDeletedFile": False, - "isCriticalFile": False, - "baseCoverage": { - "percentCovered": 41.666666666666664 - }, - "headCoverage": { - "percentCovered": 85.71428571428571 - }, - "patchCoverage": {"percentCovered": 50.0}, - "changeCoverage": 44.047619047619044, - "missesCount": 1, - }, - { - "fileName": "fileB", - "headName": "fileB", - "baseName": "fileB", - "isNewFile": False, - "isRenamedFile": False, - "isDeletedFile": False, - "isCriticalFile": False, - "baseCoverage": { - "percentCovered": 41.666666666666664 - }, - "headCoverage": { - "percentCovered": 85.71428571428571 - }, - "patchCoverage": { - "percentCovered": 85.71428571428571 - }, - "changeCoverage": 44.047619047619044, - "missesCount": 2, - }, - ], - } - } - } - } - } - - @patch("services.task.TaskService.compute_comparisons") - @patch("services.comparison.ComparisonReport.impacted_file") - @patch("services.comparison.Comparison.validate") - @patch("services.comparison.PullRequestComparison.get_file_comparison") - @patch("services.archive.ArchiveService.read_file") - def test_fetch_impacted_file_segments_without_comparison_in_context( - self, - read_file, - mock_get_file_comparison, - mock_compare_validate, - mock_impacted_file, - _, - ): - read_file.return_value = mock_data_from_archive - mock_get_file_comparison.return_value = MockFileComparison() - mock_compare_validate.return_value = True - mock_impacted_file.return_value = ImpactedFile( - **{ - "head_name": "fileB", - "base_name": "fileB", - "head_coverage": { - "hits": 12, - "misses": 1, - "partials": 1, - "branches": 3, - "sessions": 0, - "complexity": 0, - "complexity_total": 0, - "methods": 5, - }, - "base_coverage": { - "hits": 5, - "misses": 6, - "partials": 1, - "branches": 2, - "sessions": 0, - "complexity": 0, - "complexity_total": 0, - "methods": 4, - }, - "added_diff_coverage": [ - [9, "h"], - [10, "m"], - [13, "p"], - [14, "h"], - [15, "h"], - [16, "h"], - [17, "h"], - ], - "unexpected_line_changes": [[[1, "h"], [1, "h"]]], - } - ) - self.comparison.delete() - variables = { - "org": self.org.username, - "repo": self.repo.name, - "pull": self.pull.pullid, - "path": "fileB", - } - data = self.gql_request(query_impacted_file_through_pull, variables=variables) - assert data == { - "owner": { - "repository": { - "pull": { - "compareWithBase": { - "state": "pending", - "impactedFile": { - "headName": "fileB", - "baseName": "fileB", - "hashedPath": "eea3f37743bfd3409bec556ab26d4698", - "baseCoverage": {"percentCovered": None}, - "headCoverage": {"percentCovered": None}, - "patchCoverage": {"percentCovered": 71.42857142857143}, - "segments": {"results": []}, - }, - } - } - } - } - } - - @patch("services.comparison.Comparison.validate") - @patch("services.comparison.PullRequestComparison.get_file_comparison") - @patch("services.archive.ArchiveService.read_file") - def test_fetch_impacted_file_with_segments( - self, read_file, mock_get_file_comparison, mock_compare_validate - ): - read_file.return_value = mock_data_from_archive - - mock_get_file_comparison.return_value = MockFileComparison() - mock_compare_validate.return_value = True - variables = { - "org": self.org.username, - "repo": self.repo.name, - "pull": self.pull.pullid, - "path": "fileB", - } - data = self.gql_request(query_impacted_file_through_pull, variables=variables) - assert data == { - "owner": { - "repository": { - "pull": { - "compareWithBase": { - "state": "processed", - "impactedFile": { - "headName": "fileB", - "baseName": "fileB", - "hashedPath": hashlib.md5("fileB".encode()).hexdigest(), - "baseCoverage": {"percentCovered": 41.666666666666664}, - "headCoverage": {"percentCovered": 85.71428571428571}, - "patchCoverage": {"percentCovered": 85.71428571428571}, - "segments": { - "results": [ - {"hasUnintendedChanges": True}, - {"hasUnintendedChanges": False}, - {"hasUnintendedChanges": True}, - ], - }, - }, - } - } - } - } - } - - @patch("services.comparison.Comparison.validate") - @patch("services.comparison.PullRequestComparison.get_file_comparison") - @patch("services.archive.ArchiveService.read_file") - def test_fetch_impacted_file_segments_with_indirect_and_direct_changes( - self, read_file, mock_get_file_comparison, mock_compare_validate - ): - read_file.return_value = mock_data_from_archive - - mock_get_file_comparison.return_value = MockFileComparison() - mock_compare_validate.return_value = True - variables = { - "org": self.org.username, - "repo": self.repo.name, - "pull": self.pull.pullid, - "path": "fileA", - "filters": {"hasUnintendedChanges": True}, - } - data = self.gql_request(query_impacted_file_through_pull, variables=variables) - assert data == { - "owner": { - "repository": { - "pull": { - "compareWithBase": { - "state": "processed", - "impactedFile": { - "headName": "fileA", - "baseName": "fileA", - "hashedPath": "5e9f0c9689fb7ec181ea0fb09ad3f74e", - "baseCoverage": {"percentCovered": 41.666666666666664}, - "headCoverage": {"percentCovered": 85.71428571428571}, - "patchCoverage": {"percentCovered": 50.0}, - "segments": { - "results": [ - {"hasUnintendedChanges": True}, - {"hasUnintendedChanges": True}, - ] - }, - }, - } - } - } - } - } - - @patch("services.comparison.Comparison.validate") - @patch("services.comparison.PullRequestComparison.get_file_comparison") - @patch("services.archive.ArchiveService.read_file") - def test_fetch_impacted_file_with_segments_unknown_path( - self, read_file, mock_get_file_comparison, mock_compare_validate - ): - read_file.return_value = mock_data_from_archive - mock_get_file_comparison.side_effect = TorngitObjectNotFoundError(None, None) - mock_compare_validate.return_value = True - - variables = { - "org": self.org.username, - "repo": self.repo.name, - "pull": self.pull.pullid, - "path": "fileA", - } - data = self.gql_request(query_impacted_file_through_pull, variables=variables) - assert data == { - "owner": { - "repository": { - "pull": { - "compareWithBase": { - "state": "processed", - "impactedFile": { - "headName": "fileA", - "baseName": "fileA", - "hashedPath": hashlib.md5("fileA".encode()).hexdigest(), - "baseCoverage": {"percentCovered": 41.666666666666664}, - "headCoverage": {"percentCovered": 85.71428571428571}, - "patchCoverage": {"percentCovered": 50.0}, - "segments": {"message": "path does not exist: fileA"}, - }, - } - } - } - } - } - - @patch("services.comparison.Comparison.validate") - @patch("services.comparison.PullRequestComparison.get_file_comparison") - @patch("services.archive.ArchiveService.read_file") - def test_fetch_impacted_file_with_segments_provider_error( - self, read_file, mock_get_file_comparison, mock_compare_validate - ): - read_file.return_value = mock_data_from_archive - mock_get_file_comparison.side_effect = TorngitClientGeneralError( - 500, None, None - ) - mock_compare_validate.return_value = True - - variables = { - "org": self.org.username, - "repo": self.repo.name, - "pull": self.pull.pullid, - "path": "fileA", - } - data = self.gql_request(query_impacted_file_through_pull, variables=variables) - assert data == { - "owner": { - "repository": { - "pull": { - "compareWithBase": { - "state": "processed", - "impactedFile": { - "headName": "fileA", - "baseName": "fileA", - "hashedPath": hashlib.md5("fileA".encode()).hexdigest(), - "baseCoverage": {"percentCovered": 41.666666666666664}, - "headCoverage": {"percentCovered": 85.71428571428571}, - "patchCoverage": {"percentCovered": 50.0}, - "segments": { - "message": "Error fetching data from the provider" - }, - }, - } - } - } - } - } - - @patch("services.comparison.Comparison.validate") - @patch("services.comparison.PullRequestComparison.get_file_comparison") - @patch("services.archive.ArchiveService.read_file") - def test_fetch_impacted_file_with_invalid_comparison( - self, read_file, mock_get_file_comparison, mock_compare_validate - ): - read_file.return_value = mock_data_from_archive - - mock_get_file_comparison.return_value = MockFileComparison() - mock_compare_validate.side_effect = MissingComparisonReport() - variables = { - "org": self.org.username, - "repo": self.repo.name, - "pull": self.pull.pullid, - "path": "fileA", - "filters": {"hasUnintendedChanges": False}, - } - data = self.gql_request(query_impacted_file_through_pull, variables=variables) - assert data == { - "owner": { - "repository": { - "pull": { - "compareWithBase": { - "state": "processed", - "impactedFile": { - "headName": "fileA", - "baseName": "fileA", - "hashedPath": "5e9f0c9689fb7ec181ea0fb09ad3f74e", - "baseCoverage": {"percentCovered": 41.666666666666664}, - "headCoverage": {"percentCovered": 85.71428571428571}, - "patchCoverage": {"percentCovered": 50.0}, - "segments": {"results": []}, - }, - } - } - } - } - } - - @patch("services.comparison.Comparison.validate") - @patch("services.comparison.PullRequestComparison.get_file_comparison") - @patch("services.archive.ArchiveService.read_file") - def test_fetch_impacted_file_segments_with_direct_and_indirect_changes( - self, read_file, mock_get_file_comparison, mock_compare_validate - ): - read_file.return_value = mock_data_from_archive - - mock_get_file_comparison.return_value = MockFileComparison() - mock_compare_validate.return_value = True - variables = { - "org": self.org.username, - "repo": self.repo.name, - "pull": self.pull.pullid, - "path": "fileA", - "filters": {"hasUnintendedChanges": False}, - } - data = self.gql_request(query_impacted_file_through_pull, variables=variables) - assert data == { - "owner": { - "repository": { - "pull": { - "compareWithBase": { - "state": "processed", - "impactedFile": { - "headName": "fileA", - "baseName": "fileA", - "hashedPath": "5e9f0c9689fb7ec181ea0fb09ad3f74e", - "baseCoverage": {"percentCovered": 41.666666666666664}, - "headCoverage": {"percentCovered": 85.71428571428571}, - "patchCoverage": {"percentCovered": 50.0}, - "segments": { - "results": [ - {"hasUnintendedChanges": False}, - {"hasUnintendedChanges": True}, - ] - }, - }, - } - } - } - } - } - - @patch("services.comparison.Comparison.validate") - @patch("services.comparison.PullRequestComparison.get_file_comparison") - @patch("services.archive.ArchiveService.read_file") - def test_fetch_impacted_file_without_segments_filter( - self, read_file, mock_get_file_comparison, mock_compare_validate - ): - read_file.return_value = mock_data_from_archive - - mock_get_file_comparison.return_value = MockFileComparison() - mock_compare_validate.return_value = True - variables = { - "org": self.org.username, - "repo": self.repo.name, - "pull": self.pull.pullid, - "path": "fileA", - } - data = self.gql_request(query_impacted_file_through_pull, variables=variables) - assert data == { - "owner": { - "repository": { - "pull": { - "compareWithBase": { - "state": "processed", - "impactedFile": { - "headName": "fileA", - "baseName": "fileA", - "hashedPath": "5e9f0c9689fb7ec181ea0fb09ad3f74e", - "baseCoverage": {"percentCovered": 41.666666666666664}, - "headCoverage": {"percentCovered": 85.71428571428571}, - "patchCoverage": {"percentCovered": 50.0}, - "segments": { - "results": [ - {"hasUnintendedChanges": True}, - {"hasUnintendedChanges": False}, - {"hasUnintendedChanges": True}, - ] - }, - }, - } - } - } - } - } - - @patch("services.archive.ArchiveService.read_file") - def test_fetch_direct_changed_files_count(self, read_file): - read_file.return_value = mock_data_from_archive - variables = { - "org": self.org.username, - "repo": self.repo.name, - "commit": self.commit.commitid, - } - data = self.gql_request( - query_direct_changed_files_count, - variables=variables, - ) - assert data == { - "owner": { - "repository": { - "commit": { - "compareWithParent": { - "directChangedFilesCount": 2, - } - } - } - } - } diff --git a/graphql_api/tests/test_invoice.py b/graphql_api/tests/test_invoice.py index ca8d7239fc..000370ba0b 100644 --- a/graphql_api/tests/test_invoice.py +++ b/graphql_api/tests/test_invoice.py @@ -69,6 +69,10 @@ def test_invoices_returns_100_recent_invoices(self, mock_list_filtered_invoices) phone } } + taxIds { + type + value + } } } } @@ -101,6 +105,7 @@ def test_invoices_returns_100_recent_invoices(self, mock_list_filtered_invoices) "subtotal": 999, "total": 999, "defaultPaymentMethod": None, + "taxIds": [], } @patch("services.billing.stripe.Invoice.retrieve") @@ -156,6 +161,10 @@ def test_invoice_returns_invoice_by_id(self, mock_retrieve_invoice): phone } } + taxIds { + type + value + } } } } @@ -188,6 +197,7 @@ def test_invoice_returns_invoice_by_id(self, mock_retrieve_invoice): "subtotal": 999, "total": 999, "defaultPaymentMethod": None, + "taxIds": [], } @patch("services.billing.stripe.Invoice.retrieve") @@ -239,6 +249,10 @@ def test_invoice_returns_none_if_no_invoices(self, mock_retrieve_invoice): phone } } + taxIds { + type + value + } } } } diff --git a/graphql_api/tests/test_okta_config.py b/graphql_api/tests/test_okta_config.py new file mode 100644 index 0000000000..227096199e --- /dev/null +++ b/graphql_api/tests/test_okta_config.py @@ -0,0 +1,180 @@ +from django.test import TransactionTestCase + +from codecov_auth.tests.factories import ( + AccountFactory, + OktaSettingsFactory, + OwnerFactory, +) + +from .helper import GraphQLTestHelper + + +class OktaConfigTestCase(GraphQLTestHelper, TransactionTestCase): + def setUp(self): + self.account = AccountFactory(name="Test Account") + self.owner = OwnerFactory( + username="randomOwner", service="github", account=self.account + ) + self.okta_settings = OktaSettingsFactory( + account=self.account, + client_id="test-client-id", + client_secret="test-client-secret", + ) + + def test_fetch_enabled_okta_config(self) -> None: + query = """ + query { + owner(username: "%s"){ + account { + oktaConfig { + enabled + } + } + } + } + """ % (self.owner.username) + + result = self.gql_request(query, owner=self.owner) + + assert "errors" not in result + assert result["owner"]["account"]["oktaConfig"]["enabled"] == True + + def test_fetch_disabled_okta_config(self) -> None: + self.okta_settings.enabled = False + self.okta_settings.save() + query = """ + query { + owner(username: "%s"){ + account { + oktaConfig { + enabled + } + } + } + } + """ % (self.owner.username) + + result = self.gql_request(query, owner=self.owner) + + assert "errors" not in result + assert result["owner"]["account"]["oktaConfig"]["enabled"] == False + + def test_fetch_enforced_okta_config(self) -> None: + query = """ + query { + owner(username: "%s"){ + account { + oktaConfig { + enforced + } + } + } + } + """ % (self.owner.username) + + result = self.gql_request(query, owner=self.owner) + + assert "errors" not in result + assert result["owner"]["account"]["oktaConfig"]["enforced"] == False + + def test_fetch_enforced_okta_config_true(self) -> None: + self.okta_settings.enforced = True + self.okta_settings.save() + query = """ + query { + owner(username: "%s"){ + account { + oktaConfig { + enforced + } + } + } + } + """ % (self.owner.username) + + result = self.gql_request(query, owner=self.owner) + + assert "errors" not in result + assert result["owner"]["account"]["oktaConfig"]["enforced"] == True + + def test_fetch_url_okta_config(self) -> None: + query = """ + query { + owner(username: "%s"){ + account{ + oktaConfig { + url + } + } + } + } + """ % (self.owner.username) + + result = self.gql_request(query, owner=self.owner) + + assert "errors" not in result + assert result["owner"]["account"]["oktaConfig"]["url"] == self.okta_settings.url + + def test_fetch_okta_config_client_id(self) -> None: + query = """ + query { + owner(username: "%s"){ + account{ + oktaConfig { + clientId + } + } + } + } + """ % (self.owner.username) + + result = self.gql_request(query, owner=self.owner) + + assert "errors" not in result + assert ( + result["owner"]["account"]["oktaConfig"]["clientId"] + == self.okta_settings.client_id + ) + + def test_fetch_okta_config_client_secret(self) -> None: + query = """ + query { + owner(username: "%s"){ + account{ + oktaConfig { + clientSecret + } + } + } + } + """ % (self.owner.username) + + result = self.gql_request(query, owner=self.owner) + + assert "errors" not in result + assert ( + result["owner"]["account"]["oktaConfig"]["clientSecret"] + == self.okta_settings.client_secret + ) + + def test_fetch_non_existent_okta_config(self) -> None: + self.okta_settings.delete() + + query = """ + query { + owner(username: "%s"){ + account{ + oktaConfig { + clientId + clientSecret + url + } + } + } + } + """ % (self.owner.username) + + result = self.gql_request(query, owner=self.owner) + + assert "errors" not in result + assert result["owner"]["account"]["oktaConfig"] is None diff --git a/graphql_api/tests/test_onboarding.py b/graphql_api/tests/test_onboarding.py index c4907ebef8..965e32d22a 100644 --- a/graphql_api/tests/test_onboarding.py +++ b/graphql_api/tests/test_onboarding.py @@ -1,7 +1,7 @@ from django.test import TransactionTestCase from codecov_auth.tests.factories import OwnerFactory -from graphql_api.tests.helper import GraphQLTestHelper, paginate_connection +from graphql_api.tests.helper import GraphQLTestHelper class OnboardingTest(GraphQLTestHelper, TransactionTestCase): diff --git a/graphql_api/tests/test_owner.py b/graphql_api/tests/test_owner.py index 1fe4dd8c27..916efced30 100644 --- a/graphql_api/tests/test_owner.py +++ b/graphql_api/tests/test_owner.py @@ -7,12 +7,14 @@ from freezegun import freeze_time from graphql import GraphQLError from prometheus_client import REGISTRY +from shared.django_apps.codecov_auth.tests.factories import OktaSettingsFactory from shared.django_apps.reports.models import ReportType from shared.upload.utils import UploaderType, insert_coverage_measurement from codecov.commands.exceptions import MissingService, UnauthorizedGuestAccess from codecov_auth.models import OwnerProfile from codecov_auth.tests.factories import ( + AccountFactory, GetAdminProviderAdapter, OwnerFactory, UserFactory, @@ -25,6 +27,7 @@ query_repositories = """{ owner(username: "%s") { + delinquent orgUploadToken ownerid isCurrentUserPartOfOrg @@ -48,24 +51,21 @@ class TestOwnerType(GraphQLTestHelper, TransactionTestCase): def setUp(self): - self.owner = OwnerFactory(username="codecov-user", service="github") + self.account = AccountFactory() + self.owner = OwnerFactory( + username="codecov-user", service="github", account=self.account + ) + self.okta_settings = OktaSettingsFactory(account=self.account, enforced=True) random_user = OwnerFactory(username="random-user", service="github") RepositoryFactory( author=self.owner, active=True, activated=True, private=True, name="a" ) RepositoryFactory( - author=self.owner, active=False, private=False, activated=False, name="b" + author=self.owner, active=False, activated=False, private=False, name="b" ) RepositoryFactory( author=random_user, active=True, activated=False, private=True, name="not" ) - RepositoryFactory( - author=random_user, - active=True, - private=False, - activated=True, - name="still-not", - ) def test_fetching_repositories(self): before = REGISTRY.get_sample_value( @@ -81,9 +81,12 @@ def test_fetching_repositories(self): labels={"operation_type": "unknown_type", "operation_name": "owner"}, ) query = query_repositories % (self.owner.username, "", "") - data = self.gql_request(query, owner=self.owner) + data = self.gql_request( + query, owner=self.owner, okta_signed_in_accounts=[self.account.id] + ) assert data == { "owner": { + "delinquent": None, "orgUploadToken": None, "ownerid": self.owner.ownerid, "isCurrentUserPartOfOrg": True, @@ -114,7 +117,9 @@ def test_fetching_repositories(self): def test_fetching_repositories_with_pagination(self): query = query_repositories % (self.owner.username, "(first: 1)", "endCursor") # Check on the first page if we have the repository b - data_page_one = self.gql_request(query, owner=self.owner) + data_page_one = self.gql_request( + query, owner=self.owner, okta_signed_in_accounts=[self.account.id] + ) connection = data_page_one["owner"]["repositories"] assert connection["edges"][0]["node"] == {"name": "a"} pageInfo = connection["pageInfo"] @@ -126,7 +131,9 @@ def test_fetching_repositories_with_pagination(self): f'(first: 1, after: "{next_cursor}")', "endCursor", ) - data_page_two = self.gql_request(query, owner=self.owner) + data_page_two = self.gql_request( + query, owner=self.owner, okta_signed_in_accounts=[self.account.id] + ) connection = data_page_two["owner"]["repositories"] assert connection["edges"][0]["node"] == {"name": "b"} pageInfo = connection["pageInfo"] @@ -138,7 +145,9 @@ def test_fetching_active_repositories(self): "(filters: { active: true })", "", ) - data = self.gql_request(query, owner=self.owner) + data = self.gql_request( + query, owner=self.owner, okta_signed_in_accounts=[self.account.id] + ) repos = paginate_connection(data["owner"]["repositories"]) assert repos == [{"name": "a"}] @@ -148,7 +157,9 @@ def test_fetching_repositories_by_name(self): '(filters: { term: "a" })', "", ) - data = self.gql_request(query, owner=self.owner) + data = self.gql_request( + query, owner=self.owner, okta_signed_in_accounts=[self.account.id] + ) repos = paginate_connection(data["owner"]["repositories"]) assert repos == [{"name": "a"}] @@ -164,7 +175,9 @@ def test_fetching_repositories_with_ordering(self): "(ordering: NAME, orderingDirection: DESC)", "", ) - data = self.gql_request(query, owner=self.owner) + data = self.gql_request( + query, owner=self.owner, okta_signed_in_accounts=[self.account.id] + ) repos = paginate_connection(data["owner"]["repositories"]) assert repos == [{"name": "b"}, {"name": "a"}] @@ -178,13 +191,27 @@ def test_fetching_repositories_inactive_repositories(self): repos = paginate_connection(data["owner"]["repositories"]) assert repos == [{"name": "b"}] + def test_fetch_account(self) -> None: + query = """{ + owner(username: "%s") { + account { + name + } + } + } + """ % (self.owner.username) + data = self.gql_request(query, owner=self.owner) + assert data["owner"]["account"]["name"] == self.account.name + def test_fetching_repositories_active_repositories(self): query = query_repositories % ( self.owner.username, "(filters: { active: true })", "", ) - data = self.gql_request(query, owner=self.owner) + data = self.gql_request( + query, owner=self.owner, okta_signed_in_accounts=[self.account.id] + ) repos = paginate_connection(data["owner"]["repositories"]) assert repos == [{"name": "a"}] @@ -194,7 +221,9 @@ def test_fetching_repositories_activated_repositories(self): "(filters: { activated: true })", "", ) - data = self.gql_request(query, owner=self.owner) + data = self.gql_request( + query, owner=self.owner, okta_signed_in_accounts=[self.account.id] + ) repos = paginate_connection(data["owner"]["repositories"]) assert repos == [{"name": "a"}] @@ -208,6 +237,22 @@ def test_fetching_repositories_deactivated_repositories(self): repos = paginate_connection(data["owner"]["repositories"]) assert repos == [{"name": "b"}] + def test_fetching_repositories_filter_out_okta_enforced(self): + query = query_repositories % ( + self.owner.username, + '(filters: { term: "a" })', + "", + ) + data = self.gql_request(query, owner=self.owner) + repos = paginate_connection(data["owner"]["repositories"]) + assert repos == [] + + def test_fetching_repositories_impersonation_show_okta_enforced(self): + query = query_repositories % (self.owner.username, "", "") + data = self.gql_request(query, owner=self.owner, impersonate_owner=True) + repos = paginate_connection(data["owner"]["repositories"]) + assert repos == [{"name": "a"}, {"name": "b"}] + def test_is_part_of_org_when_unauthenticated(self): query = query_repositories % (self.owner.username, "", "") data = self.gql_request(query) @@ -347,6 +392,11 @@ def test_ownerid(self): data = self.gql_request(query, owner=self.owner) assert data["owner"]["ownerid"] == self.owner.ownerid + def test_delinquent(self): + query = query_repositories % (self.owner.username, "", "") + data = self.gql_request(query, owner=self.owner) + assert data["owner"]["delinquent"] == self.owner.delinquent + @patch("codecov_auth.commands.owner.owner.OwnerCommands.get_org_upload_token") def test_get_org_upload_token(self, mocker): mocker.return_value = "upload_token" @@ -366,7 +416,7 @@ def test_when_owner_profile_doesnt_exist(self): } """ % (owner.username) data = self.gql_request(query, owner=owner) - assert data["owner"]["defaultOrgUsername"] == None + assert data["owner"]["defaultOrgUsername"] is None def test_get_default_org_username_for_owner(self): organization = OwnerFactory(username="sample-org", service="github") @@ -398,7 +448,7 @@ def test_owner_without_default_org_returns_null(self): } """ % (owner.username) data = self.gql_request(query, owner=owner) - assert data["owner"]["defaultOrgUsername"] == None + assert data["owner"]["defaultOrgUsername"] is None def test_owner_without_owner_profile_returns_no_default_org(self): owner = OwnerFactory(username="sample-owner", service="github") @@ -410,7 +460,7 @@ def test_owner_without_owner_profile_returns_no_default_org(self): } """ % (owner.username) data = self.gql_request(query, owner=owner) - assert data["owner"]["defaultOrgUsername"] == None + assert data["owner"]["defaultOrgUsername"] is None def test_is_current_user_not_activated(self): owner = OwnerFactory(username="sample-owner", service="github") @@ -689,3 +739,97 @@ def test_fetch_owner_on_unauthenticated_enteprise_guest_access(self): except GraphQLError as e: assert e.message == UnauthorizedGuestAccess.message assert e.extensions["code"] == UnauthorizedGuestAccess.code + + def test_fetch_current_user_is_okta_authenticated(self): + account = AccountFactory() + owner = OwnerFactory(username="sample-owner", service="github", account=account) + owner.save() + + user = OwnerFactory(username="sample-user") + user.organizations = [owner.ownerid] + user.save() + + query = """{ + owner(username: "%s") { + isUserOktaAuthenticated + } + } + """ % (owner.username) + + data = self.gql_request(query, owner=user, okta_signed_in_accounts=[account.pk]) + assert data["owner"]["isUserOktaAuthenticated"] == True + + def test_fetch_current_user_is_not_okta_authenticated(self): + account = AccountFactory() + owner = OwnerFactory(username="sample-owner", service="github", account=account) + owner.save() + + user = OwnerFactory(username="sample-user") + user.organizations = [owner.ownerid] + user.save() + + query = """{ + owner(username: "%s") { + isUserOktaAuthenticated + } + } + """ % (owner.username) + + data = self.gql_request(query, owner=user, okta_signed_in_accounts=[]) + assert data["owner"]["isUserOktaAuthenticated"] == False + + def test_fetch_current_user_is_not_okta_authenticated_no_account(self): + owner = OwnerFactory(username="sample-owner", service="github") + owner.save() + + user = OwnerFactory(username="sample-user") + user.organizations = [owner.ownerid] + user.save() + + query = """{ + owner(username: "%s") { + isUserOktaAuthenticated + } + } + """ % (owner.username) + + data = self.gql_request(query, owner=user, okta_signed_in_accounts=[]) + assert data["owner"]["isUserOktaAuthenticated"] == False + + @patch("shared.rate_limits.determine_entity_redis_key") + @patch("shared.rate_limits.determine_if_entity_is_rate_limited") + @override_settings(IS_ENTERPRISE=True, GUEST_ACCESS=False) + def test_fetch_is_github_rate_limited( + self, mock_determine_rate_limit, mock_determine_redis_key + ): + current_org = OwnerFactory( + username="random-plan-user", + service="github", + ) + query = """{ + owner(username: "%s") { + isGithubRateLimited + } + } + + """ % (current_org.username) + mock_determine_redis_key.return_value = "test" + mock_determine_rate_limit.return_value = True + + data = self.gql_request(query, owner=current_org) + assert data["owner"]["isGithubRateLimited"] == True + + def test_fetch_is_github_rate_limited_not_on_gh_service(self): + current_org = OwnerFactory( + username="random-plan-user", + service="bitbucket", + ) + query = """{ + owner(username: "%s") { + isGithubRateLimited + } + } + + """ % (current_org.username) + data = self.gql_request(query, owner=current_org, provider="bb") + assert data["owner"]["isGithubRateLimited"] == False diff --git a/graphql_api/tests/test_path_content.py b/graphql_api/tests/test_path_content.py index 7b258c17ff..8023eb32b5 100644 --- a/graphql_api/tests/test_path_content.py +++ b/graphql_api/tests/test_path_content.py @@ -53,7 +53,7 @@ def test_returns_path_content_dir(self): def test_returns_none(self): type = resolve_path_content_type("string") - assert type == None + assert type is None class TestIsCriticalFile(TransactionTestCase): diff --git a/graphql_api/tests/test_plan.py b/graphql_api/tests/test_plan.py index 3a586afa29..e808c4a7cd 100644 --- a/graphql_api/tests/test_plan.py +++ b/graphql_api/tests/test_plan.py @@ -1,8 +1,13 @@ -from datetime import timedelta +from datetime import datetime, timedelta +from unittest.mock import patch +import pytest from django.test import TransactionTestCase from django.utils import timezone from freezegun import freeze_time +from shared.django_apps.codecov_auth.tests.factories import AccountFactory +from shared.license import LicenseInformation +from shared.utils.test_utils import mock_config_helper from codecov_auth.tests.factories import OwnerFactory from plan.constants import PlanName, TrialStatus @@ -11,6 +16,10 @@ class TestPlanType(GraphQLTestHelper, TransactionTestCase): + @pytest.fixture(scope="function", autouse=True) + def inject_mocker(request, mocker): + request.mocker = mocker + def setUp(self): self.current_org = OwnerFactory( username="random-plan-user", @@ -39,6 +48,7 @@ def test_owner_plan_data_when_trialing(self): trialStatus trialEndDate trialStartDate + trialTotalDays marketingName planName value @@ -58,6 +68,7 @@ def test_owner_plan_data_when_trialing(self): "trialStatus": "ONGOING", "trialEndDate": "2023-07-03T00:00:00", "trialStartDate": "2023-06-19T00:00:00", + "trialTotalDays": None, "marketingName": "Developer", "planName": "users-trial", "value": "users-trial", @@ -75,6 +86,37 @@ def test_owner_plan_data_when_trialing(self): "planUserCount": 123, } + def test_owner_plan_data_with_account(self): + self.current_org.account = AccountFactory( + plan=PlanName.CODECOV_PRO_YEARLY.value, + plan_seat_count=25, + ) + self.current_org.save() + query = """{ + owner(username: "%s") { + plan { + marketingName + planName + value + tierName + billingRate + baseUnitPrice + planUserCount + } + } + } + """ % (self.current_org.username) + data = self.gql_request(query, owner=self.current_org) + assert data["owner"]["plan"] == { + "marketingName": "Pro", + "planName": "users-pr-inappy", + "value": "users-pr-inappy", + "tierName": "pro", + "billingRate": "annually", + "baseUnitPrice": 10, + "planUserCount": 25, + } + def test_owner_plan_data_has_seats_left(self): current_org = OwnerFactory( username="random-plan-user", @@ -94,3 +136,90 @@ def test_owner_plan_data_has_seats_left(self): """ % (current_org.username) data = self.gql_request(query, owner=current_org) assert data["owner"]["plan"] == {"hasSeatsLeft": True} + + @patch("services.self_hosted.get_current_license") + def test_plan_user_count_for_enterprise_org(self, mocked_license): + """ + If an Org has an enterprise license, number_allowed_users from their license + should be used instead of plan_user_count on the Org object. + """ + mock_enterprise_license = LicenseInformation( + is_valid=True, + message=None, + url="https://codeov.mysite.com", + number_allowed_users=5, + number_allowed_repos=10, + expires=datetime.strptime("2020-05-09 00:00:00", "%Y-%m-%d %H:%M:%S"), + is_trial=False, + is_pr_billing=True, + ) + mocked_license.return_value = mock_enterprise_license + mock_config_helper( + self.mocker, configs={"setup.enterprise_license": mock_enterprise_license} + ) + + enterprise_org = OwnerFactory( + username="random-plan-user", + service="github", + plan=PlanName.CODECOV_PRO_YEARLY.value, + plan_user_count=1, + plan_activated_users=[], + ) + for i in range(4): + new_owner = OwnerFactory() + enterprise_org.plan_activated_users.append(new_owner.ownerid) + enterprise_org.save() + + other_org_in_enterprise = OwnerFactory( + service="github", + plan=PlanName.CODECOV_PRO_YEARLY.value, + plan_user_count=1, + plan_activated_users=[], + ) + for i in range(4): + new_owner = OwnerFactory() + other_org_in_enterprise.plan_activated_users.append(new_owner.ownerid) + other_org_in_enterprise.save() + + query = """{ + owner(username: "%s") { + plan { + planUserCount + hasSeatsLeft + } + } + } + """ % (enterprise_org.username) + data = self.gql_request(query, owner=enterprise_org) + assert data["owner"]["plan"]["planUserCount"] == 5 + assert data["owner"]["plan"]["hasSeatsLeft"] == False + + @patch("services.self_hosted.get_current_license") + def test_plan_user_count_for_enterprise_org_invaild_license(self, mocked_license): + mock_enterprise_license = LicenseInformation( + is_valid=False, + ) + mocked_license.return_value = mock_enterprise_license + mock_config_helper( + self.mocker, configs={"setup.enterprise_license": mock_enterprise_license} + ) + + enterprise_org = OwnerFactory( + username="random-plan-user", + service="github", + plan=PlanName.CODECOV_PRO_YEARLY.value, + plan_user_count=1, + plan_activated_users=[], + ) + query = """{ + owner(username: "%s") { + plan { + planUserCount + hasSeatsLeft + } + } + } + """ % (enterprise_org.username) + data = self.gql_request(query, owner=enterprise_org) + assert data["owner"]["plan"]["planUserCount"] == 0 + assert data["owner"]["plan"]["hasSeatsLeft"] == False diff --git a/graphql_api/tests/test_pull.py b/graphql_api/tests/test_pull.py index 88fb65e1a7..c3ebf7f968 100644 --- a/graphql_api/tests/test_pull.py +++ b/graphql_api/tests/test_pull.py @@ -94,7 +94,11 @@ bundleAnalysisCompareWithBase { __typename ... on BundleAnalysisComparison { - sizeDelta + bundleData { + size { + uncompress + } + } } } behindBy @@ -105,7 +109,11 @@ bundleAnalysisCompareWithBase { __typename ... on BundleAnalysisComparison { - sizeDelta + bundleData { + size { + uncompress + } + } } } """ @@ -152,8 +160,9 @@ def test_fetch_list_pull_request(self): assert pull_1.title in pull_titles assert pull_2.title in pull_titles - @freeze_time("2021-02-02") - def test_when_repository_has_null_compared_to(self): + @freeze_time("2021-02-02 00:00:00") + @patch("core.commands.pull.interactors.fetch_pull_request.TaskService") + def test_when_repository_has_null_compared_to(self, mock_task_service): my_pull = PullFactory( repository=self.repository, title="test-null-base", @@ -165,9 +174,10 @@ def test_when_repository_has_null_compared_to(self): ).commitid, compared_to=None, ) - pull = self.fetch_one_pull_request( - my_pull.pullid, pull_request_detail_query_with_bundle_analysis - ) + with freeze_time("2021-02-02 06:00:00"): + pull = self.fetch_one_pull_request( + my_pull.pullid, pull_request_detail_query_with_bundle_analysis + ) assert pull == { "title": "test-null-base", "state": "OPEN", @@ -185,9 +195,13 @@ def test_when_repository_has_null_compared_to(self): "behindBy": None, "behindByCommit": None, } + mock_task_service.return_value.pulls_sync.assert_called_with( + my_pull.repository.repoid, my_pull.pullid + ) - @freeze_time("2021-02-02") - def test_when_repository_has_null_author(self): + @freeze_time("2021-02-02 00:00:00") + @patch("core.commands.pull.interactors.fetch_pull_request.TaskService") + def test_when_repository_has_null_author(self, mock_task_service): PullFactory( repository=self.repository, title="dummy-first-pr", @@ -219,9 +233,11 @@ def test_when_repository_has_null_author(self): "behindBy": None, "behindByCommit": None, } + mock_task_service.return_value.pulls_sync.assert_not_called() @freeze_time("2021-02-02") - def test_when_repository_has_null_head(self): + @patch("core.commands.pull.interactors.fetch_pull_request.TaskService") + def test_when_repository_has_null_head_no_parent_report(self, mock_task_service): PullFactory( repository=self.repository, title="dummy-first-pr", @@ -248,11 +264,83 @@ def test_when_repository_has_null_head(self): "__typename": "MissingHeadCommit", }, "bundleAnalysisCompareWithBase": { - "__typename": "MissingHeadCommit", + "__typename": "MissingHeadReport", }, "behindBy": None, "behindByCommit": None, } + mock_task_service.return_value.pulls_sync.assert_not_called() + + @patch("graphql_api.dataloader.bundle_analysis.get_appropriate_storage_service") + def test_when_repository_has_null_head_has_parent_report(self, get_storage_service): + os.system("rm -rf /tmp/bundle_analysis_*") + storage = MemoryStorageService({}) + get_storage_service.return_value = storage + + parent_commit = CommitFactory(repository=self.repository) + + base_commit_report = CommitReportFactory( + commit=parent_commit, + report_type=CommitReport.ReportType.BUNDLE_ANALYSIS, + ) + + my_pull = PullFactory( + repository=self.repository, + title="test-pull-request", + author=self.owner, + head=None, + compared_to=base_commit_report.commit.commitid, + behind_by=23, + behind_by_commit="1089nf898as-jdf09hahs09fgh", + ) + + with open("./services/tests/samples/base_bundle_report.sqlite", "rb") as f: + storage_path = StoragePaths.bundle_report.path( + repo_key=ArchiveService.get_archive_hash(self.repository), + report_key=base_commit_report.external_id, + ) + storage.write_file(get_bucket_name(), storage_path, f) + + query = """ + bundleAnalysisCompareWithBase { + __typename + ... on BundleAnalysisComparison { + bundleData { + size { + uncompress + } + } + bundleChange { + size { + uncompress + } + } + } + } + """ + + pull = self.fetch_one_pull_request(my_pull.pullid, query) + + assert pull == { + "bundleAnalysisCompareWithBase": { + "__typename": "BundleAnalysisComparison", + "bundleData": { + "size": { + "uncompress": 165165, + } + }, + "bundleChange": { + "size": { + "uncompress": 0, + } + }, + } + } + + for file in os.listdir("/tmp"): + assert not file.startswith("bundle_analysis_") + + os.system("rm -rf /tmp/bundle_analysis_*") @freeze_time("2021-02-02") def test_when_pr_is_first_pr_in_repo(self): @@ -453,7 +541,11 @@ def test_bundle_analysis_sqlite_file_deleted(self, get_storage_service): bundleAnalysisCompareWithBase { __typename ... on BundleAnalysisComparison { - sizeTotal + bundleData { + size { + uncompress + } + } } } """ @@ -463,7 +555,11 @@ def test_bundle_analysis_sqlite_file_deleted(self, get_storage_service): assert pull == { "bundleAnalysisCompareWithBase": { "__typename": "BundleAnalysisComparison", - "sizeTotal": 201720, + "bundleData": { + "size": { + "uncompress": 201720, + } + }, } } diff --git a/graphql_api/tests/test_pull_comparison.py b/graphql_api/tests/test_pull_comparison.py index d0cf26fbd1..87a3645970 100644 --- a/graphql_api/tests/test_pull_comparison.py +++ b/graphql_api/tests/test_pull_comparison.py @@ -5,7 +5,6 @@ from shared.reports.types import ReportTotals from shared.utils.merge import LineType -import services.comparison as comparison from codecov_auth.tests.factories import OwnerFactory from compare.models import CommitComparison from compare.tests.factories import CommitComparisonFactory, FlagComparisonFactory @@ -409,35 +408,42 @@ def test_pull_comparison_impacted_files(self, files_mock, critical_files): pullId compareWithBase { ... on Comparison { - impactedFilesDeprecated { - baseName - headName - isNewFile - isRenamedFile - isDeletedFile - baseCoverage { - percentCovered - fileCount - lineCount - hitsCount - missesCount - partialsCount - } - headCoverage { - percentCovered - fileCount - lineCount - hitsCount - missesCount - partialsCount + impactedFiles(filters:{}) { + ... on ImpactedFiles { + results { + baseName + headName + isNewFile + isRenamedFile + isDeletedFile + baseCoverage { + percentCovered + fileCount + lineCount + hitsCount + missesCount + partialsCount + } + headCoverage { + percentCovered + fileCount + lineCount + hitsCount + missesCount + partialsCount + } + patchCoverage { + percentCovered + fileCount + lineCount + hitsCount + missesCount + partialsCount + } + } } - patchCoverage { - percentCovered - fileCount - lineCount - hitsCount - missesCount - partialsCount + ... on UnknownFlags { + message } } } @@ -473,28 +479,30 @@ def test_pull_comparison_impacted_files(self, files_mock, critical_files): assert res == { "pullId": self.pull.pullid, "compareWithBase": { - "impactedFilesDeprecated": [ - { - "baseName": "foo.py", - "headName": "bar.py", - "isNewFile": False, - "isRenamedFile": True, - "isDeletedFile": False, - "baseCoverage": base_totals, - "headCoverage": head_totals, - "patchCoverage": patch_totals, - }, - { - "baseName": None, - "headName": "baz.py", - "isNewFile": True, - "isRenamedFile": False, - "isDeletedFile": False, - "baseCoverage": base_totals, - "headCoverage": head_totals, - "patchCoverage": patch_totals, - }, - ] + "impactedFiles": { + "results": [ + { + "baseName": "foo.py", + "headName": "bar.py", + "isNewFile": False, + "isRenamedFile": True, + "isDeletedFile": False, + "baseCoverage": base_totals, + "headCoverage": head_totals, + "patchCoverage": patch_totals, + }, + { + "baseName": None, + "headName": "baz.py", + "isNewFile": True, + "isRenamedFile": False, + "isDeletedFile": False, + "baseCoverage": base_totals, + "headCoverage": head_totals, + "patchCoverage": patch_totals, + }, + ] + }, }, } @@ -526,10 +534,17 @@ def test_pull_comparison_is_critical_file(self, files_mock, critical_files): pullId compareWithBase { ... on Comparison { - impactedFilesDeprecated { - baseName - headName - isCriticalFile + impactedFiles(filters:{}) { + ... on ImpactedFiles { + results { + baseName + headName + isCriticalFile + } + } + ... on UnknownFlags { + message + } } } } @@ -539,18 +554,20 @@ def test_pull_comparison_is_critical_file(self, files_mock, critical_files): assert res == { "pullId": self.pull.pullid, "compareWithBase": { - "impactedFilesDeprecated": [ - { - "baseName": "foo.py", - "headName": "bar.py", - "isCriticalFile": True, - }, - { - "baseName": None, - "headName": "baz.py", - "isCriticalFile": False, - }, - ] + "impactedFiles": { + "results": [ + { + "baseName": "foo.py", + "headName": "bar.py", + "isCriticalFile": True, + }, + { + "baseName": None, + "headName": "baz.py", + "isCriticalFile": False, + }, + ] + }, }, } @@ -581,10 +598,17 @@ def test_pull_comparison_is_critical_file_returns_false_through_repositories( pullId compareWithBase { ... on Comparison { - impactedFilesDeprecated { - baseName - headName - isCriticalFile + impactedFiles(filters:{}) { + ... on ImpactedFiles { + results { + baseName + headName + isCriticalFile + } + } + ... on UnknownFlags { + message + } } } } @@ -609,13 +633,15 @@ def test_pull_comparison_is_critical_file_returns_false_through_repositories( "pull": { "pullId": 2, "compareWithBase": { - "impactedFilesDeprecated": [ - { - "baseName": "foo.py", - "headName": "bar.py", - "isCriticalFile": False, - } - ] + "impactedFiles": { + "results": [ + { + "baseName": "foo.py", + "headName": "bar.py", + "isCriticalFile": False, + } + ] + }, }, } } @@ -725,26 +751,33 @@ def test_pull_comparison_line_comparisons( pullId compareWithBase { ... on Comparison { - impactedFilesDeprecated { - segments { - ... on SegmentComparisons { - results { - header - hasUnintendedChanges - lines { - baseNumber - headNumber - baseCoverage - headCoverage - content - coverageInfo(ignoredUploadIds: [1]) { - hitCount - hitUploadIds + impactedFiles(filters: {}){ + ... on ImpactedFiles { + results { + segments { + ... on SegmentComparisons { + results { + header + hasUnintendedChanges + lines { + baseNumber + headNumber + baseCoverage + headCoverage + content + coverageInfo(ignoredUploadIds: [1]) { + hitCount + hitUploadIds + } + } } } } } } + ... on UnknownFlags { + message + } } } } @@ -754,65 +787,67 @@ def test_pull_comparison_line_comparisons( assert res == { "pullId": self.pull.pullid, "compareWithBase": { - "impactedFilesDeprecated": [ - { - "segments": { - "results": [ - { - "header": "-1,2 +3,4", - "hasUnintendedChanges": False, - "lines": [ - { - "baseNumber": "1", - "headNumber": "1", - "baseCoverage": "H", - "headCoverage": "H", - "content": " line1", - "coverageInfo": { - "hitCount": 1, - "hitUploadIds": [0], - }, - }, - { - "baseNumber": None, - "headNumber": "2", - "baseCoverage": None, - "headCoverage": "H", - "content": "+ line2", - "coverageInfo": { - "hitCount": 1, - "hitUploadIds": [0], + "impactedFiles": { + "results": [ + { + "segments": { + "results": [ + { + "header": "-1,2 +3,4", + "hasUnintendedChanges": False, + "lines": [ + { + "baseNumber": "1", + "headNumber": "1", + "baseCoverage": "H", + "headCoverage": "H", + "content": " line1", + "coverageInfo": { + "hitCount": 1, + "hitUploadIds": [0], + }, }, - }, - ], - } - ] - } - }, - { - "segments": { - "results": [ - { - "header": "-1 +1", - "hasUnintendedChanges": True, - "lines": [ - { - "baseNumber": "1", - "headNumber": "1", - "baseCoverage": "M", - "headCoverage": "H", - "content": " line1", - "coverageInfo": { - "hitCount": 1, - "hitUploadIds": [0], + { + "baseNumber": None, + "headNumber": "2", + "baseCoverage": None, + "headCoverage": "H", + "content": "+ line2", + "coverageInfo": { + "hitCount": 1, + "hitUploadIds": [0], + }, }, - } - ], - } - ] - } - }, - ] + ], + } + ] + } + }, + { + "segments": { + "results": [ + { + "header": "-1 +1", + "hasUnintendedChanges": True, + "lines": [ + { + "baseNumber": "1", + "headNumber": "1", + "baseCoverage": "M", + "headCoverage": "H", + "content": " line1", + "coverageInfo": { + "hitCount": 1, + "hitUploadIds": [0], + }, + } + ], + } + ] + } + }, + ] + }, }, } @@ -871,22 +906,29 @@ def test_pull_comparison_coverage_changes( pullId compareWithBase { ... on Comparison { - impactedFilesDeprecated { - segments { - ... on SegmentComparisons { - results { - header - hasUnintendedChanges - lines { - baseNumber - headNumber - baseCoverage - headCoverage - content + impactedFiles(filters: {}){ + ... on ImpactedFiles { + results { + segments { + ... on SegmentComparisons { + results { + header + hasUnintendedChanges + lines { + baseNumber + headNumber + baseCoverage + headCoverage + content + } + } } } } } + ... on UnknownFlags { + message + } } } } @@ -896,27 +938,29 @@ def test_pull_comparison_coverage_changes( assert res == { "pullId": self.pull.pullid, "compareWithBase": { - "impactedFilesDeprecated": [ - { - "segments": { - "results": [ - { - "header": "-1,1 +1,1", - "hasUnintendedChanges": True, - "lines": [ - { - "baseNumber": "1", - "headNumber": "1", - "baseCoverage": "M", - "headCoverage": "H", - "content": " line1", - } - ], - } - ] + "impactedFiles": { + "results": [ + { + "segments": { + "results": [ + { + "header": "-1,1 +1,1", + "hasUnintendedChanges": True, + "lines": [ + { + "baseNumber": "1", + "headNumber": "1", + "baseCoverage": "M", + "headCoverage": "H", + "content": " line1", + } + ], + } + ] + } } - } - ] + ] + }, }, } @@ -941,9 +985,16 @@ def test_pull_comparison_pending(self): headTotals { percentCovered } - impactedFilesDeprecated { - baseName - headName + impactedFiles(filters: {}) { + ... on ImpactedFiles { + results { + baseName + headName + } + } + ... on UnknownFlags { + message + } } } } @@ -956,7 +1007,7 @@ def test_pull_comparison_pending(self): "state": "pending", "baseTotals": None, "headTotals": None, - "impactedFilesDeprecated": [], + "impactedFiles": {"results": []}, }, } @@ -975,7 +1026,7 @@ def test_pull_comparison_no_comparison(self, compute_comparisons_mock): res = self._request(query) # it regenerates the comparison as needed - assert res["compareWithBase"] != None + assert res["compareWithBase"] is not None compute_comparisons_mock.assert_called_once diff --git a/graphql_api/tests/test_repository.py b/graphql_api/tests/test_repository.py index 4206ef2fda..307445f1f7 100644 --- a/graphql_api/tests/test_repository.py +++ b/graphql_api/tests/test_repository.py @@ -5,13 +5,13 @@ from freezegun import freeze_time from codecov_auth.tests.factories import OwnerFactory -from core import models from core.tests.factories import ( CommitFactory, PullFactory, RepositoryFactory, RepositoryTokenFactory, ) +from reports.tests.factories import TestFactory, TestInstanceFactory from services.profiling import CriticalFile from .helper import GraphQLTestHelper @@ -302,7 +302,7 @@ def test_repository_resolve_yaml_no_yaml(self): owner=user, variables={"name": repo.name}, ) - assert data["me"]["owner"]["repository"]["yaml"] == None + assert data["me"]["owner"]["repository"]["yaml"] is None def test_repository_resolve_bot(self): user = OwnerFactory() @@ -403,6 +403,36 @@ def test_repository_repository_config_indication_range(self, mocked_useryaml): == 80 ) + @patch("shared.yaml.user_yaml.UserYaml.get_final_yaml") + def test_repository_repository_config_indication_range_float(self, mocked_useryaml): + mocked_useryaml.return_value = {"coverage": {"range": [61.1, 82.2]}} + + repo = RepositoryFactory( + author=self.owner, + active=True, + private=True, + ) + + data = self.gql_request( + query_repository + % "repositoryConfig { indicationRange { upperRange lowerRange } }", + owner=self.owner, + variables={"name": repo.name}, + ) + + assert ( + data["me"]["owner"]["repository"]["repositoryConfig"]["indicationRange"][ + "lowerRange" + ] + == 61.1 + ) + assert ( + data["me"]["owner"]["repository"]["repositoryConfig"]["indicationRange"][ + "upperRange" + ] + == 82.2 + ) + @patch("services.activation.try_auto_activate") def test_repository_auto_activate(self, try_auto_activate): repo = RepositoryFactory( @@ -558,7 +588,7 @@ def test_repository_get_languages_null(self): author=self.owner, active=True, private=True, languages=None ) res = self.fetch_repository(repo.name) - assert res["languages"] == None + assert res["languages"] is None def test_repository_get_languages_empty(self): repo = RepositoryFactory(author=self.owner, active=True, private=True) @@ -754,3 +784,341 @@ def test_repository_when_is_first_pull_request_false(self) -> None: ) assert data["me"]["owner"]["repository"]["isFirstPullRequest"] == False + + @patch("shared.rate_limits.determine_entity_redis_key") + @patch("shared.rate_limits.determine_if_entity_is_rate_limited") + @override_settings(IS_ENTERPRISE=True, GUEST_ACCESS=False) + def test_fetch_is_github_rate_limited( + self, mock_determine_rate_limit, mock_determine_redis_key + ): + repo = RepositoryFactory( + author=self.owner, + active=True, + private=True, + yaml={"component_management": {}}, + ) + + mock_determine_redis_key.return_value = "test" + mock_determine_rate_limit.return_value = True + + data = self.gql_request( + query_repository + % """ + isGithubRateLimited + """, + owner=self.owner, + variables={"name": repo.name}, + ) + + assert data["me"]["owner"]["repository"]["isGithubRateLimited"] == True + + def test_fetch_is_github_rate_limited_not_on_gh_service(self): + owner = OwnerFactory(service="gitlab") + repo = RepositoryFactory( + author=owner, + author__service="gitlab", + service_id=12345, + active=True, + ) + + data = self.gql_request( + query_repository + % """ + isGithubRateLimited + """, + owner=owner, + variables={"name": repo.name}, + provider="gitlab", + ) + + assert data["me"]["owner"]["repository"]["isGithubRateLimited"] == False + + @patch("shared.rate_limits.determine_entity_redis_key") + @patch("shared.rate_limits.determine_if_entity_is_rate_limited") + @patch("logging.Logger.warning") + @override_settings(IS_ENTERPRISE=True, GUEST_ACCESS=False) + def test_fetch_is_github_rate_limited_but_errors( + self, + mock_log_warning, + mock_determine_rate_limit, + mock_determine_redis_key, + ): + repo = RepositoryFactory( + author=self.owner, + active=True, + private=True, + yaml={"component_management": {}}, + ) + + mock_determine_redis_key.side_effect = Exception("some random error lol") + mock_determine_rate_limit.return_value = True + + data = self.gql_request( + query_repository + % """ + isGithubRateLimited + """, + owner=self.owner, + variables={"name": repo.name}, + ) + + assert data["me"]["owner"]["repository"]["isGithubRateLimited"] is None + + mock_log_warning.assert_called_once_with( + "Error when checking rate limit", + extra={ + "repo_id": repo.repoid, + "has_owner": True, + }, + ) + + def test_test_results(self) -> None: + repo = RepositoryFactory(author=self.owner, active=True, private=True) + test = TestFactory(repository=repo) + _test_instance_1 = TestInstanceFactory( + test=test, created_at=datetime.datetime.now(), repoid=repo.repoid + ) + res = self.fetch_repository( + repo.name, """testResults { edges { node { name } } }""" + ) + assert res["testResults"] == {"edges": [{"node": {"name": test.name}}]} + + def test_test_results_no_tests(self) -> None: + repo = RepositoryFactory(author=self.owner, active=True, private=True) + res = self.fetch_repository( + repo.name, """testResults { edges { node { name } } }""" + ) + assert res["testResults"] == {"edges": []} + + def test_branch_filter_on_test_results(self) -> None: + repo = RepositoryFactory(author=self.owner, active=True, private=True) + test = TestFactory(repository=repo) + _test_instance_1 = TestInstanceFactory( + test=test, + created_at=datetime.datetime.now(), + repoid=repo.repoid, + branch="main", + ) + _test_instance_2 = TestInstanceFactory( + test=test, + created_at=datetime.datetime.now(), + repoid=repo.repoid, + branch="feature", + ) + res = self.fetch_repository( + repo.name, + """testResults(filters: { branch: "main"}) { edges { node { name } } }""", + ) + assert res["testResults"] == {"edges": [{"node": {"name": test.name}}]} + + def test_commits_failed_ordering_on_test_results(self) -> None: + repo = RepositoryFactory(author=self.owner, active=True, private=True) + test = TestFactory(repository=repo) + _test_instance_1 = TestInstanceFactory( + test=test, + created_at=datetime.datetime.now(), + repoid=repo.repoid, + commitid="1", + ) + _test_instance_2 = TestInstanceFactory( + test=test, + created_at=datetime.datetime.now(), + repoid=repo.repoid, + commitid="2", + ) + test_2 = TestFactory(repository=repo) + _test_instance_3 = TestInstanceFactory( + test=test_2, + created_at=datetime.datetime.now(), + repoid=repo.repoid, + commitid="3", + ) + res = self.fetch_repository( + repo.name, + """testResults(ordering: { parameter: COMMITS_WHERE_FAIL, direction: ASC }) { edges { node { name commitsFailed } } }""", + ) + assert res["testResults"] == { + "edges": [ + {"node": {"name": test_2.name, "commitsFailed": 1}}, + {"node": {"name": test.name, "commitsFailed": 2}}, + ] + } + + def test_desc_commits_failed_ordering_on_test_results(self) -> None: + repo = RepositoryFactory(author=self.owner, active=True, private=True) + test = TestFactory(repository=repo) + _test_instance_1 = TestInstanceFactory( + test=test, + created_at=datetime.datetime.now(), + repoid=repo.repoid, + commitid="1", + ) + _test_instance_2 = TestInstanceFactory( + test=test, + created_at=datetime.datetime.now(), + repoid=repo.repoid, + commitid="2", + ) + test_2 = TestFactory(repository=repo) + _test_instance_3 = TestInstanceFactory( + test=test_2, + created_at=datetime.datetime.now(), + repoid=repo.repoid, + commitid="3", + ) + res = self.fetch_repository( + repo.name, + """testResults(ordering: { parameter: COMMITS_WHERE_FAIL, direction: DESC }) { edges { node { name commitsFailed } } }""", + ) + assert res["testResults"] == { + "edges": [ + {"node": {"name": test.name, "commitsFailed": 2}}, + {"node": {"name": test_2.name, "commitsFailed": 1}}, + ] + } + + def test_avg_duration_ordering_on_test_results(self) -> None: + repo = RepositoryFactory(author=self.owner, active=True, private=True) + test = TestFactory(repository=repo) + _test_instance_1 = TestInstanceFactory( + test=test, + created_at=datetime.datetime.now(), + repoid=repo.repoid, + duration_seconds=1, + ) + _test_instance_2 = TestInstanceFactory( + test=test, + created_at=datetime.datetime.now(), + repoid=repo.repoid, + duration_seconds=2, + ) + test_2 = TestFactory(repository=repo) + _test_instance_3 = TestInstanceFactory( + test=test_2, + created_at=datetime.datetime.now(), + repoid=repo.repoid, + duration_seconds=3, + ) + res = self.fetch_repository( + repo.name, + """testResults(ordering: { parameter: AVG_DURATION, direction: ASC }) { edges { node { name avgDuration } } }""", + ) + assert res["testResults"] == { + "edges": [ + {"node": {"name": test.name, "avgDuration": 1.5}}, + {"node": {"name": test_2.name, "avgDuration": 3}}, + ] + } + + def test_desc_avg_duration_ordering_on_test_results(self) -> None: + repo = RepositoryFactory(author=self.owner, active=True, private=True) + test = TestFactory(repository=repo) + _test_instance_1 = TestInstanceFactory( + test=test, + created_at=datetime.datetime.now(), + repoid=repo.repoid, + duration_seconds=1, + ) + _test_instance_2 = TestInstanceFactory( + test=test, + created_at=datetime.datetime.now(), + repoid=repo.repoid, + duration_seconds=2, + ) + test_2 = TestFactory(repository=repo) + _test_instance_3 = TestInstanceFactory( + test=test_2, + created_at=datetime.datetime.now(), + repoid=repo.repoid, + duration_seconds=3, + ) + res = self.fetch_repository( + repo.name, + """testResults(ordering: { parameter: AVG_DURATION, direction: DESC }) { edges { node { name avgDuration } } }""", + ) + assert res["testResults"] == { + "edges": [ + {"node": {"name": test_2.name, "avgDuration": 3}}, + {"node": {"name": test.name, "avgDuration": 1.5}}, + ] + } + + def test_failure_rate_ordering_on_test_results(self) -> None: + repo = RepositoryFactory(author=self.owner, active=True, private=True) + test = TestFactory(repository=repo) + _test_instance_1 = TestInstanceFactory( + test=test, + created_at=datetime.datetime.now(), + repoid=repo.repoid, + outcome="pass", + ) + _test_instance_2 = TestInstanceFactory( + test=test, + created_at=datetime.datetime.now(), + repoid=repo.repoid, + outcome="failure", + ) + test_2 = TestFactory(repository=repo) + _test_instance_3 = TestInstanceFactory( + test=test_2, + created_at=datetime.datetime.now(), + repoid=repo.repoid, + outcome="failure", + ) + _test_instance_4 = TestInstanceFactory( + test=test_2, + created_at=datetime.datetime.now(), + repoid=repo.repoid, + outcome="failure", + ) + res = self.fetch_repository( + repo.name, + """testResults(ordering: { parameter: FAILURE_RATE, direction: ASC }) { edges { node { name failureRate } } }""", + ) + + assert res["testResults"] == { + "edges": [ + {"node": {"name": test.name, "failureRate": 0.5}}, + {"node": {"name": test_2.name, "failureRate": 1.0}}, + ] + } + + def test_desc_failure_rate_ordering_on_test_results(self) -> None: + repo = RepositoryFactory(author=self.owner, active=True, private=True) + test = TestFactory(repository=repo) + _test_instance_1 = TestInstanceFactory( + test=test, + created_at=datetime.datetime.now(), + repoid=repo.repoid, + outcome="pass", + ) + _test_instance_2 = TestInstanceFactory( + test=test, + created_at=datetime.datetime.now(), + repoid=repo.repoid, + outcome="failure", + ) + test_2 = TestFactory(repository=repo) + _test_instance_3 = TestInstanceFactory( + test=test_2, + created_at=datetime.datetime.now(), + repoid=repo.repoid, + outcome="failure", + ) + _test_instance_4 = TestInstanceFactory( + test=test_2, + created_at=datetime.datetime.now(), + repoid=repo.repoid, + outcome="failure", + ) + res = self.fetch_repository( + repo.name, + """testResults(ordering: { parameter: FAILURE_RATE, direction: DESC }) { edges { node { name failureRate } } }""", + ) + + assert res["testResults"] == { + "edges": [ + {"node": {"name": test_2.name, "failureRate": 1.0}}, + {"node": {"name": test.name, "failureRate": 0.5}}, + ] + } diff --git a/graphql_api/tests/test_repository_encoded_secret_string.py b/graphql_api/tests/test_repository_encoded_secret_string.py deleted file mode 100644 index fd3180433f..0000000000 --- a/graphql_api/tests/test_repository_encoded_secret_string.py +++ /dev/null @@ -1,40 +0,0 @@ -from django.test import TransactionTestCase -from shared.encryption.yaml_secret import yaml_secret_encryptor - -from codecov_auth.tests.factories import OwnerFactory -from core.tests.factories import RepositoryFactory - -from .helper import GraphQLTestHelper - - -class TestEncodedString(TransactionTestCase, GraphQLTestHelper): - def _request(self, variables=None): - query = f""" - query EncodedSecretString($value: String!) {{ - owner(username: "{self.org.username}") {{ - repository(name: "{self.repo.name}") {{ - ... on Repository {{ - encodedSecretString(value: $value) {{ - value - }} - }} - }} - }} - }} - """ - data = self.gql_request(query, owner=self.owner, variables=variables) - return data["owner"]["repository"]["encodedSecretString"]["value"] - - def setUp(self): - self.org = OwnerFactory(username="test-org") - self.repo = RepositoryFactory( - name="test-repo", - author=self.org, - private=True, - ) - self.owner = OwnerFactory(permission=[self.repo.pk]) - - def test_encoded_secret_string(self): - res = self._request(variables={"value": "token-1"}) - check_encryptor = yaml_secret_encryptor - assert "token-1" in check_encryptor.decode(res[7:]) diff --git a/graphql_api/tests/test_test_result.py b/graphql_api/tests/test_test_result.py new file mode 100644 index 0000000000..5138d65287 --- /dev/null +++ b/graphql_api/tests/test_test_result.py @@ -0,0 +1,188 @@ +from datetime import UTC, datetime + +from django.test import TransactionTestCase +from freezegun import freeze_time + +from codecov_auth.tests.factories import OwnerFactory +from core.tests.factories import RepositoryFactory +from reports.models import TestInstance +from reports.tests.factories import TestFactory, TestInstanceFactory + +from .helper import GraphQLTestHelper + + +@freeze_time(datetime.now().isoformat()) +class TestResultTestCase(GraphQLTestHelper, TransactionTestCase): + def setUp(self): + self.owner = OwnerFactory(username="randomOwner") + self.repository = RepositoryFactory( + author=self.owner, + ) + self.test = TestFactory( + name="Test\x1fName", + repository=self.repository, + ) + _ = TestInstanceFactory( + test=self.test, + outcome=TestInstance.Outcome.FAILURE.value, + duration_seconds=1.1, + repoid=self.repository.repoid, + created_at=datetime.now(), + ) + _ = TestInstanceFactory( + test=self.test, + outcome=TestInstance.Outcome.FAILURE.value, + duration_seconds=1.3, + repoid=self.repository.repoid, + created_at=datetime.now(), + ) + _ = TestInstanceFactory( + test=self.test, + outcome=TestInstance.Outcome.PASS.value, + duration_seconds=1.5, + repoid=self.repository.repoid, + created_at=datetime.now(), + commitid="456123", + ) + + def test_fetch_test_result_name(self) -> None: + query = """ + query { + owner(username: "%s") { + repository(name: "%s") { + ... on Repository { + testResults { + edges { + node { + name + } + } + } + } + } + } + } + """ % (self.owner.username, self.repository.name) + + result = self.gql_request(query, owner=self.owner) + + assert "errors" not in result + assert result["owner"]["repository"]["testResults"]["edges"][0]["node"][ + "name" + ] == self.test.name.replace("\x1f", " ") + + def test_fetch_test_result_updated_at(self) -> None: + query = """ + query { + owner(username: "%s") { + repository(name: "%s") { + ... on Repository { + testResults { + edges { + node { + updatedAt + } + } + } + } + } + } + } + """ % (self.owner.username, self.repository.name) + + result = self.gql_request(query, owner=self.owner) + + assert "errors" not in result + assert ( + result["owner"]["repository"]["testResults"]["edges"][0]["node"][ + "updatedAt" + ] + == datetime.now(UTC).isoformat() + ) + + def test_fetch_test_result_commits_failed(self) -> None: + query = """ + query { + owner(username: "%s") { + repository(name: "%s") { + ... on Repository { + testResults { + edges { + node { + commitsFailed + } + } + } + } + } + } + } + """ % (self.owner.username, self.repository.name) + + result = self.gql_request(query, owner=self.owner) + + assert "errors" not in result + assert ( + result["owner"]["repository"]["testResults"]["edges"][0]["node"][ + "commitsFailed" + ] + == 1 + ) + + def test_fetch_test_result_failure_rate(self) -> None: + query = """ + query { + owner(username: "%s") { + repository(name: "%s") { + ... on Repository { + testResults { + edges { + node { + failureRate + } + } + } + } + } + } + } + """ % (self.owner.username, self.repository.name) + + result = self.gql_request(query, owner=self.owner) + + assert "errors" not in result + assert ( + result["owner"]["repository"]["testResults"]["edges"][0]["node"][ + "failureRate" + ] + == 2 / 3 + ) + + def test_fetch_test_result_avg_duration(self) -> None: + query = """ + query { + owner(username: "%s") { + repository(name: "%s") { + ... on Repository { + testResults { + edges { + node { + avgDuration + } + } + } + } + } + } + } + """ % (self.owner.username, self.repository.name) + + result = self.gql_request(query, owner=self.owner) + + assert "errors" not in result + assert ( + result["owner"]["repository"]["testResults"]["edges"][0]["node"][ + "avgDuration" + ] + == 1.3 + ) diff --git a/graphql_api/tests/test_user.py b/graphql_api/tests/test_user.py index 98170d8dd2..319b7ac9db 100644 --- a/graphql_api/tests/test_user.py +++ b/graphql_api/tests/test_user.py @@ -85,4 +85,4 @@ def test_query_user_resolver(self): def test_query_null_user_customer_intent_resolver(self): null_user = OwnerFactory(user=None, service_id=4) data = resolve_customer_intent(null_user, None) - assert data == None + assert data is None diff --git a/graphql_api/tests/test_views.py b/graphql_api/tests/test_views.py index ec158f0a99..17475fed75 100644 --- a/graphql_api/tests/test_views.py +++ b/graphql_api/tests/test_views.py @@ -1,5 +1,5 @@ import json -from unittest.mock import patch +from unittest.mock import Mock, call, patch from ariadne import ObjectType, make_executable_schema from ariadne.validation import cost_directive @@ -39,7 +39,7 @@ def generate_cost_test_schema(): return make_executable_schema([types, cost_directive], query_bindable) -class ArianeViewTestCase(GraphQLTestHelper, TestCase): +class AriadneViewTestCase(GraphQLTestHelper, TestCase): async def do_query(self, schema, query="{ failing }"): view = AsyncGraphqlView.as_view(schema=schema) request = RequestFactory().post( @@ -200,3 +200,40 @@ async def test_query_metrics_extension_set_type_and_name_timeout( ) assert extension.operation_type == "unknown_type" assert extension.operation_name == "unknown_name" + + @patch("sentry_sdk.metrics.incr") + @patch("graphql_api.views.AsyncGraphqlView._check_ratelimit") + async def test_when_rate_limit_reached( + self, mocked_check_ratelimit, mocked_sentry_incr + ): + schema = generate_cost_test_schema() + mocked_check_ratelimit.return_value = True + response = await self.do_query(schema, " { stuff }") + + assert response["status"] == 429 + assert ( + response["detail"] + == "It looks like you've hit the rate limit of 300 req/min. Try again later." + ) + + expected_calls = [ + call("graphql.info.request_made", tags={"path": "/graphql/gh"}), + call("graphql.error.rate_limit", tags={"path": "/graphql/gh"}), + ] + mocked_sentry_incr.assert_has_calls(expected_calls) + + def test_client_ip_from_x_forwarded_for(self): + view = AsyncGraphqlView() + request = Mock() + request.META = {"HTTP_X_FORWARDED_FOR": "127.0.0.1,blah", "REMOTE_ADDR": "lol"} + + result = view.get_client_ip(request) + assert result == "127.0.0.1" + + def test_client_ip_from_remote_addr(self): + view = AsyncGraphqlView() + request = Mock() + request.META = {"HTTP_X_FORWARDED_FOR": None, "REMOTE_ADDR": "lol"} + + result = view.get_client_ip(request) + assert result == "lol" diff --git a/graphql_api/types/__init__.py b/graphql_api/types/__init__.py index 0f2dacc933..6753d88454 100644 --- a/graphql_api/types/__init__.py +++ b/graphql_api/types/__init__.py @@ -1,8 +1,8 @@ -from ariadne import load_schema_from_path from ariadne.validation import cost_directive from ariadne_django.scalars import datetime_scalar from ..helpers.ariadne import ariadne_load_local_graphql +from .account import account, account_bindable from .branch import branch, branch_bindable from .bundle_analysis import ( bundle_analysis, @@ -38,6 +38,7 @@ from .me import me, me_bindable, tracking_metadata_bindable from .measurement import measurement, measurement_bindable from .mutation import mutation, mutation_resolvers +from .okta_config import okta_config, okta_config_bindable from .owner import owner, owner_bindable from .path_contents import ( path_content, @@ -63,6 +64,7 @@ ) from .self_hosted_license import self_hosted_license, self_hosted_license_bindable from .session import session, session_bindable +from .test_results import test_result_bindable, test_results from .upload import upload, upload_bindable, upload_error_bindable from .user import user, user_bindable from .user_token import user_token, user_token_bindable @@ -109,6 +111,9 @@ upload, user_token, user, + account, + okta_config, + test_results, ] bindables = [ @@ -163,4 +168,7 @@ upload_error_bindable, user_bindable, user_token_bindable, + account_bindable, + okta_config_bindable, + test_result_bindable, ] diff --git a/graphql_api/types/account/__init__.py b/graphql_api/types/account/__init__.py new file mode 100644 index 0000000000..8e5898e6c5 --- /dev/null +++ b/graphql_api/types/account/__init__.py @@ -0,0 +1,10 @@ +from shared.license import get_current_license + +from graphql_api.helpers.ariadne import ariadne_load_local_graphql + +from .account import account_bindable + +account = ariadne_load_local_graphql(__file__, "account.graphql") + + +__all__ = ["get_current_license", "account_bindable"] diff --git a/graphql_api/types/account/account.graphql b/graphql_api/types/account/account.graphql new file mode 100644 index 0000000000..4866830b97 --- /dev/null +++ b/graphql_api/types/account/account.graphql @@ -0,0 +1,4 @@ +type Account { + name: String! + oktaConfig: OktaConfig +} diff --git a/graphql_api/types/account/account.py b/graphql_api/types/account/account.py new file mode 100644 index 0000000000..6009125967 --- /dev/null +++ b/graphql_api/types/account/account.py @@ -0,0 +1,17 @@ +from ariadne import ObjectType + +from codecov.db import sync_to_async +from codecov_auth.models import Account, OktaSettings + +account_bindable = ObjectType("Account") + + +@account_bindable.field("name") +def resolve_name(account: Account, info) -> str: + return account.name + + +@account_bindable.field("oktaConfig") +@sync_to_async +def resolve_okta_config(account: Account, info) -> OktaSettings: + return OktaSettings.objects.filter(account_id=account.pk).first() diff --git a/graphql_api/types/branch/__init__.py b/graphql_api/types/branch/__init__.py index e306f5b3c2..27a3954e89 100644 --- a/graphql_api/types/branch/__init__.py +++ b/graphql_api/types/branch/__init__.py @@ -3,3 +3,6 @@ from .branch import branch_bindable branch = ariadne_load_local_graphql(__file__, "branch.graphql") + + +__all__ = ["branch_bindable"] diff --git a/graphql_api/types/branch/branch.py b/graphql_api/types/branch/branch.py index c2b209f671..5d76428d36 100644 --- a/graphql_api/types/branch/branch.py +++ b/graphql_api/types/branch/branch.py @@ -2,7 +2,6 @@ from ariadne import ObjectType -from codecov.db import sync_to_async from core.models import Branch, Commit from graphql_api.dataloader.commit import CommitLoader diff --git a/graphql_api/types/bundle_analysis/__init__.py b/graphql_api/types/bundle_analysis/__init__.py index 10b6fdaba3..4b7ae42be4 100644 --- a/graphql_api/types/bundle_analysis/__init__.py +++ b/graphql_api/types/bundle_analysis/__init__.py @@ -19,3 +19,16 @@ bundle_analysis = ariadne_load_local_graphql(__file__, "base.graphql") bundle_analysis_comparison = ariadne_load_local_graphql(__file__, "comparison.graphql") bundle_analysis_report = ariadne_load_local_graphql(__file__, "report.graphql") + + +__all__ = [ + "bundle_asset_bindable", + "bundle_data_bindable", + "bundle_module_bindable", + "bundle_report_bindable", + "bundle_analysis_comparison_bindable", + "bundle_analysis_comparison_result_bindable", + "bundle_comparison_bindable", + "bundle_analysis_report_bindable", + "bundle_analysis_report_result_bindable", +] diff --git a/graphql_api/types/bundle_analysis/base.graphql b/graphql_api/types/bundle_analysis/base.graphql index b746012da1..9af48f5a68 100644 --- a/graphql_api/types/bundle_analysis/base.graphql +++ b/graphql_api/types/bundle_analysis/base.graphql @@ -7,6 +7,20 @@ enum BundleAnalysisMeasurementsAssetType { ASSET_SIZE } +enum BundleReportGroups { + JAVASCRIPT + STYLESHEET + FONT + IMAGE + UNKNOWN +} + +enum BundleLoadTypes { + ENTRY + INITIAL + LAZY +} + type BundleSize { gzip: Int! uncompress: Int! @@ -32,34 +46,39 @@ type BundleAsset { name: String! extension: String! normalizedName: String! - moduleExtensions: [String!]! modules: [BundleModule]! bundleData: BundleData! measurements( interval: MeasurementInterval! before: DateTime! - after: DateTime! + after: DateTime branch: String ): BundleAnalysisMeasurements } type BundleReport { name: String! - sizeTotal: Int! - loadTimeTotal: Float! - moduleExtensions: [String!]! moduleCount: Int! - assets(filters: BundleAnalysisReportFilters): [BundleAsset]! + assets: [BundleAsset]! asset(name: String!): BundleAsset bundleData: BundleData! measurements( interval: MeasurementInterval! before: DateTime! - after: DateTime! + after: DateTime branch: String orderingDirection: OrderingDirection filters: BundleAnalysisMeasurementsSetFilters ): [BundleAnalysisMeasurements!] + isCached: Boolean! + assetsPaginated( + ordering: AssetOrdering + orderingDirection: OrderingDirection + first: Int + after: String + last: Int + before: String + ): AssetConnection } type BundleAnalysisMeasurements{ @@ -68,4 +87,15 @@ type BundleAnalysisMeasurements{ size: BundleData change: BundleData measurements: [Measurement!] +} + +type AssetConnection { + edges: [AssetEdge]! + totalCount: Int! + pageInfo: PageInfo! +} + +type AssetEdge { + cursor: String! + node: BundleAsset! } \ No newline at end of file diff --git a/graphql_api/types/bundle_analysis/base.py b/graphql_api/types/bundle_analysis/base.py index 0f06c4d062..2e67d5ac4e 100644 --- a/graphql_api/types/bundle_analysis/base.py +++ b/graphql_api/types/bundle_analysis/base.py @@ -1,11 +1,12 @@ from datetime import datetime -from typing import List, Mapping, Optional +from typing import Dict, List, Mapping, Optional, Union from ariadne import ObjectType, convert_kwargs_to_snake_case from graphql import GraphQLResolveInfo +from codecov.commands.exceptions import ValidationError from codecov.db import sync_to_async -from graphql_api.types.enums import OrderingDirection +from graphql_api.types.enums import AssetOrdering, OrderingDirection from services.bundle_analysis import ( AssetReport, BundleAnalysisMeasurementData, @@ -25,6 +26,16 @@ bundle_report_bindable = ObjectType("BundleReport") +def _find_index_by_cursor(assets: List, cursor: str) -> int: + try: + for i, asset in enumerate(assets): + if asset.id == int(cursor): + return i + except ValueError: + pass + return -1 + + # ============= Bundle Data Bindable ============= @@ -83,7 +94,7 @@ def resolve_extension(bundle_asset: AssetReport, info: GraphQLResolveInfo) -> st def resolve_bundle_asset_bundle_data( bundle_asset: AssetReport, info: GraphQLResolveInfo ) -> BundleData: - return BundleData(bundle_asset.size_total) + return BundleData(bundle_asset.size_total, bundle_asset.gzip_size_total) @bundle_asset_bindable.field("modules") @@ -93,13 +104,6 @@ def resolve_modules( return bundle_asset.modules -@bundle_asset_bindable.field("moduleExtensions") -def resolve_bundle_asset_module_extensions( - bundle_asset: AssetReport, info: GraphQLResolveInfo -) -> List[str]: - return bundle_asset.module_extensions - - @bundle_asset_bindable.field("measurements") @convert_kwargs_to_snake_case @sync_to_async @@ -108,14 +112,14 @@ def resolve_asset_report_measurements( info: GraphQLResolveInfo, interval: Interval, before: datetime, - after: datetime, + after: Optional[datetime] = None, branch: Optional[str] = None, ) -> Optional[BundleAnalysisMeasurementData]: bundle_analysis_measurements = BundleAnalysisMeasurementsService( repository=info.context["commit"].repository, interval=interval, - after=after, before=before, + after=after, branch=branch, ) return bundle_analysis_measurements.compute_asset(bundle_asset) @@ -129,27 +133,6 @@ def resolve_name(bundle_report: BundleReport, info: GraphQLResolveInfo) -> str: return bundle_report.name -# TODO: depreacted with Issue 1199 -@bundle_report_bindable.field("sizeTotal") -def resolve_size_total(bundle_report: BundleReport, info: GraphQLResolveInfo) -> int: - return bundle_report.size_total - - -# TODO: depreacted with Issue 1199 -@bundle_report_bindable.field("loadTimeTotal") -def resolve_load_time_total( - bundle_report: BundleReport, info: GraphQLResolveInfo -) -> float: - return bundle_report.load_time_total - - -@bundle_report_bindable.field("moduleExtensions") -def resolve_module_extensions( - bundle_report: BundleReport, info: GraphQLResolveInfo -) -> List[str]: - return bundle_report.module_extensions - - @bundle_report_bindable.field("moduleCount") def resolve_module_count(bundle_report: BundleReport, info: GraphQLResolveInfo) -> int: return bundle_report.module_count @@ -159,10 +142,72 @@ def resolve_module_count(bundle_report: BundleReport, info: GraphQLResolveInfo) def resolve_assets( bundle_report: BundleReport, info: GraphQLResolveInfo, - filters: Optional[Mapping] = None, ) -> List[AssetReport]: - extensions_filter = filters.get("moduleExtensions", None) if filters else None - return list(bundle_report.assets(extensions_filter)) + return list(bundle_report.assets()) + + +@bundle_report_bindable.field("assetsPaginated") +def resolve_assets_paginated( + bundle_report: BundleReport, + info: GraphQLResolveInfo, + ordering: AssetOrdering = AssetOrdering.SIZE, + ordering_direction: OrderingDirection = OrderingDirection.DESC, + first: Optional[int] = None, + after: Optional[str] = None, + last: Optional[int] = None, + before: Optional[str] = None, +) -> Union[Dict[str, object], ValidationError]: + if first is not None and last is not None: + return ValidationError("First and last can not be used at the same time") + if after is not None and before is not None: + return ValidationError("After and before can not be used at the same time") + + # All filtered assets before pagination + assets = list( + bundle_report.assets( + ordering=ordering.value, + ordering_desc=ordering_direction.value == OrderingDirection.DESC.value, + ) + ) + + total_count, has_next_page, has_previous_page = len(assets), False, False + start_cursor, end_cursor = None, None + + # Apply cursors to edges + if after is not None: + after_edge = _find_index_by_cursor(assets, after) + if after_edge > -1: + assets = assets[after_edge + 1 :] + + if before is not None: + before_edge = _find_index_by_cursor(assets, before) + if before_edge > -1: + assets = assets[:before_edge] + + # Slice edges by return size + if first is not None and first >= 0: + if len(assets) > first: + assets = assets[:first] + has_next_page = True + + if last is not None and last >= 0: + if len(assets) > last: + assets = assets[len(assets) - last :] + has_previous_page = True + + if assets: + start_cursor, end_cursor = assets[0].id, assets[-1].id + + return { + "edges": [{"cursor": asset.id, "node": asset} for asset in assets], + "total_count": total_count, + "page_info": { + "has_next_page": has_next_page, + "has_previous_page": has_previous_page, + "start_cursor": start_cursor, + "end_cursor": end_cursor, + }, + } @bundle_report_bindable.field("asset") @@ -176,7 +221,10 @@ def resolve_asset( def resolve_bundle_data( bundle_report: BundleReport, info: GraphQLResolveInfo ) -> BundleData: - return BundleData(bundle_report.size_total) + return BundleData( + bundle_report.size_total, + bundle_report.gzip_size_total, + ) @bundle_report_bindable.field("measurements") @@ -187,7 +235,7 @@ def resolve_bundle_report_measurements( info: GraphQLResolveInfo, interval: Interval, before: datetime, - after: datetime, + after: Optional[datetime] = None, branch: Optional[str] = None, filters: Mapping = {}, ordering_direction: Optional[OrderingDirection] = OrderingDirection.ASC, @@ -202,8 +250,8 @@ def resolve_bundle_report_measurements( bundle_analysis_measurements = BundleAnalysisMeasurementsService( repository=info.context["commit"].repository, interval=interval, - after=after, before=before, + after=after, branch=branch, ) @@ -218,3 +266,10 @@ def resolve_bundle_report_measurements( key=lambda c: c.asset_type, reverse=ordering_direction == OrderingDirection.DESC, ) + + +@bundle_report_bindable.field("isCached") +def resolve_bundle_report_is_cached( + bundle_report: BundleReport, info: GraphQLResolveInfo +) -> bool: + return bundle_report.is_cached diff --git a/graphql_api/types/bundle_analysis/comparison.graphql b/graphql_api/types/bundle_analysis/comparison.graphql index fc07cd6423..c35f5f0a1a 100644 --- a/graphql_api/types/bundle_analysis/comparison.graphql +++ b/graphql_api/types/bundle_analysis/comparison.graphql @@ -8,10 +8,6 @@ union BundleAnalysisComparisonResult = type BundleAnalysisComparison { bundles: [BundleComparison]! - sizeDelta: Int! - sizeTotal: Int! - loadTimeDelta: Float! - loadTimeTotal: Float! bundleData: BundleData! bundleChange: BundleData! } @@ -19,10 +15,6 @@ type BundleAnalysisComparison { type BundleComparison { name: String! changeType: String! - sizeDelta: Int! - sizeTotal: Int! - loadTimeDelta: Float! - loadTimeTotal: Float! bundleData: BundleData! bundleChange: BundleData! } \ No newline at end of file diff --git a/graphql_api/types/bundle_analysis/comparison.py b/graphql_api/types/bundle_analysis/comparison.py index a736be3da4..cfd1987ea5 100644 --- a/graphql_api/types/bundle_analysis/comparison.py +++ b/graphql_api/types/bundle_analysis/comparison.py @@ -34,34 +34,6 @@ def resolve_bundle_analysis_comparison_result_type(obj, *_): return "MissingBaseReport" -@bundle_analysis_comparison_bindable.field("sizeDelta") -def resolve_ba_comparison_size_delta( - bundles_analysis_comparison: BundleAnalysisComparison, info -): - return bundles_analysis_comparison.size_delta - - -@bundle_analysis_comparison_bindable.field("sizeTotal") -def resolve_ba_comparison_size_total( - bundles_analysis_comparison: BundleAnalysisComparison, info -): - return bundles_analysis_comparison.size_total - - -@bundle_analysis_comparison_bindable.field("loadTimeDelta") -def resolve_ba_comparison_load_time_delta( - bundles_analysis_comparison: BundleAnalysisComparison, info -): - return bundles_analysis_comparison.load_time_delta - - -@bundle_analysis_comparison_bindable.field("loadTimeTotal") -def resolve_ba_comparison_load_time_total( - bundles_analysis_comparison: BundleAnalysisComparison, info -): - return bundles_analysis_comparison.load_time_total - - @bundle_analysis_comparison_bindable.field("bundles") def resolve_ba_comparison_bundles( bundles_analysis_comparison: BundleAnalysisComparison, info @@ -93,26 +65,6 @@ def resolve_change_type(bundle_comparison: BundleComparison, info): return bundle_comparison.change_type -@bundle_comparison_bindable.field("sizeDelta") -def resolve_size_delta(bundle_comparison: BundleComparison, info): - return bundle_comparison.size_delta - - -@bundle_comparison_bindable.field("sizeTotal") -def resolve_size_total(bundle_comparison: BundleComparison, info): - return bundle_comparison.size_total - - -@bundle_comparison_bindable.field("loadTimeDelta") -def resolve_load_time_delta(bundle_comparison: BundleComparison, info): - return bundle_comparison.load_time_delta - - -@bundle_comparison_bindable.field("loadTimeTotal") -def resolve_load_time_total(bundle_comparison: BundleComparison, info): - return bundle_comparison.load_time_total - - @bundle_comparison_bindable.field("bundleData") def resolve_bundle_data(bundle_comparison: BundleComparison, info) -> BundleData: return BundleData(bundle_comparison.size_total) diff --git a/graphql_api/types/bundle_analysis/report.graphql b/graphql_api/types/bundle_analysis/report.graphql index 383ee4f279..4dcfbc8ee7 100644 --- a/graphql_api/types/bundle_analysis/report.graphql +++ b/graphql_api/types/bundle_analysis/report.graphql @@ -3,9 +3,8 @@ union BundleAnalysisReportResult = | MissingHeadReport type BundleAnalysisReport { - sizeTotal: Int! - loadTimeTotal: Float! bundles: [BundleReport]! bundleData: BundleData! - bundle(name: String!): BundleReport + bundle(name: String!, filters: BundleAnalysisReportFilters): BundleReport + isCached: Boolean! } diff --git a/graphql_api/types/bundle_analysis/report.py b/graphql_api/types/bundle_analysis/report.py index 8dc2bcf7f7..a57f5cf4cf 100644 --- a/graphql_api/types/bundle_analysis/report.py +++ b/graphql_api/types/bundle_analysis/report.py @@ -1,8 +1,10 @@ -from typing import List +from typing import Any, List, Optional, Union from ariadne import ObjectType, UnionType +from graphql import GraphQLResolveInfo from graphql_api.types.comparison.comparison import MissingHeadReport +from graphql_api.types.enums import BundleLoadTypes from services.bundle_analysis import BundleAnalysisReport, BundleData, BundleReport bundle_analysis_report_result_bindable = UnionType("BundleAnalysisReportResult") @@ -10,43 +12,74 @@ @bundle_analysis_report_result_bindable.type_resolver -def resolve_bundle_analysis_report_result_type(obj, *_): +def resolve_bundle_analysis_report_result_type( + obj: Union[BundleAnalysisReport, MissingHeadReport], *_: Any +) -> str: if isinstance(obj, BundleAnalysisReport): return "BundleAnalysisReport" elif isinstance(obj, MissingHeadReport): return "MissingHeadReport" -# TODO: depreacted with Issue 1199 -@bundle_analysis_report_bindable.field("sizeTotal") -def resolve_size_total(bundles_analysis_report: BundleAnalysisReport, info) -> int: - return bundles_analysis_report.size_total - - -# TODO: depreacted with Issue 1199 -@bundle_analysis_report_bindable.field("loadTimeTotal") -def resolve_load_time_total( - bundles_analysis_report: BundleAnalysisReport, info -) -> float: - return bundles_analysis_report.load_time_total - - @bundle_analysis_report_bindable.field("bundles") def resolve_bundles( - bundles_analysis_report: BundleAnalysisReport, info + bundles_analysis_report: BundleAnalysisReport, info: GraphQLResolveInfo ) -> List[BundleReport]: return bundles_analysis_report.bundles @bundle_analysis_report_bindable.field("bundle") def resolve_bundle( - bundles_analysis_report: BundleAnalysisReport, info, name: str -) -> BundleReport: - return bundles_analysis_report.bundle(name) + bundles_analysis_report: BundleAnalysisReport, + info: GraphQLResolveInfo, + name: str, + filters: dict[str, list[str]] = {}, +) -> Optional[BundleReport]: + asset_types = None + if filters.get("report_groups"): + asset_types = filters.get("report_groups") + + chunk_entry, chunk_initial = None, None + if filters.get("load_types"): + load_types = filters.get("load_types") + + # Compute chunk entry boolean + if BundleLoadTypes.ENTRY in load_types and ( + BundleLoadTypes.INITIAL in load_types or BundleLoadTypes.LAZY in load_types + ): + chunk_entry = None + elif BundleLoadTypes.ENTRY in load_types: + chunk_entry = True + elif ( + BundleLoadTypes.INITIAL in load_types or BundleLoadTypes.LAZY in load_types + ): + chunk_entry = False + + # Compute chunk initial boolean + if BundleLoadTypes.INITIAL in load_types and BundleLoadTypes.LAZY in load_types: + chunk_initial = None + elif BundleLoadTypes.INITIAL in load_types: + chunk_initial = True + elif BundleLoadTypes.LAZY in load_types: + chunk_initial = False + + return bundles_analysis_report.bundle( + name, + { + "asset_types": asset_types, + "chunk_entry": chunk_entry, + "chunk_initial": chunk_initial, + }, + ) @bundle_analysis_report_bindable.field("bundleData") def resolve_bundle_data( - bundles_analysis_report: BundleAnalysisReport, info + bundles_analysis_report: BundleAnalysisReport, info: GraphQLResolveInfo ) -> BundleData: return BundleData(bundles_analysis_report.size_total) + + +@bundle_analysis_report_bindable.field("isCached") +def resolve_is_cached(bundle_report: BundleReport, info: GraphQLResolveInfo) -> bool: + return bundle_report.is_cached diff --git a/graphql_api/types/commit/__init__.py b/graphql_api/types/commit/__init__.py index 3a2f45db54..60126e4eb9 100644 --- a/graphql_api/types/commit/__init__.py +++ b/graphql_api/types/commit/__init__.py @@ -3,3 +3,6 @@ from .commit import commit_bindable commit = ariadne_load_local_graphql(__file__, "commit.graphql") + + +__all__ = ["commit_bindable"] diff --git a/graphql_api/types/commit/commit.py b/graphql_api/types/commit/commit.py index 096a38c662..a6683a995d 100644 --- a/graphql_api/types/commit/commit.py +++ b/graphql_api/types/commit/commit.py @@ -1,3 +1,4 @@ +import logging from typing import List, Optional import sentry_sdk @@ -27,7 +28,6 @@ ) from graphql_api.types.comparison.comparison import ( MissingBaseCommit, - MissingBaseReport, MissingHeadReport, ) from graphql_api.types.enums import ( @@ -44,7 +44,10 @@ from services.path import ReportPaths from services.profiling import CriticalFile, ProfilingSummary from services.report import ReadOnlyReport -from services.yaml import YamlStates, get_yaml_state +from services.yaml import ( + YamlStates, + get_yaml_state, +) commit_bindable = ObjectType("Commit") @@ -52,6 +55,8 @@ commit_bindable.set_alias("pullId", "pullid") commit_bindable.set_alias("branchName", "branch") +log = logging.getLogger(__name__) + @commit_bindable.field("coverageFile") @sync_to_async @@ -306,9 +311,9 @@ def resolve_path_contents(commit: Commit, info, path: str = None, filters=None): @commit_bindable.field("errors") -async def resolve_errors(commit, info, errorType): +async def resolve_errors(commit, info, error_type): command = info.context["executor"].get_command("commit") - queryset = await command.get_commit_errors(commit, error_type=errorType) + queryset = await command.get_commit_errors(commit, error_type=error_type) return await queryset_to_connection( queryset, ordering=("updated_at",), @@ -325,9 +330,9 @@ async def resolve_total_uploads(commit, info): @commit_bindable.field("components") @sync_to_async def resolve_components(commit: Commit, info, filters=None) -> List[Component]: - request = info.context["request"] info.context["component_commit"] = commit - all_components = components_service.commit_components(commit, request.user) + current_owner = info.context["request"].current_owner + all_components = components_service.commit_components(commit, current_owner) if filters and filters.get("components"): return components_service.filter_components_by_name( diff --git a/graphql_api/types/comparison/__init__.py b/graphql_api/types/comparison/__init__.py index a739e59196..6344af7cbd 100644 --- a/graphql_api/types/comparison/__init__.py +++ b/graphql_api/types/comparison/__init__.py @@ -3,3 +3,6 @@ from .comparison import comparison_bindable, comparison_result_bindable comparison = ariadne_load_local_graphql(__file__, "comparison.graphql") + + +__all__ = ["comparison_bindable", "comparison_result_bindable"] diff --git a/graphql_api/types/comparison/comparison.graphql b/graphql_api/types/comparison/comparison.graphql index 70dc3ec93f..8ae51b0c22 100644 --- a/graphql_api/types/comparison/comparison.graphql +++ b/graphql_api/types/comparison/comparison.graphql @@ -2,7 +2,6 @@ type Comparison { state: String! impactedFile(path: String!): ImpactedFile impactedFiles(filters: ImpactedFilesFilters): ImpactedFilesResult! - impactedFilesDeprecated(filters: ImpactedFilesFilters): [ImpactedFile]! impactedFilesCount: Int! indirectChangedFilesCount: Int! patchTotals: CoverageTotals diff --git a/graphql_api/types/comparison/comparison.py b/graphql_api/types/comparison/comparison.py index 66c85c55e6..5aeb006a05 100644 --- a/graphql_api/types/comparison/comparison.py +++ b/graphql_api/types/comparison/comparison.py @@ -1,6 +1,7 @@ from asyncio import gather from typing import List, Optional +import sentry_sdk from ariadne import ObjectType, UnionType, convert_kwargs_to_snake_case from graphql.type.definition import GraphQLResolveInfo @@ -24,7 +25,6 @@ ComparisonReport, FirstPullRequest, ImpactedFile, - MissingComparisonReport, ) comparison_bindable = ObjectType("Comparison") @@ -35,18 +35,6 @@ def resolve_state(comparison: ComparisonReport, info: GraphQLResolveInfo) -> str return comparison.commit_comparison.state -@comparison_bindable.field("impactedFilesDeprecated") -@convert_kwargs_to_snake_case -@sync_to_async -def resolve_impacted_files_deprecated( - comparison_report: ComparisonReport, info: GraphQLResolveInfo, filters=None -) -> List[ImpactedFile]: - command: CompareCommands = info.context["executor"].get_command("compare") - comparison: Comparison = info.context.get("comparison", None) - - return command.fetch_impacted_files(comparison_report, comparison, filters) - - @comparison_bindable.field("impactedFiles") @convert_kwargs_to_snake_case @sync_to_async @@ -250,20 +238,21 @@ def resolve_flag_comparisons_count( return get_flag_comparisons(comparison.commit_comparison).count() +@sentry_sdk.trace @comparison_bindable.field("hasDifferentNumberOfHeadAndBaseReports") @sync_to_async def resolve_has_different_number_of_head_and_base_reports( comparison: ComparisonReport, info: GraphQLResolveInfo, **kwargs, # type: ignore -) -> False: +) -> bool: # TODO: can we remove the need for `info.context["comparison"]` here? if "comparison" not in info.context: return False comparison: Comparison = info.context["comparison"] try: comparison.validate() - except MissingComparisonReport: + except Exception: return False return comparison.has_different_number_of_head_and_base_sessions diff --git a/graphql_api/types/component/__init__.py b/graphql_api/types/component/__init__.py index d306ee4fe0..bd81cc7112 100644 --- a/graphql_api/types/component/__init__.py +++ b/graphql_api/types/component/__init__.py @@ -3,3 +3,5 @@ from .component import component_bindable component = ariadne_load_local_graphql(__file__, "component.graphql") + +__all__ = ["component_bindable"] diff --git a/graphql_api/types/component_comparison/__init__.py b/graphql_api/types/component_comparison/__init__.py index 421748c1d2..43598a549f 100644 --- a/graphql_api/types/component_comparison/__init__.py +++ b/graphql_api/types/component_comparison/__init__.py @@ -5,3 +5,6 @@ component_comparison = ariadne_load_local_graphql( __file__, "component_comparison.graphql" ) + + +__all__ = ["component_comparison_bindable"] diff --git a/graphql_api/types/component_comparison/component_comparison.py b/graphql_api/types/component_comparison/component_comparison.py index 95e58adabd..f9dcef5497 100644 --- a/graphql_api/types/component_comparison/component_comparison.py +++ b/graphql_api/types/component_comparison/component_comparison.py @@ -1,5 +1,3 @@ -from typing import List - from ariadne import ObjectType from shared.reports.types import ReportTotals diff --git a/graphql_api/types/config/__init__.py b/graphql_api/types/config/__init__.py index d928f9d4b4..1bfee13154 100644 --- a/graphql_api/types/config/__init__.py +++ b/graphql_api/types/config/__init__.py @@ -5,3 +5,6 @@ from .config import config_bindable config = ariadne_load_local_graphql(__file__, "config.graphql") + + +__all__ = ["get_current_license", "config_bindable"] diff --git a/graphql_api/types/config/config.py b/graphql_api/types/config/config.py index fcf892d058..87c57ef627 100644 --- a/graphql_api/types/config/config.py +++ b/graphql_api/types/config/config.py @@ -109,7 +109,7 @@ def resolve_self_hosted_license(_, info): license = self_hosted.get_current_license() if not license.is_valid: - None + return None return license diff --git a/graphql_api/types/coverage_totals/__init__.py b/graphql_api/types/coverage_totals/__init__.py index e5a9a4f299..4b7d597f06 100644 --- a/graphql_api/types/coverage_totals/__init__.py +++ b/graphql_api/types/coverage_totals/__init__.py @@ -3,3 +3,6 @@ from .coverage_totals import coverage_totals_bindable coverage_totals = ariadne_load_local_graphql(__file__, "coverage_totals.graphql") + + +__all__ = ["coverage_totals_bindable"] diff --git a/graphql_api/types/enums/__init__.py b/graphql_api/types/enums/__init__.py index f92f08d53f..e43ca662c2 100644 --- a/graphql_api/types/enums/__init__.py +++ b/graphql_api/types/enums/__init__.py @@ -1,4 +1,6 @@ from .enums import ( + AssetOrdering, + BundleLoadTypes, CommitErrorCode, CommitErrorGeneralType, CommitStatus, @@ -11,8 +13,31 @@ PullRequestState, RepositoryOrdering, SyncProvider, + TestResultsOrderingParameter, TypeProjectOnboarding, UploadErrorEnum, UploadState, UploadType, ) + +__all__ = [ + "AssetOrdering", + "BundleLoadTypes", + "CommitErrorCode", + "CommitErrorGeneralType", + "CommitStatus", + "CoverageLine", + "GoalOnboarding", + "LoginProvider", + "OrderingDirection", + "OrderingParameter", + "PathContentDisplayType", + "PullRequestState", + "RepositoryOrdering", + "SyncProvider", + "TestResultsOrderingParameter", + "TypeProjectOnboarding", + "UploadErrorEnum", + "UploadState", + "UploadType", +] diff --git a/graphql_api/types/enums/asset_ordering.graphql b/graphql_api/types/enums/asset_ordering.graphql new file mode 100644 index 0000000000..fb11939dbf --- /dev/null +++ b/graphql_api/types/enums/asset_ordering.graphql @@ -0,0 +1,5 @@ +enum AssetOrdering { + NAME + SIZE + TYPE +} \ No newline at end of file diff --git a/graphql_api/types/enums/enum_types.py b/graphql_api/types/enums/enum_types.py index 5565109474..5dd6411ba4 100644 --- a/graphql_api/types/enums/enum_types.py +++ b/graphql_api/types/enums/enum_types.py @@ -11,6 +11,8 @@ from timeseries.models import MeasurementName from .enums import ( + AssetOrdering, + BundleLoadTypes, CoverageLine, GoalOnboarding, LoginProvider, @@ -20,6 +22,7 @@ PullRequestState, RepositoryOrdering, SyncProvider, + TestResultsOrderingParameter, TypeProjectOnboarding, UploadErrorEnum, UploadState, @@ -48,4 +51,7 @@ EnumType("TierName", TierName), EnumType("TrialStatus", TrialStatus), EnumType("YamlStates", YamlStates), + EnumType("BundleLoadTypes", BundleLoadTypes), + EnumType("TestResultsOrderingParameter", TestResultsOrderingParameter), + EnumType("AssetOrdering", AssetOrdering), ] diff --git a/graphql_api/types/enums/enums.py b/graphql_api/types/enums/enums.py index 5e9182aa41..1fed70b76f 100644 --- a/graphql_api/types/enums/enums.py +++ b/graphql_api/types/enums/enums.py @@ -1,4 +1,5 @@ import enum +from typing import Self class OrderingParameter(enum.Enum): @@ -10,6 +11,13 @@ class OrderingParameter(enum.Enum): LINES = "lines" +class TestResultsOrderingParameter(enum.Enum): + AVG_DURATION = "avg_duration" + FAILURE_RATE = "failure_rate" + COMMITS_WHERE_FAIL = "commits_where_fail" + UPDATED_AT = "updated_at" + + class PathContentDisplayType(enum.Enum): TREE = "tree" LIST = "list" @@ -103,12 +111,12 @@ class CommitErrorCode(enum.Enum): yaml_unknown_error = ("yaml_unknown_error", CommitErrorGeneralType.yaml_error) repo_bot_invalid = ("repo_bot_invalid", CommitErrorGeneralType.bot_error) - def __init__(self, db_string, error_type): + def __init__(self, db_string: str, error_type: CommitErrorGeneralType): self.db_string = db_string self.error_type = error_type @classmethod - def get_codes_from_type(cls, error_type): + def get_codes_from_type(cls, error_type: CommitErrorGeneralType) -> list[Self]: return [item for item in cls if item.error_type == error_type] @@ -116,3 +124,15 @@ class CommitStatus(enum.Enum): COMPLETED = "COMPLETED" ERROR = "ERROR" PENDING = "PENDING" + + +class BundleLoadTypes(enum.Enum): + ENTRY = "ENTRY" + INITIAL = "INITIAL" + LAZY = "LAZY" + + +class AssetOrdering(enum.Enum): + NAME = "name" + SIZE = "size" + TYPE = "asset_type" diff --git a/graphql_api/types/enums/test_results_ordering_parameter.graphql b/graphql_api/types/enums/test_results_ordering_parameter.graphql new file mode 100644 index 0000000000..bf5ae033f7 --- /dev/null +++ b/graphql_api/types/enums/test_results_ordering_parameter.graphql @@ -0,0 +1,6 @@ +enum TestResultsOrderingParameter { + AVG_DURATION + FAILURE_RATE + COMMITS_WHERE_FAIL + UPDATED_AT +} diff --git a/graphql_api/types/errors/__init__.py b/graphql_api/types/errors/__init__.py index d692768c5f..2351d803c6 100644 --- a/graphql_api/types/errors/__init__.py +++ b/graphql_api/types/errors/__init__.py @@ -8,3 +8,14 @@ ProviderError, UnknownPath, ) + +__all__ = [ + "MissingBaseCommit", + "MissingBaseReport", + "MissingComparison", + "MissingCoverage", + "MissingHeadCommit", + "MissingHeadReport", + "ProviderError", + "UnknownPath", +] diff --git a/graphql_api/types/file/__init__.py b/graphql_api/types/file/__init__.py index e546088f19..09e7e120d9 100644 --- a/graphql_api/types/file/__init__.py +++ b/graphql_api/types/file/__init__.py @@ -3,3 +3,6 @@ from .file import file_bindable commit_file = ariadne_load_local_graphql(__file__, "file.graphql") + + +__all__ = ["file_bindable"] diff --git a/graphql_api/types/file/file.py b/graphql_api/types/file/file.py index f92fa0e706..21a51a14f3 100644 --- a/graphql_api/types/file/file.py +++ b/graphql_api/types/file/file.py @@ -1,6 +1,4 @@ import hashlib -import math -from fractions import Fraction from ariadne import ObjectType from shared.utils.merge import LineType, line_type diff --git a/graphql_api/types/flag/__init__.py b/graphql_api/types/flag/__init__.py index 46ae0e95d2..97313d98c6 100644 --- a/graphql_api/types/flag/__init__.py +++ b/graphql_api/types/flag/__init__.py @@ -5,3 +5,6 @@ flag = ariadne_load_local_graphql(__file__, "flag.graphql") flag += build_connection_graphql("FlagConnection", "Flag") + + +__all__ = ["flag_bindable"] diff --git a/graphql_api/types/flag_comparison/__init__.py b/graphql_api/types/flag_comparison/__init__.py index ba9ac3fc53..2589d8e9bb 100644 --- a/graphql_api/types/flag_comparison/__init__.py +++ b/graphql_api/types/flag_comparison/__init__.py @@ -4,3 +4,6 @@ from .flag_comparison import flag_comparison_bindable flag_comparison = ariadne_load_local_graphql(__file__, "flag_comparison.graphql") + + +__all__ = ["flag_comparison_bindable", "build_connection_graphql"] diff --git a/graphql_api/types/impacted_file/__init__.py b/graphql_api/types/impacted_file/__init__.py index 8272b18a3a..2412e2ff20 100644 --- a/graphql_api/types/impacted_file/__init__.py +++ b/graphql_api/types/impacted_file/__init__.py @@ -3,3 +3,6 @@ from .impacted_file import impacted_file_bindable, impacted_files_result_bindable impacted_file = ariadne_load_local_graphql(__file__, "impacted_file.graphql") + + +__all__ = ["impacted_file_bindable", "impacted_files_result_bindable"] diff --git a/graphql_api/types/impacted_file/impacted_file.py b/graphql_api/types/impacted_file/impacted_file.py index 38dc173565..3dc5e87476 100644 --- a/graphql_api/types/impacted_file/impacted_file.py +++ b/graphql_api/types/impacted_file/impacted_file.py @@ -1,6 +1,7 @@ import hashlib from typing import List, Union +import sentry_sdk from ariadne import ObjectType, UnionType, convert_kwargs_to_snake_case from shared.reports.types import ReportTotals from shared.torngit.exceptions import TorngitClientError @@ -63,6 +64,7 @@ def resolve_hashed_path(impacted_file: ImpactedFile, info) -> str: return md5_path.hexdigest() +@sentry_sdk.trace @impacted_file_bindable.field("segments") @sync_to_async @convert_kwargs_to_snake_case diff --git a/graphql_api/types/inputs/bundle_analysis_filters.graphql b/graphql_api/types/inputs/bundle_analysis_filters.graphql index 4876c60e38..08e2cec746 100644 --- a/graphql_api/types/inputs/bundle_analysis_filters.graphql +++ b/graphql_api/types/inputs/bundle_analysis_filters.graphql @@ -1,5 +1,6 @@ input BundleAnalysisReportFilters { - moduleExtensions: [String!] + reportGroups: [BundleReportGroups!] + loadTypes: [BundleLoadTypes!] } input BundleAnalysisMeasurementsSetFilters { diff --git a/graphql_api/types/inputs/encode_secret_string.graphql b/graphql_api/types/inputs/encode_secret_string.graphql new file mode 100644 index 0000000000..2ea2e0910a --- /dev/null +++ b/graphql_api/types/inputs/encode_secret_string.graphql @@ -0,0 +1,4 @@ +input EncodeSecretStringInput { + repoName: String! + value: String! +} \ No newline at end of file diff --git a/graphql_api/types/inputs/test_results_filters.graphql b/graphql_api/types/inputs/test_results_filters.graphql new file mode 100644 index 0000000000..b61079a57d --- /dev/null +++ b/graphql_api/types/inputs/test_results_filters.graphql @@ -0,0 +1,8 @@ +input TestResultsFilters { + branch: String +} + +input TestResultsOrdering { + direction: OrderingDirection + parameter: TestResultsOrderingParameter +} diff --git a/graphql_api/types/invoice/__init__.py b/graphql_api/types/invoice/__init__.py index 96e50ca471..1b1a23413e 100644 --- a/graphql_api/types/invoice/__init__.py +++ b/graphql_api/types/invoice/__init__.py @@ -3,3 +3,6 @@ from .invoice import invoice_bindable invoice = ariadne_load_local_graphql(__file__, "invoice.graphql") + + +__all__ = ["invoice_bindable"] diff --git a/graphql_api/types/invoice/invoice.graphql b/graphql_api/types/invoice/invoice.graphql index 09b5169d70..1c6b2ae1ec 100644 --- a/graphql_api/types/invoice/invoice.graphql +++ b/graphql_api/types/invoice/invoice.graphql @@ -17,6 +17,7 @@ type Invoice { status: String subtotal: Float! total: Float! + taxIds: [TaxInfo] } type LineItem { @@ -57,3 +58,8 @@ type Address { postalCode: String state: String } + +type TaxInfo { + type: String! + value: String! +} diff --git a/graphql_api/types/invoice/invoice.py b/graphql_api/types/invoice/invoice.py index 936e5a57cd..fedb1ea857 100644 --- a/graphql_api/types/invoice/invoice.py +++ b/graphql_api/types/invoice/invoice.py @@ -1,6 +1,3 @@ -from functools import cached_property -from typing import List, Optional - from ariadne import ObjectType from graphql import GraphQLResolveInfo from stripe import ( @@ -115,3 +112,8 @@ def resolve_invoice_default_payment_method( invoice: Invoice, info: GraphQLResolveInfo ) -> PaymentMethod | None: return invoice["default_payment_method"] + + +@invoice_bindable.field("taxIds") +def resolve_invoice_tax_ids(invoice: Invoice, info: GraphQLResolveInfo) -> list: + return invoice["customer_tax_ids"] diff --git a/graphql_api/types/line_comparison/__init__.py b/graphql_api/types/line_comparison/__init__.py index eb6a04e9d0..b8f87ab586 100644 --- a/graphql_api/types/line_comparison/__init__.py +++ b/graphql_api/types/line_comparison/__init__.py @@ -3,3 +3,6 @@ from .line_comparison import line_comparison_bindable line_comparison = ariadne_load_local_graphql(__file__, "line_comparison.graphql") + + +__all__ = ["line_comparison_bindable"] diff --git a/graphql_api/types/me/__init__.py b/graphql_api/types/me/__init__.py index b914769f2c..f589aad98d 100644 --- a/graphql_api/types/me/__init__.py +++ b/graphql_api/types/me/__init__.py @@ -1 +1,3 @@ from .me import me, me_bindable, tracking_metadata_bindable + +__all__ = ["me", "me_bindable", "tracking_metadata_bindable"] diff --git a/graphql_api/types/me/me.py b/graphql_api/types/me/me.py index 6a6a16d60a..954c1a8b05 100644 --- a/graphql_api/types/me/me.py +++ b/graphql_api/types/me/me.py @@ -1,9 +1,11 @@ from typing import Optional from ariadne import ObjectType, convert_kwargs_to_snake_case +from graphql import GraphQLResolveInfo from codecov.db import sync_to_async -from codecov_auth.models import Owner, OwnerProfile, User +from codecov_auth.models import Owner, OwnerProfile +from codecov_auth.views.okta_cloud import OKTA_SIGNED_IN_ACCOUNTS_SESSION_KEY from graphql_api.actions.owner import ( get_owner_login_sessions, get_user_tokens, @@ -42,13 +44,23 @@ def resolve_owner(user, _): @convert_kwargs_to_snake_case def resolve_viewable_repositories( current_user, - _, + info: GraphQLResolveInfo, filters=None, ordering=RepositoryOrdering.ID, ordering_direction=OrderingDirection.ASC, **kwargs, ): - queryset = search_repos(current_user, filters) + okta_authenticated_accounts: list[int] = info.context["request"].session.get( + OKTA_SIGNED_IN_ACCOUNTS_SESSION_KEY, [] + ) + is_impersonation = info.context["request"].impersonation + # If the user is impersonating another user, we want to show all the Okta repos. + # This means we do not want to filter out the Okta enforced repos + exclude_okta_enforced_repos = not is_impersonation + + queryset = search_repos( + current_user, filters, okta_authenticated_accounts, exclude_okta_enforced_repos + ) return queryset_to_connection( queryset, ordering=(ordering, RepositoryOrdering.ID), diff --git a/graphql_api/types/measurement/__init__.py b/graphql_api/types/measurement/__init__.py index 5b756774f7..8ce7990a1b 100644 --- a/graphql_api/types/measurement/__init__.py +++ b/graphql_api/types/measurement/__init__.py @@ -3,3 +3,6 @@ from .measurement import measurement_bindable measurement = ariadne_load_local_graphql(__file__, "measurement.graphql") + + +__all__ = ["measurement_bindable"] diff --git a/graphql_api/types/mutation/__init__.py b/graphql_api/types/mutation/__init__.py index 42ae08e70f..c4045976fe 100644 --- a/graphql_api/types/mutation/__init__.py +++ b/graphql_api/types/mutation/__init__.py @@ -7,17 +7,20 @@ from .delete_component_measurements import gql_delete_component_measurements from .delete_flag import gql_delete_flag from .delete_session import gql_delete_session +from .encode_secret_string import gql_encode_secret_string from .erase_repository import gql_erase_repository -from .mutation import mutation_resolvers +from .mutation import mutation_resolvers # noqa: F401 from .onboard_user import gql_onboard_user from .regenerate_org_upload_token import gql_regenerate_org_upload_token from .regenerate_repository_token import gql_regenerate_repository_token from .regenerate_repository_upload_token import gql_regenerate_repository_upload_token from .revoke_user_token import gql_revoke_user_token +from .save_okta_config import gql_save_okta_config from .save_sentry_state import gql_save_sentry_state from .save_terms_agreement import gql_save_terms_agreement from .set_yaml_on_owner import gql_set_yaml_on_owner from .start_trial import gql_start_trial +from .store_event_metrics import gql_store_event_metrics from .sync_with_git_provider import gql_sync_with_git_provider from .update_default_organization import gql_update_default_organization from .update_profile import gql_update_profile @@ -47,3 +50,6 @@ mutation = mutation + gql_update_repository mutation = mutation + gql_update_self_hosted_settings mutation = mutation + gql_regenerate_repository_upload_token +mutation = mutation + gql_encode_secret_string +mutation = mutation + gql_store_event_metrics +mutation = mutation + gql_save_okta_config diff --git a/graphql_api/types/mutation/activate_measurements/__init__.py b/graphql_api/types/mutation/activate_measurements/__init__.py index f56ebd6a47..104281a26d 100644 --- a/graphql_api/types/mutation/activate_measurements/__init__.py +++ b/graphql_api/types/mutation/activate_measurements/__init__.py @@ -8,3 +8,6 @@ gql_activate_measurements = ariadne_load_local_graphql( __file__, "activate_measurements.graphql" ) + + +__all__ = ["error_activate_measurements", "resolve_activate_measurements"] diff --git a/graphql_api/types/mutation/activate_measurements/activate_measurements.py b/graphql_api/types/mutation/activate_measurements/activate_measurements.py index 5a18871722..88aad675b2 100644 --- a/graphql_api/types/mutation/activate_measurements/activate_measurements.py +++ b/graphql_api/types/mutation/activate_measurements/activate_measurements.py @@ -6,7 +6,6 @@ resolve_union_error_type, wrap_error_handling_mutation, ) -from timeseries.models import MeasurementName @wrap_error_handling_mutation @@ -15,8 +14,8 @@ async def resolve_activate_measurements(_, info, input): command: RepositoryCommands = info.context["executor"].get_command("repository") await command.activate_measurements( owner_name=input.get("owner"), - repo_name=input.get("repoName"), - measurement_type=input.get("measurementType"), + repo_name=input.get("repo_name"), + measurement_type=input.get("measurement_type"), ) return None diff --git a/graphql_api/types/mutation/cancel_trial/__init__.py b/graphql_api/types/mutation/cancel_trial/__init__.py index 1e0136c321..826a34ff61 100644 --- a/graphql_api/types/mutation/cancel_trial/__init__.py +++ b/graphql_api/types/mutation/cancel_trial/__init__.py @@ -3,3 +3,6 @@ from .cancel_trial import error_cancel_trial, resolve_cancel_trial gql_cancel_trial = ariadne_load_local_graphql(__file__, "cancel_trial.graphql") + + +__all__ = ["error_cancel_trial", "resolve_cancel_trial"] diff --git a/graphql_api/types/mutation/cancel_trial/cancel_trial.py b/graphql_api/types/mutation/cancel_trial/cancel_trial.py index 0f0e906955..b6f0f59d36 100644 --- a/graphql_api/types/mutation/cancel_trial/cancel_trial.py +++ b/graphql_api/types/mutation/cancel_trial/cancel_trial.py @@ -12,7 +12,7 @@ @require_authenticated async def resolve_cancel_trial(_, info, input) -> None: command: OwnerCommands = info.context["executor"].get_command("owner") - await command.cancel_trial(input.get("orgUsername")) + await command.cancel_trial(input.get("org_username")) return None diff --git a/graphql_api/types/mutation/create_api_token/__init__.py b/graphql_api/types/mutation/create_api_token/__init__.py index 3c32566340..7b1ff06527 100644 --- a/graphql_api/types/mutation/create_api_token/__init__.py +++ b/graphql_api/types/mutation/create_api_token/__init__.py @@ -3,3 +3,6 @@ from .create_api_token import error_create_api_token, resolve_create_api_token gql_create_api_token = ariadne_load_local_graphql(__file__, "create_api_token.graphql") + + +__all__ = ["error_create_api_token", "resolve_create_api_token"] diff --git a/graphql_api/types/mutation/create_user_token/__init__.py b/graphql_api/types/mutation/create_user_token/__init__.py index 1217132877..dd8b1a10b2 100644 --- a/graphql_api/types/mutation/create_user_token/__init__.py +++ b/graphql_api/types/mutation/create_user_token/__init__.py @@ -5,3 +5,5 @@ gql_create_user_token = ariadne_load_local_graphql( __file__, "create_user_token.graphql" ) + +__all__ = ["error_create_user_token", "resolve_create_user_token"] diff --git a/graphql_api/types/mutation/create_user_token/create_user_token.py b/graphql_api/types/mutation/create_user_token/create_user_token.py index bc75549b67..b012bd96c8 100644 --- a/graphql_api/types/mutation/create_user_token/create_user_token.py +++ b/graphql_api/types/mutation/create_user_token/create_user_token.py @@ -11,7 +11,7 @@ async def resolve_create_user_token(_, info, input): command = info.context["executor"].get_command("owner") user_token = await command.create_user_token( name=input.get("name"), - token_type=input.get("tokenType"), + token_type=input.get("token_type"), ) return { "token": user_token, diff --git a/graphql_api/types/mutation/delete_component_measurements/__init__.py b/graphql_api/types/mutation/delete_component_measurements/__init__.py index e06f8f50ac..656d8689c0 100644 --- a/graphql_api/types/mutation/delete_component_measurements/__init__.py +++ b/graphql_api/types/mutation/delete_component_measurements/__init__.py @@ -8,3 +8,9 @@ gql_delete_component_measurements = ariadne_load_local_graphql( __file__, "delete_component_measurements.graphql" ) + + +__all__ = [ + "error_delete_component_measurements", + "resolve_delete_component_measurements", +] diff --git a/graphql_api/types/mutation/delete_flag/__init__.py b/graphql_api/types/mutation/delete_flag/__init__.py index a5167f9026..b4aead7635 100644 --- a/graphql_api/types/mutation/delete_flag/__init__.py +++ b/graphql_api/types/mutation/delete_flag/__init__.py @@ -3,3 +3,6 @@ from .delete_flag import error_delete_flag, resolve_delete_flag gql_delete_flag = ariadne_load_local_graphql(__file__, "delete_flag.graphql") + + +__all__ = ["error_delete_flag", "resolve_delete_flag"] diff --git a/graphql_api/types/mutation/delete_session/__init__.py b/graphql_api/types/mutation/delete_session/__init__.py index 97ed927505..a33094e0ca 100644 --- a/graphql_api/types/mutation/delete_session/__init__.py +++ b/graphql_api/types/mutation/delete_session/__init__.py @@ -3,3 +3,6 @@ from .delete_session import error_delete_session, resolve_delete_session gql_delete_session = ariadne_load_local_graphql(__file__, "delete_session.graphql") + + +__all__ = ["error_delete_session", "resolve_delete_session"] diff --git a/graphql_api/types/mutation/encode_secret_string/__init__.py b/graphql_api/types/mutation/encode_secret_string/__init__.py new file mode 100644 index 0000000000..110dd77c94 --- /dev/null +++ b/graphql_api/types/mutation/encode_secret_string/__init__.py @@ -0,0 +1,12 @@ +from graphql_api.helpers.ariadne import ariadne_load_local_graphql + +from .encode_secret_string import ( + error_encode_secret_string, + resolve_encode_secret_string, +) + +gql_encode_secret_string = ariadne_load_local_graphql( + __file__, "encode_secret_string.graphql" +) + +__all__ = ["error_encode_secret_string", "resolve_encode_secret_string"] diff --git a/graphql_api/types/mutation/encode_secret_string/encode_secret_string.graphql b/graphql_api/types/mutation/encode_secret_string/encode_secret_string.graphql new file mode 100644 index 0000000000..03662a7167 --- /dev/null +++ b/graphql_api/types/mutation/encode_secret_string/encode_secret_string.graphql @@ -0,0 +1,6 @@ +union EncodeSecretStringError = ValidationError | UnauthenticatedError + +type EncodeSecretStringPayload { + error: EncodeSecretStringError + value: String +} \ No newline at end of file diff --git a/graphql_api/types/mutation/encode_secret_string/encode_secret_string.py b/graphql_api/types/mutation/encode_secret_string/encode_secret_string.py new file mode 100644 index 0000000000..caaa1e5be9 --- /dev/null +++ b/graphql_api/types/mutation/encode_secret_string/encode_secret_string.py @@ -0,0 +1,24 @@ +from ariadne import UnionType + +from graphql_api.helpers.mutation import ( + require_authenticated, + resolve_union_error_type, + wrap_error_handling_mutation, +) + + +@wrap_error_handling_mutation +@require_authenticated +async def resolve_encode_secret_string(_, info, input) -> None: + command = info.context["executor"].get_command("repository") + repo_name = input.get("repo_name") + value = input.get("value") + current_owner = info.context["request"].current_owner + value = command.encode_secret_string( + repo_name=repo_name, owner=current_owner, value=value + ) + return {"value": value} + + +error_encode_secret_string = UnionType("EraseRepositoryError") +error_encode_secret_string.type_resolver(resolve_union_error_type) diff --git a/graphql_api/types/mutation/erase_repository/__init__.py b/graphql_api/types/mutation/erase_repository/__init__.py index 2115f6d981..1ad6b3e491 100644 --- a/graphql_api/types/mutation/erase_repository/__init__.py +++ b/graphql_api/types/mutation/erase_repository/__init__.py @@ -3,3 +3,6 @@ from .erase_repository import error_erase_repository, resolve_erase_repository gql_erase_repository = ariadne_load_local_graphql(__file__, "erase_repository.graphql") + + +__all__ = ["error_erase_repository", "resolve_erase_repository"] diff --git a/graphql_api/types/mutation/erase_repository/erase_repository.py b/graphql_api/types/mutation/erase_repository/erase_repository.py index 7a57215a98..2e6bafbae0 100644 --- a/graphql_api/types/mutation/erase_repository/erase_repository.py +++ b/graphql_api/types/mutation/erase_repository/erase_repository.py @@ -12,7 +12,7 @@ async def resolve_erase_repository(_, info, input) -> None: command = info.context["executor"].get_command("repository") current_owner = info.context["request"].current_owner - repo_name = input.get("repoName") + repo_name = input.get("repo_name") await command.erase_repository(repo_name=repo_name, owner=current_owner) return None diff --git a/graphql_api/types/mutation/mutation.graphql b/graphql_api/types/mutation/mutation.graphql index 80ef130371..ab519fe399 100644 --- a/graphql_api/types/mutation/mutation.graphql +++ b/graphql_api/types/mutation/mutation.graphql @@ -24,9 +24,18 @@ type Mutation { deleteFlag(input: DeleteFlagInput!): DeleteFlagPayload saveSentryState(input: SaveSentryStateInput!): SaveSentryStatePayload saveTermsAgreement(input: SaveTermsAgreementInput!): SaveTermsAgreementPayload - deleteComponentMeasurements(input: DeleteComponentMeasurementsInput!): DeleteComponentMeasurementsPayload + deleteComponentMeasurements( + input: DeleteComponentMeasurementsInput! + ): DeleteComponentMeasurementsPayload eraseRepository(input: EraseRepositoryInput!): EraseRepositoryPayload updateRepository(input: UpdateRepositoryInput!): UpdateRepositoryPayload - updateSelfHostedSettings(input: UpdateSelfHostedSettingsInput!): UpdateSelfHostedSettingsPayload - regenerateRepositoryUploadToken(input: RegenerateRepositoryUploadTokenInput!): RegenerateRepositoryUploadTokenPayload + updateSelfHostedSettings( + input: UpdateSelfHostedSettingsInput! + ): UpdateSelfHostedSettingsPayload + regenerateRepositoryUploadToken( + input: RegenerateRepositoryUploadTokenInput! + ): RegenerateRepositoryUploadTokenPayload + encodeSecretString(input: EncodeSecretStringInput!): EncodeSecretStringPayload + storeEventMetric(input: StoreEventMetricsInput!): StoreEventMetricsPayload + saveOktaConfig(input: SaveOktaConfigInput!): SaveOktaConfigPayload } diff --git a/graphql_api/types/mutation/mutation.py b/graphql_api/types/mutation/mutation.py index e9d71f9132..a036a51077 100644 --- a/graphql_api/types/mutation/mutation.py +++ b/graphql_api/types/mutation/mutation.py @@ -13,6 +13,10 @@ ) from .delete_flag import error_delete_flag, resolve_delete_flag from .delete_session import error_delete_session, resolve_delete_session +from .encode_secret_string import ( + error_encode_secret_string, + resolve_encode_secret_string, +) from .erase_repository import error_erase_repository, resolve_erase_repository from .onboard_user import error_onboard_user, resolve_onboard_user from .regenerate_org_upload_token import ( @@ -28,6 +32,7 @@ resolve_regenerate_repository_upload_token, ) from .revoke_user_token import error_revoke_user_token, resolve_revoke_user_token +from .save_okta_config import error_save_okta_config, resolve_save_okta_config from .save_sentry_state import error_save_sentry_state, resolve_save_sentry_state from .save_terms_agreement import ( error_save_terms_agreement, @@ -35,6 +40,7 @@ ) from .set_yaml_on_owner import error_set_yaml_error, resolve_set_yaml_on_owner from .start_trial import error_start_trial, resolve_start_trial +from .store_event_metrics import error_store_event_metrics, resolve_store_event_metrics from .sync_with_git_provider import ( error_sync_with_git_provider, resolve_sync_with_git_provider, @@ -83,7 +89,11 @@ mutation_bindable.field("regenerateRepositoryUploadToken")( resolve_regenerate_repository_upload_token ) +mutation_bindable.field("encodeSecretString")(resolve_encode_secret_string) + +mutation_bindable.field("storeEventMetric")(resolve_store_event_metrics) +mutation_bindable.field("saveOktaConfig")(resolve_save_okta_config) mutation_resolvers = [ mutation_bindable, @@ -109,4 +119,7 @@ error_update_repository, error_update_self_hosted_settings, error_regenerate_repository_upload_token, + error_encode_secret_string, + error_store_event_metrics, + error_save_okta_config, ] diff --git a/graphql_api/types/mutation/onboard_user/__init__.py b/graphql_api/types/mutation/onboard_user/__init__.py index ef028f36ca..cd6db9f451 100644 --- a/graphql_api/types/mutation/onboard_user/__init__.py +++ b/graphql_api/types/mutation/onboard_user/__init__.py @@ -3,3 +3,6 @@ from .onboard_user import error_onboard_user, resolve_onboard_user gql_onboard_user = ariadne_load_local_graphql(__file__, "onboard_user.graphql") + + +__all__ = ["error_onboard_user", "resolve_onboard_user"] diff --git a/graphql_api/types/mutation/regenerate_org_upload_token/__init__.py b/graphql_api/types/mutation/regenerate_org_upload_token/__init__.py index 7e23fb65db..111d89fc80 100644 --- a/graphql_api/types/mutation/regenerate_org_upload_token/__init__.py +++ b/graphql_api/types/mutation/regenerate_org_upload_token/__init__.py @@ -8,3 +8,5 @@ gql_regenerate_org_upload_token = ariadne_load_local_graphql( __file__, "regenerate_org_upload_token.graphql" ) + +__all__ = ["error_generate_org_upload_token", "resolve_regenerate_org_upload_token"] diff --git a/graphql_api/types/mutation/regenerate_repository_token/__init__.py b/graphql_api/types/mutation/regenerate_repository_token/__init__.py index 485790c6ec..10cb2852e8 100644 --- a/graphql_api/types/mutation/regenerate_repository_token/__init__.py +++ b/graphql_api/types/mutation/regenerate_repository_token/__init__.py @@ -8,3 +8,6 @@ gql_regenerate_repository_token = ariadne_load_local_graphql( __file__, "regenerate_repository_token.graphql" ) + + +__all__ = ["error_regenerate_repository_token", "resolve_regenerate_repository_token"] diff --git a/graphql_api/types/mutation/regenerate_repository_upload_token/__init__.py b/graphql_api/types/mutation/regenerate_repository_upload_token/__init__.py index 1e2cf7cd45..93f0511986 100644 --- a/graphql_api/types/mutation/regenerate_repository_upload_token/__init__.py +++ b/graphql_api/types/mutation/regenerate_repository_upload_token/__init__.py @@ -8,3 +8,9 @@ gql_regenerate_repository_upload_token = ariadne_load_local_graphql( __file__, "regenerate_repository_upload_token.graphql" ) + + +__all__ = [ + "error_regenerate_repository_upload_token", + "resolve_regenerate_repository_upload_token", +] diff --git a/graphql_api/types/mutation/regenerate_repository_upload_token/regenerate_repository_upload_token.py b/graphql_api/types/mutation/regenerate_repository_upload_token/regenerate_repository_upload_token.py index d47571248d..60e51f8922 100644 --- a/graphql_api/types/mutation/regenerate_repository_upload_token/regenerate_repository_upload_token.py +++ b/graphql_api/types/mutation/regenerate_repository_upload_token/regenerate_repository_upload_token.py @@ -13,7 +13,7 @@ async def resolve_regenerate_repository_upload_token(_, info, input): command: RepositoryCommands = info.context["executor"].get_command("repository") token = await command.regenerate_repository_upload_token( - repo_name=input.get("repoName"), + repo_name=input.get("repo_name"), owner_username=input.get("owner"), ) diff --git a/graphql_api/types/mutation/revoke_user_token/__init__.py b/graphql_api/types/mutation/revoke_user_token/__init__.py index e99a6e29e5..cf6f8fd43c 100644 --- a/graphql_api/types/mutation/revoke_user_token/__init__.py +++ b/graphql_api/types/mutation/revoke_user_token/__init__.py @@ -5,3 +5,6 @@ gql_revoke_user_token = ariadne_load_local_graphql( __file__, "revoke_user_token.graphql" ) + + +__all__ = ["error_revoke_user_token", "resolve_revoke_user_token"] diff --git a/graphql_api/types/mutation/save_okta_config/__init__.py b/graphql_api/types/mutation/save_okta_config/__init__.py new file mode 100644 index 0000000000..15fd54e5f2 --- /dev/null +++ b/graphql_api/types/mutation/save_okta_config/__init__.py @@ -0,0 +1,8 @@ +from graphql_api.helpers.ariadne import ariadne_load_local_graphql + +from .save_okta_config import error_save_okta_config, resolve_save_okta_config + +gql_save_okta_config = ariadne_load_local_graphql(__file__, "save_okta_config.graphql") + + +__all__ = ["error_save_okta_config", "resolve_save_okta_config"] diff --git a/graphql_api/types/mutation/save_okta_config/save_okta_config.graphql b/graphql_api/types/mutation/save_okta_config/save_okta_config.graphql new file mode 100644 index 0000000000..2e57fedf59 --- /dev/null +++ b/graphql_api/types/mutation/save_okta_config/save_okta_config.graphql @@ -0,0 +1,17 @@ +union SaveOktaConfigError = + UnauthenticatedError + | UnauthorizedError + | ValidationError + +type SaveOktaConfigPayload { + error: SaveOktaConfigError +} + +input SaveOktaConfigInput { + clientId: String + clientSecret: String + url: String + enabled: Boolean + enforced: Boolean + orgUsername: String +} diff --git a/graphql_api/types/mutation/save_okta_config/save_okta_config.py b/graphql_api/types/mutation/save_okta_config/save_okta_config.py new file mode 100644 index 0000000000..ad1152c4b9 --- /dev/null +++ b/graphql_api/types/mutation/save_okta_config/save_okta_config.py @@ -0,0 +1,19 @@ +from ariadne import UnionType, convert_kwargs_to_snake_case + +from graphql_api.helpers.mutation import ( + require_authenticated, + resolve_union_error_type, + wrap_error_handling_mutation, +) + + +@wrap_error_handling_mutation +@require_authenticated +@convert_kwargs_to_snake_case +async def resolve_save_okta_config(_, info, input): + command = info.context["executor"].get_command("owner") + return await command.save_okta_config(input) + + +error_save_okta_config = UnionType("SaveOktaConfigError") +error_save_okta_config.type_resolver(resolve_union_error_type) diff --git a/graphql_api/types/mutation/save_sentry_state/__init__.py b/graphql_api/types/mutation/save_sentry_state/__init__.py index 58285bf788..0d90e90e4f 100644 --- a/graphql_api/types/mutation/save_sentry_state/__init__.py +++ b/graphql_api/types/mutation/save_sentry_state/__init__.py @@ -5,3 +5,5 @@ gql_save_sentry_state = ariadne_load_local_graphql( __file__, "save_sentry_state.graphql" ) + +__all__ = ["error_save_sentry_state", "resolve_save_sentry_state"] diff --git a/graphql_api/types/mutation/save_terms_agreement/__init__.py b/graphql_api/types/mutation/save_terms_agreement/__init__.py index 19b0f1a349..d65cf7c938 100644 --- a/graphql_api/types/mutation/save_terms_agreement/__init__.py +++ b/graphql_api/types/mutation/save_terms_agreement/__init__.py @@ -8,3 +8,5 @@ gql_save_terms_agreement = ariadne_load_local_graphql( __file__, "save_terms_agreement.graphql" ) + +__all__ = ["error_save_terms_agreement", "resolve_save_terms_agreement"] diff --git a/graphql_api/types/mutation/set_yaml_on_owner/__init__.py b/graphql_api/types/mutation/set_yaml_on_owner/__init__.py index 719fbe1c6e..1d5a8aaa10 100644 --- a/graphql_api/types/mutation/set_yaml_on_owner/__init__.py +++ b/graphql_api/types/mutation/set_yaml_on_owner/__init__.py @@ -5,3 +5,5 @@ gql_set_yaml_on_owner = ariadne_load_local_graphql( __file__, "set_yaml_on_owner.graphql" ) + +__all__ = ["error_set_yaml_error", "resolve_set_yaml_on_owner"] diff --git a/graphql_api/types/mutation/start_trial/__init__.py b/graphql_api/types/mutation/start_trial/__init__.py index d0c354bf33..46192ad9a4 100644 --- a/graphql_api/types/mutation/start_trial/__init__.py +++ b/graphql_api/types/mutation/start_trial/__init__.py @@ -3,3 +3,6 @@ from .start_trial import error_start_trial, resolve_start_trial gql_start_trial = ariadne_load_local_graphql(__file__, "start_trial.graphql") + + +__all__ = ["error_start_trial", "resolve_start_trial"] diff --git a/graphql_api/types/mutation/start_trial/start_trial.py b/graphql_api/types/mutation/start_trial/start_trial.py index d212ffe998..034aedfd37 100644 --- a/graphql_api/types/mutation/start_trial/start_trial.py +++ b/graphql_api/types/mutation/start_trial/start_trial.py @@ -12,7 +12,7 @@ @require_authenticated async def resolve_start_trial(_, info, input) -> None: command: OwnerCommands = info.context["executor"].get_command("owner") - await command.start_trial(input.get("orgUsername")) + await command.start_trial(input.get("org_username")) return None diff --git a/graphql_api/types/mutation/store_event_metrics/__init__.py b/graphql_api/types/mutation/store_event_metrics/__init__.py new file mode 100644 index 0000000000..74abfb4f84 --- /dev/null +++ b/graphql_api/types/mutation/store_event_metrics/__init__.py @@ -0,0 +1,10 @@ +from graphql_api.helpers.ariadne import ariadne_load_local_graphql + +from .store_event_metrics import error_store_event_metrics, resolve_store_event_metrics + +gql_store_event_metrics = ariadne_load_local_graphql( + __file__, "store_event_metrics.graphql" +) + + +__all__ = ["error_store_event_metrics", "resolve_store_event_metrics"] diff --git a/graphql_api/types/mutation/store_event_metrics/store_event_metrics.graphql b/graphql_api/types/mutation/store_event_metrics/store_event_metrics.graphql new file mode 100644 index 0000000000..4a0faccc72 --- /dev/null +++ b/graphql_api/types/mutation/store_event_metrics/store_event_metrics.graphql @@ -0,0 +1,11 @@ +union StoreEventMetricsError = UnauthenticatedError | ValidationError + +type StoreEventMetricsPayload { + error: StoreEventMetricsError +} + +input StoreEventMetricsInput { + orgUsername: String! + eventName: String! + jsonPayload: String # The input expects a serialized json string +} diff --git a/graphql_api/types/mutation/store_event_metrics/store_event_metrics.py b/graphql_api/types/mutation/store_event_metrics/store_event_metrics.py new file mode 100644 index 0000000000..e7c3d322da --- /dev/null +++ b/graphql_api/types/mutation/store_event_metrics/store_event_metrics.py @@ -0,0 +1,22 @@ +from ariadne import UnionType + +from codecov_auth.commands.owner import OwnerCommands +from graphql_api.helpers.mutation import ( + require_authenticated, + resolve_union_error_type, + wrap_error_handling_mutation, +) + + +@wrap_error_handling_mutation +@require_authenticated +async def resolve_store_event_metrics(_, info, input) -> None: + command: OwnerCommands = info.context["executor"].get_command("owner") + await command.store_codecov_metric( + input.get("org_username"), input.get("event_name"), input.get("json_payload") + ) + return None + + +error_store_event_metrics = UnionType("StoreEventMetricsError") +error_store_event_metrics.type_resolver(resolve_union_error_type) diff --git a/graphql_api/types/mutation/sync_with_git_provider/__init__.py b/graphql_api/types/mutation/sync_with_git_provider/__init__.py index f583e0cce5..8f5f43c266 100644 --- a/graphql_api/types/mutation/sync_with_git_provider/__init__.py +++ b/graphql_api/types/mutation/sync_with_git_provider/__init__.py @@ -8,3 +8,6 @@ gql_sync_with_git_provider = ariadne_load_local_graphql( __file__, "sync_with_git_provider.graphql" ) + + +__all__ = ["error_sync_with_git_provider", "resolve_sync_with_git_provider"] diff --git a/graphql_api/types/mutation/update_default_organization/__init__.py b/graphql_api/types/mutation/update_default_organization/__init__.py index 019716ff07..b0f7faccf4 100644 --- a/graphql_api/types/mutation/update_default_organization/__init__.py +++ b/graphql_api/types/mutation/update_default_organization/__init__.py @@ -8,3 +8,6 @@ gql_update_default_organization = ariadne_load_local_graphql( __file__, "update_default_organization.graphql" ) + + +__all__ = ["error_update_default_organization", "resolve_update_default_organization"] diff --git a/graphql_api/types/mutation/update_profile/__init__.py b/graphql_api/types/mutation/update_profile/__init__.py index 3f8443eea5..86add287c4 100644 --- a/graphql_api/types/mutation/update_profile/__init__.py +++ b/graphql_api/types/mutation/update_profile/__init__.py @@ -3,3 +3,6 @@ from .update_profile import error_update_profile, resolve_update_profile gql_update_profile = ariadne_load_local_graphql(__file__, "update_profile.graphql") + + +__all__ = ["error_update_profile", "resolve_update_profile"] diff --git a/graphql_api/types/mutation/update_repository/__init__.py b/graphql_api/types/mutation/update_repository/__init__.py index 200542ae16..c9f3549acf 100644 --- a/graphql_api/types/mutation/update_repository/__init__.py +++ b/graphql_api/types/mutation/update_repository/__init__.py @@ -5,3 +5,6 @@ gql_update_repository = ariadne_load_local_graphql( __file__, "update_repository.graphql" ) + + +__all__ = ["error_update_repository", "resolve_update_repository"] diff --git a/graphql_api/types/mutation/update_repository/update_repository.py b/graphql_api/types/mutation/update_repository/update_repository.py index c8d9f4df85..75aa943a3f 100644 --- a/graphql_api/types/mutation/update_repository/update_repository.py +++ b/graphql_api/types/mutation/update_repository/update_repository.py @@ -10,7 +10,7 @@ async def resolve_update_repository(_, info, input): command = info.context["executor"].get_command("repository") owner = info.context["request"].current_owner - repo_name = input.get("repoName") + repo_name = input.get("repo_name") default_branch = input.get("branch") activated = input.get("activated") await command.update_repository( diff --git a/graphql_api/types/mutation/update_self_hosted_settings/__init__.py b/graphql_api/types/mutation/update_self_hosted_settings/__init__.py index c0cd5030f9..dbb5299a6e 100644 --- a/graphql_api/types/mutation/update_self_hosted_settings/__init__.py +++ b/graphql_api/types/mutation/update_self_hosted_settings/__init__.py @@ -8,3 +8,9 @@ gql_update_self_hosted_settings = ariadne_load_local_graphql( __file__, "update_self_hosted_settings.graphql" ) + +__all__ = [ + "gql_update_self_hosted_settings", + "error_update_self_hosted_settings", + "resolve_update_self_hosted_settings", +] diff --git a/graphql_api/types/okta_config/__init__.py b/graphql_api/types/okta_config/__init__.py new file mode 100644 index 0000000000..46b28c6ba6 --- /dev/null +++ b/graphql_api/types/okta_config/__init__.py @@ -0,0 +1,10 @@ +from shared.license import get_current_license + +from graphql_api.helpers.ariadne import ariadne_load_local_graphql + +from .okta_config import okta_config_bindable + +okta_config = ariadne_load_local_graphql(__file__, "okta_config.graphql") + + +__all__ = ["get_current_license", "okta_config_bindable"] diff --git a/graphql_api/types/okta_config/okta_config.graphql b/graphql_api/types/okta_config/okta_config.graphql new file mode 100644 index 0000000000..5cb6f4a4bf --- /dev/null +++ b/graphql_api/types/okta_config/okta_config.graphql @@ -0,0 +1,7 @@ +type OktaConfig { + clientId: String! + clientSecret: String! + url: String! + enabled: Boolean! + enforced: Boolean! +} diff --git a/graphql_api/types/okta_config/okta_config.py b/graphql_api/types/okta_config/okta_config.py new file mode 100644 index 0000000000..a9b07ca289 --- /dev/null +++ b/graphql_api/types/okta_config/okta_config.py @@ -0,0 +1,30 @@ +from ariadne import ObjectType + +from codecov_auth.models import OktaSettings + +okta_config_bindable = ObjectType("OktaConfig") + + +@okta_config_bindable.field("clientId") +def resolve_client_id(okta_config: OktaSettings, info) -> str: + return okta_config.client_id + + +@okta_config_bindable.field("clientSecret") +def resolve_client_secret(okta_config: OktaSettings, info) -> str: + return okta_config.client_secret + + +@okta_config_bindable.field("url") +def resolve_url(okta_config: OktaSettings, info) -> str: + return okta_config.url + + +@okta_config_bindable.field("enabled") +def resolve_enabled(okta_config: OktaSettings, info) -> bool: + return okta_config.enabled + + +@okta_config_bindable.field("enforced") +def resolve_enforced(okta_config: OktaSettings, info) -> bool: + return okta_config.enforced diff --git a/graphql_api/types/owner/__init__.py b/graphql_api/types/owner/__init__.py index 92a7c70530..a484f87370 100644 --- a/graphql_api/types/owner/__init__.py +++ b/graphql_api/types/owner/__init__.py @@ -1 +1,6 @@ from .owner import owner, owner_bindable + +__all__ = [ + "owner", + "owner_bindable", +] diff --git a/graphql_api/types/owner/owner.graphql b/graphql_api/types/owner/owner.graphql index 562429346b..fceebf0f0f 100644 --- a/graphql_api/types/owner/owner.graphql +++ b/graphql_api/types/owner/owner.graphql @@ -1,29 +1,18 @@ type Owner { + account: Account + availablePlans: [PlanRepresentation!] avatarUrl: String! - username: String - isCurrentUserPartOfOrg: Boolean! - yaml: String - repositories( - filters: RepositorySetFilters - ordering: RepositoryOrdering - orderingDirection: OrderingDirection - first: Int - after: String - last: Int - before: String - ): RepositoryConnection! @cost(complexity: 25, multipliers: ["first", "last"]) - repository(name: String!): RepositoryResult! - numberOfUploads: Int + defaultOrgUsername: String + delinquent: Boolean + hashOwnerid: String hasPrivateRepos: Boolean + invoice(invoiceId: String!): Invoice + invoices: [Invoice] @cost(complexity: 100) isAdmin: Boolean - hashOwnerid: String - ownerid: Int - plan: Plan - pretrialPlan: PlanRepresentation - availablePlans: [PlanRepresentation!] - orgUploadToken: String - defaultOrgUsername: String isCurrentUserActivated: Boolean + isCurrentUserPartOfOrg: Boolean! + isGithubRateLimited: Boolean + isUserOktaAuthenticated: Boolean measurements( interval: MeasurementInterval! after: DateTime @@ -31,6 +20,21 @@ type Owner { repos: [String!] isPublic: Boolean ): [Measurement!] - invoices: [Invoice] @cost(complexity: 100) - invoice(invoiceId: String!): Invoice + numberOfUploads: Int + orgUploadToken: String + ownerid: Int + plan: Plan + pretrialPlan: PlanRepresentation + repository(name: String!): RepositoryResult! + repositories( + filters: RepositorySetFilters + ordering: RepositoryOrdering + orderingDirection: OrderingDirection + first: Int + after: String + last: Int + before: String + ): RepositoryConnection! @cost(complexity: 25, multipliers: ["first", "last"]) + username: String + yaml: String } diff --git a/graphql_api/types/owner/owner.py b/graphql_api/types/owner/owner.py index 91c289ba92..5b38492fe9 100644 --- a/graphql_api/types/owner/owner.py +++ b/graphql_api/types/owner/owner.py @@ -2,6 +2,7 @@ from hashlib import sha1 from typing import Iterable, List, Optional +import shared.rate_limits as rate_limits import stripe import yaml from ariadne import ObjectType, convert_kwargs_to_snake_case @@ -10,7 +11,13 @@ import timeseries.helpers as timeseries_helpers from codecov.db import sync_to_async from codecov_auth.helpers import current_user_part_of_org -from codecov_auth.models import Owner +from codecov_auth.models import ( + SERVICE_GITHUB, + SERVICE_GITHUB_ENTERPRISE, + Account, + Owner, +) +from codecov_auth.views.okta_cloud import OKTA_SIGNED_IN_ACCOUNTS_SESSION_KEY from core.models import Repository from graphql_api.actions.repository import list_repository_for_owner from graphql_api.helpers.ariadne import ariadne_load_local_graphql @@ -25,6 +32,7 @@ from plan.service import PlanService from services.billing import BillingService from services.profiling import ProfilingSummary +from services.redis_configuration import get_redis_connection from timeseries.helpers import fill_sparse_measurements from timeseries.models import Interval, MeasurementSummary @@ -36,7 +44,7 @@ @owner_bindable.field("repositories") @convert_kwargs_to_snake_case def resolve_repositories( - owner, + owner: Owner, info, filters=None, ordering=RepositoryOrdering.ID, @@ -44,7 +52,18 @@ def resolve_repositories( **kwargs, ): current_owner = info.context["request"].current_owner - queryset = list_repository_for_owner(current_owner, owner, filters) + okta_account_auths: list[int] = info.context["request"].session.get( + OKTA_SIGNED_IN_ACCOUNTS_SESSION_KEY, [] + ) + + is_impersonation = info.context["request"].impersonation + # If the user is impersonating another user, we want to show all the Okta repos. + # This means we do not want to filter out the Okta enforced repos + exclude_okta_enforced_repos = not is_impersonation + + queryset = list_repository_for_owner( + current_owner, owner, filters, okta_account_auths, exclude_okta_enforced_repos + ) return queryset_to_connection( queryset, ordering=(ordering, RepositoryOrdering.ID), @@ -61,7 +80,7 @@ def resolve_is_current_user_part_of_org(owner, info): @owner_bindable.field("yaml") -def resolve_yaml(owner, info): +def resolve_yaml(owner: Owner, info): if owner.yaml is None: return current_owner = info.context["request"].current_owner @@ -103,14 +122,28 @@ def resolve_has_private_repos(owner: Owner, info) -> List[PlanData]: @owner_bindable.field("ownerid") @require_part_of_org -def resolve_ownerid(owner, info) -> int: +def resolve_ownerid(owner: Owner, info) -> int: return owner.ownerid @owner_bindable.field("repository") -async def resolve_repository(owner, info, name): +async def resolve_repository(owner: Owner, info, name): command = info.context["executor"].get_command("repository") - repository: Optional[Repository] = await command.fetch_repository(owner, name) + okta_authenticated_accounts: list[int] = info.context["request"].session.get( + OKTA_SIGNED_IN_ACCOUNTS_SESSION_KEY, [] + ) + + is_impersonation = info.context["request"].impersonation + # If the user is impersonating another user, we want to show all the Okta repos. + # This means we do not want to filter out the Okta enforced repos + exclude_okta_enforced_repos = not is_impersonation + + repository: Optional[Repository] = await command.fetch_repository( + owner, + name, + okta_authenticated_accounts, + exclude_okta_enforced_repos=exclude_okta_enforced_repos, + ) if repository is None: return NotFoundError() @@ -134,14 +167,14 @@ async def resolve_repository(owner, info, name): @owner_bindable.field("numberOfUploads") @require_part_of_org -async def resolve_number_of_uploads(owner, info, **kwargs): +async def resolve_number_of_uploads(owner: Owner, info, **kwargs): command = info.context["executor"].get_command("owner") return await command.get_uploads_number_per_user(owner) @owner_bindable.field("isAdmin") @require_part_of_org -def resolve_is_current_user_an_admin(owner, info): +def resolve_is_current_user_an_admin(owner: Owner, info): current_owner = info.context["request"].current_owner command = info.context["executor"].get_command("owner") return command.get_is_current_user_an_admin(owner, current_owner) @@ -149,14 +182,14 @@ def resolve_is_current_user_an_admin(owner, info): @owner_bindable.field("hashOwnerid") @require_part_of_org -def resolve_hash_ownerid(owner, info): +def resolve_hash_ownerid(owner: Owner, info): hash_ownerid = sha1(str(owner.ownerid).encode()) return hash_ownerid.hexdigest() @owner_bindable.field("orgUploadToken") @require_part_of_org -def resolve_org_upload_token(owner, info, **kwargs): +def resolve_org_upload_token(owner: Owner, info, **kwargs): command = info.context["executor"].get_command("owner") return command.get_org_upload_token(owner) @@ -182,7 +215,15 @@ def resolve_measurements( ) -> Iterable[MeasurementSummary]: current_owner = info.context["request"].current_owner - queryset = Repository.objects.filter(author=owner).viewable_repos(current_owner) + okta_authenticated_accounts: list[int] = info.context["request"].session.get( + OKTA_SIGNED_IN_ACCOUNTS_SESSION_KEY, [] + ) + + queryset = ( + Repository.objects.filter(author=owner) + .viewable_repos(current_owner) + .exclude_accounts_enforced_okta(okta_authenticated_accounts) + ) if is_public is not None: queryset = queryset.filter(private=not is_public) @@ -208,7 +249,7 @@ def resolve_measurements( @owner_bindable.field("isCurrentUserActivated") @sync_to_async -def resolve_is_current_user_activated(owner, info): +def resolve_is_current_user_activated(owner: Owner, info): current_user = info.context["request"].user if not current_user.is_authenticated: return False @@ -234,6 +275,20 @@ def resolve_owner_invoices(owner: Owner, info) -> list | None: return BillingService(requesting_user=owner).list_filtered_invoices(owner, 100) +@owner_bindable.field("isGithubRateLimited") +@sync_to_async +def resolve_is_github_rate_limited(owner: Owner, info) -> bool | None: + if owner.service != SERVICE_GITHUB and owner.service != SERVICE_GITHUB_ENTERPRISE: + return False + redis_connection = get_redis_connection() + rate_limit_redis_key = rate_limits.determine_entity_redis_key( + owner=owner, repository=None + ) + return rate_limits.determine_if_entity_is_rate_limited( + redis_connection, rate_limit_redis_key + ) + + @owner_bindable.field("invoice") @require_part_of_org @convert_kwargs_to_snake_case @@ -243,3 +298,30 @@ def resolve_owner_invoice( invoice_id: str, ) -> stripe.Invoice | None: return BillingService(requesting_user=owner).get_invoice(owner, invoice_id) + + +@owner_bindable.field("account") +@require_part_of_org +@sync_to_async +def resolve_owner_account(owner: Owner, info) -> dict: + account_id = owner.account_id + return Account.objects.filter(pk=account_id).first() + + +@owner_bindable.field("isUserOktaAuthenticated") +@sync_to_async +@require_part_of_org +def resolve_is_user_okta_authenticated(owner: Owner, info) -> bool: + okta_signed_in_accounts = info.context["request"].session.get( + OKTA_SIGNED_IN_ACCOUNTS_SESSION_KEY, + [], + ) + if owner.account_id: + return owner.account_id in okta_signed_in_accounts + return False + + +@owner_bindable.field("delinquent") +@require_part_of_org +def resolve_delinquent(owner: Owner, info) -> bool | None: + return owner.delinquent diff --git a/graphql_api/types/path_contents/__init__.py b/graphql_api/types/path_contents/__init__.py index b8a457bf10..14233ff3f6 100644 --- a/graphql_api/types/path_contents/__init__.py +++ b/graphql_api/types/path_contents/__init__.py @@ -7,3 +7,11 @@ ) path_content = ariadne_load_local_graphql(__file__, "path_content.graphql") + + +__all__ = [ + "path_content", + "path_content_bindable", + "path_content_file_bindable", + "path_contents_result_bindable", +] diff --git a/graphql_api/types/plan/__init__.py b/graphql_api/types/plan/__init__.py index fde8387cd1..b60ae134c6 100644 --- a/graphql_api/types/plan/__init__.py +++ b/graphql_api/types/plan/__init__.py @@ -1 +1,6 @@ from .plan import plan, plan_bindable + +__all__ = [ + "plan", + "plan_bindable", +] diff --git a/graphql_api/types/plan/plan.py b/graphql_api/types/plan/plan.py index c071896f18..9fe306680f 100644 --- a/graphql_api/types/plan/plan.py +++ b/graphql_api/types/plan/plan.py @@ -3,14 +3,9 @@ from ariadne import ObjectType, convert_kwargs_to_snake_case +from codecov.db import sync_to_async from graphql_api.helpers.ariadne import ariadne_load_local_graphql from plan.constants import ( - MonthlyUploadLimits, - PlanBillingRate, - PlanMarketingName, - PlanName, - PlanPrice, - TierName, TrialStatus, ) from plan.service import PlanService @@ -25,6 +20,13 @@ def resolve_trial_start_date(plan_service: PlanService, info) -> Optional[dateti return plan_service.trial_start_date +@plan_bindable.field("trialTotalDays") +@convert_kwargs_to_snake_case +@sync_to_async +def resolve_trial_total_days(plan_service: PlanService, info) -> Optional[int]: + return plan_service.trial_total_days + + @plan_bindable.field("trialEndDate") @convert_kwargs_to_snake_case def resolve_trial_end_date(plan_service: PlanService, info) -> Optional[datetime]: @@ -39,6 +41,7 @@ def resolve_trial_status(plan_service: PlanService, info) -> TrialStatus: @plan_bindable.field("marketingName") @convert_kwargs_to_snake_case +@sync_to_async def resolve_marketing_name(plan_service: PlanService, info) -> str: return plan_service.marketing_name @@ -56,23 +59,27 @@ def resolve_plan_name_as_value(plan_service: PlanService, info) -> str: @plan_bindable.field("tierName") @convert_kwargs_to_snake_case +@sync_to_async def resolve_tier_name(plan_service: PlanService, info) -> str: return plan_service.tier_name @plan_bindable.field("billingRate") @convert_kwargs_to_snake_case +@sync_to_async def resolve_billing_rate(plan_service: PlanService, info) -> Optional[str]: return plan_service.billing_rate @plan_bindable.field("baseUnitPrice") @convert_kwargs_to_snake_case +@sync_to_async def resolve_base_unit_price(plan_service: PlanService, info) -> int: return plan_service.base_unit_price @plan_bindable.field("benefits") +@sync_to_async def resolve_benefits(plan_service: PlanService, info) -> List[str]: return plan_service.benefits @@ -87,17 +94,20 @@ def resolve_pretrial_users_count(plan_service: PlanService, info) -> Optional[in @plan_bindable.field("monthlyUploadLimit") @convert_kwargs_to_snake_case +@sync_to_async def resolve_monthly_uploads_limit(plan_service: PlanService, info) -> Optional[int]: return plan_service.monthly_uploads_limit @plan_bindable.field("planUserCount") @convert_kwargs_to_snake_case +@sync_to_async def resolve_plan_user_count(plan_service: PlanService, info) -> int: return plan_service.plan_user_count @plan_bindable.field("hasSeatsLeft") @convert_kwargs_to_snake_case +@sync_to_async def resolve_has_seats_left(plan_service: PlanService, info) -> bool: return plan_service.has_seats_left diff --git a/graphql_api/types/plan_representation/__init__.py b/graphql_api/types/plan_representation/__init__.py index 99ae2869ae..f033ede1a7 100644 --- a/graphql_api/types/plan_representation/__init__.py +++ b/graphql_api/types/plan_representation/__init__.py @@ -1 +1,6 @@ from .plan_representation import plan_representation, plan_representation_bindable + +__all__ = [ + "plan_representation", + "plan_representation_bindable", +] diff --git a/graphql_api/types/profile/__init__.py b/graphql_api/types/profile/__init__.py index 26ef27c755..ebf29f3210 100644 --- a/graphql_api/types/profile/__init__.py +++ b/graphql_api/types/profile/__init__.py @@ -3,3 +3,9 @@ from .profile import profile_bindable profile = ariadne_load_local_graphql(__file__, "profile.graphql") + + +__all__ = [ + "profile", + "profile_bindable", +] diff --git a/graphql_api/types/pull/__init__.py b/graphql_api/types/pull/__init__.py index 66c956f527..6eb770cb08 100644 --- a/graphql_api/types/pull/__init__.py +++ b/graphql_api/types/pull/__init__.py @@ -3,3 +3,9 @@ from .pull import pull_bindable pull = ariadne_load_local_graphql(__file__, "pull.graphql") + + +__all__ = [ + "pull", + "pull_bindable", +] diff --git a/graphql_api/types/pull/pull.py b/graphql_api/types/pull/pull.py index ab3cce43bd..6bbb568840 100644 --- a/graphql_api/types/pull/pull.py +++ b/graphql_api/types/pull/pull.py @@ -1,7 +1,11 @@ +from typing import Any, Optional, Union + from ariadne import ObjectType -from asgiref.sync import async_to_sync +from graphql import GraphQLResolveInfo from codecov.db import sync_to_async +from codecov_auth.models import Owner +from compare.models import CommitComparison from core.models import Commit, Pull from graphql_api.actions.commits import pull_commits from graphql_api.actions.comparison import validate_commit_comparison @@ -9,7 +13,7 @@ from graphql_api.dataloader.commit import CommitLoader from graphql_api.dataloader.comparison import ComparisonLoader from graphql_api.dataloader.owner import OwnerLoader -from graphql_api.helpers.connection import queryset_to_connection_sync +from graphql_api.helpers.connection import Connection, queryset_to_connection_sync from graphql_api.types.comparison.comparison import ( FirstPullRequest, MissingBaseCommit, @@ -25,37 +29,39 @@ @pull_bindable.field("state") -def resolve_state(pull, info) -> PullRequestState: +def resolve_state(pull: Pull, info: GraphQLResolveInfo) -> PullRequestState: return PullRequestState(pull.state) @pull_bindable.field("author") -def resolve_author(pull, info): +def resolve_author(pull: Pull, info: GraphQLResolveInfo) -> Optional[Owner]: if pull.author_id: return OwnerLoader.loader(info).load(pull.author_id) @pull_bindable.field("head") -def resolve_head(pull, info): - if pull.head == None: +def resolve_head(pull: Pull, info: GraphQLResolveInfo) -> Optional[Commit]: + if pull.head is None: return None return CommitLoader.loader(info, pull.repository_id).load(pull.head) @pull_bindable.field("comparedTo") -def resolve_base(pull, info): - if pull.compared_to == None: +def resolve_base(pull: Pull, info: GraphQLResolveInfo) -> Optional[Commit]: + if pull.compared_to is None: return None return CommitLoader.loader(info, pull.repository_id).load(pull.compared_to) @sync_to_async -def is_first_pull_request(pull: Pull): +def is_first_pull_request(pull: Pull) -> bool: return pull.repository.pull_requests.order_by("id").first() == pull @pull_bindable.field("compareWithBase") -async def resolve_compare_with_base(pull, info, **kwargs): +async def resolve_compare_with_base( + pull: Pull, info: GraphQLResolveInfo, **kwargs: Any +) -> Union[CommitComparison, Any]: if not pull.compared_to: if await is_first_pull_request(pull): return FirstPullRequest() @@ -84,18 +90,28 @@ async def resolve_compare_with_base(pull, info, **kwargs): @pull_bindable.field("bundleAnalysisCompareWithBase") @sync_to_async -def resolve_bundle_analysis_compare_with_base(pull, info, **kwargs): +def resolve_bundle_analysis_compare_with_base( + pull: Pull, info: GraphQLResolveInfo, **kwargs: Any +) -> Union[BundleAnalysisComparison, Any]: if not pull.compared_to: if pull.repository.pull_requests.order_by("id").first() == pull: return FirstPullRequest() else: return MissingBaseCommit() - if not pull.head: - return MissingHeadCommit() + + # Handles a case where the PR was created without any uploads because all bundles + # from the build are cached. Instead of showing a "no commit error" we will instead + # show the parent bundle report as it implies everything was cached and carried + # over to the head commit + head_commit_sha = pull.head if pull.head else pull.compared_to bundle_analysis_comparison = load_bundle_analysis_comparison( - Commit.objects.filter(commitid=pull.compared_to).first(), - Commit.objects.filter(commitid=pull.head).first(), + Commit.objects.filter( + commitid=pull.compared_to, repository=pull.repository + ).first(), + Commit.objects.filter( + commitid=head_commit_sha, repository=pull.repository + ).first(), ) # Store the created SQLite DB path in info.context @@ -117,7 +133,7 @@ def resolve_bundle_analysis_compare_with_base(pull, info, **kwargs): @pull_bindable.field("commits") @sync_to_async -def resolve_commits(pull: Pull, info, **kwargs): +def resolve_commits(pull: Pull, info: GraphQLResolveInfo, **kwargs: Any) -> Connection: queryset = pull_commits(pull) return queryset_to_connection_sync( @@ -129,17 +145,19 @@ def resolve_commits(pull: Pull, info, **kwargs): @pull_bindable.field("behindBy") -def resolve_behind_by(pull: Pull, info, **kwargs) -> int: +def resolve_behind_by(pull: Pull, info: GraphQLResolveInfo, **kwargs: Any) -> int: return pull.behind_by @pull_bindable.field("behindByCommit") -def resolve_behind_by_commit(pull: Pull, info, **kwargs) -> str: +def resolve_behind_by_commit( + pull: Pull, info: GraphQLResolveInfo, **kwargs: Any +) -> str: return pull.behind_by_commit @pull_bindable.field("firstPull") @sync_to_async -def resolve_first_pull(pull: Pull, info) -> bool: +def resolve_first_pull(pull: Pull, info: GraphQLResolveInfo) -> bool: # returns true if this pull is/was the 1st for a repo return pull.repository.pull_requests.order_by("id").first() == pull diff --git a/graphql_api/types/query/__init__.py b/graphql_api/types/query/__init__.py index 09203079ba..66521fb059 100644 --- a/graphql_api/types/query/__init__.py +++ b/graphql_api/types/query/__init__.py @@ -1 +1,6 @@ from .query import query, query_bindable + +__all__ = [ + "query", + "query_bindable", +] diff --git a/graphql_api/types/query/query.py b/graphql_api/types/query/query.py index 14a52c6fc6..22bc13e3fa 100644 --- a/graphql_api/types/query/query.py +++ b/graphql_api/types/query/query.py @@ -2,15 +2,14 @@ from ariadne import ObjectType from django.conf import settings -from graphql import GraphQLError, GraphQLResolveInfo -from sentry_sdk import configure_scope +from graphql import GraphQLResolveInfo +from sentry_sdk import Scope from codecov.commands.exceptions import UnauthorizedGuestAccess from codecov.db import sync_to_async from codecov_auth.models import Owner from graphql_api.actions.owner import get_owner from graphql_api.helpers.ariadne import ariadne_load_local_graphql -from utils.services import get_long_service_name query = ariadne_load_local_graphql(__file__, "query.graphql") query_bindable = ObjectType("Query") @@ -27,9 +26,9 @@ def configure_sentry_scope(query_name: str): # we're configuring this here since it's the main entrypoint into GraphQL resolvers # https://docs.sentry.io/platforms/python/enriching-events/transaction-name/ - with configure_scope() as scope: - if scope.transaction: - scope.transaction.name = f"GraphQL [{query_name}]" + scope = Scope.get_current_scope() + if scope.transaction: + scope.transaction.name = f"GraphQL [{query_name}]" @query_bindable.field("me") diff --git a/graphql_api/types/repository/__init__.py b/graphql_api/types/repository/__init__.py index ec9a4df6db..0f42fc5137 100644 --- a/graphql_api/types/repository/__init__.py +++ b/graphql_api/types/repository/__init__.py @@ -3,3 +3,10 @@ from .repository import repository_bindable, repository_result_bindable repository = ariadne_load_local_graphql(__file__, "repository.graphql") + + +__all__ = [ + "repository", + "repository_bindable", + "repository_result_bindable", +] diff --git a/graphql_api/types/repository/repository.graphql b/graphql_api/types/repository/repository.graphql index 6089918351..45e06f95fb 100644 --- a/graphql_api/types/repository/repository.graphql +++ b/graphql_api/types/repository/repository.graphql @@ -81,8 +81,27 @@ type Repository { orderingDirection: OrderingDirection ): [ComponentMeasurements!]! componentsYaml(termId: String): [ComponentsYaml]! - encodedSecretString(value: String!): EncodedSecretString! testAnalyticsEnabled: Boolean + isGithubRateLimited: Boolean + testResults( + filters: TestResultsFilters + ordering: TestResultsOrdering + first: Int + after: String + last: Int + before: String + ): TestResultConnection! @cost(complexity: 10, multipliers: ["first", "last"]) +} + +type TestResultConnection { + edges: [TestResultEdge]! + totalCount: Int! + pageInfo: PageInfo! +} + +type TestResultEdge { + cursor: String! + node: TestResult! } type PullConnection { @@ -118,8 +137,4 @@ type BranchEdge { node: Branch! } -type EncodedSecretString { - value: String! -} - union RepositoryResult = Repository | NotFoundError | OwnerNotActivatedError diff --git a/graphql_api/types/repository/repository.py b/graphql_api/types/repository/repository.py index 5b81aa0935..2620e7d90a 100644 --- a/graphql_api/types/repository/repository.py +++ b/graphql_api/types/repository/repository.py @@ -1,6 +1,8 @@ +import logging from datetime import datetime from typing import Iterable, List, Mapping, Optional +import shared.rate_limits as rate_limits import yaml from ariadne import ObjectType, UnionType, convert_kwargs_to_snake_case from django.conf import settings @@ -10,8 +12,9 @@ import timeseries.helpers as timeseries_helpers from codecov.db import sync_to_async +from codecov_auth.models import SERVICE_GITHUB, SERVICE_GITHUB_ENTERPRISE from core.models import Branch, Repository -from graphql_api.actions.commits import commit_status, repo_commits +from graphql_api.actions.commits import repo_commits from graphql_api.actions.components import ( component_measurements, component_measurements_last_uploaded, @@ -28,8 +31,12 @@ from graphql_api.types.errors.errors import NotFoundError, OwnerNotActivatedError from services.components import ComponentMeasurements from services.profiling import CriticalFile, ProfilingSummary +from services.redis_configuration import get_redis_connection from timeseries.helpers import fill_sparse_measurements from timeseries.models import Dataset, Interval, MeasurementName, MeasurementSummary +from utils.test_results import aggregate_test_results + +log = logging.getLogger(__name__) repository_bindable = ObjectType("Repository") @@ -538,11 +545,51 @@ def resolve_is_first_pull_request(repository: Repository, info) -> bool: return False -@repository_bindable.field("encodedSecretString") +@repository_bindable.field("isGithubRateLimited") @sync_to_async -def resolve_encoded_secret_string( - repository: Repository, info: GraphQLResolveInfo, value: str -) -> dict[str, str]: - command = info.context["executor"].get_command("repository") - owner = info.context["request"].current_owner - return {"value": command.encode_secret_string(owner, repository, value)} +def resolve_is_github_rate_limited(repository: Repository, info) -> bool | None: + if ( + repository.service != SERVICE_GITHUB + and repository.service != SERVICE_GITHUB_ENTERPRISE + ): + return False + repo_owner = repository.author + try: + redis_connection = get_redis_connection() + rate_limit_redis_key = rate_limits.determine_entity_redis_key( + owner=repo_owner, repository=repository + ) + return rate_limits.determine_if_entity_is_rate_limited( + redis_connection, rate_limit_redis_key + ) + except Exception: + log.warning( + "Error when checking rate limit", + extra=dict(repo_id=repository.repoid, has_owner=bool(repo_owner)), + ) + return None + + +@repository_bindable.field("testResults") +@convert_kwargs_to_snake_case +async def resolve_test_results( + repository: Repository, + info: GraphQLResolveInfo, + ordering=None, + filters=None, + **kwargs, +): + queryset = await sync_to_async(aggregate_test_results)( + repoid=repository.repoid, branch=filters.get("branch") if filters else None + ) + + return await queryset_to_connection( + queryset, + ordering=(ordering.get("parameter"), "name") + if ordering + else ("avg_duration", "name"), + ordering_direction=ordering.get("direction") + if ordering + else OrderingDirection.DESC, + **kwargs, + ) diff --git a/graphql_api/types/repository_config/__init__.py b/graphql_api/types/repository_config/__init__.py index c8fc743b38..0520f63a3b 100644 --- a/graphql_api/types/repository_config/__init__.py +++ b/graphql_api/types/repository_config/__init__.py @@ -3,3 +3,9 @@ from .repository_config import indication_range_bindable, repository_config_bindable repository_config = ariadne_load_local_graphql(__file__, "repository_config.graphql") + +__all__ = [ + "repository_config", + "repository_config_bindable", + "indication_range_bindable", +] diff --git a/graphql_api/types/repository_config/repository_config.graphql b/graphql_api/types/repository_config/repository_config.graphql index 2792271240..6a32ad076c 100644 --- a/graphql_api/types/repository_config/repository_config.graphql +++ b/graphql_api/types/repository_config/repository_config.graphql @@ -3,6 +3,6 @@ type RepositoryConfig { } type IndicationRange { - upperRange: Int! - lowerRange: Int! + upperRange: Float! + lowerRange: Float! } diff --git a/graphql_api/types/repository_config/repository_config.py b/graphql_api/types/repository_config/repository_config.py index 8b0902e076..46f06d2340 100644 --- a/graphql_api/types/repository_config/repository_config.py +++ b/graphql_api/types/repository_config/repository_config.py @@ -12,28 +12,30 @@ class IndicationRange(TypedDict): - lowerRange: str - upperRange: str + lowerRange: float + upperRange: float @repository_config_bindable.field("indicationRange") -async def resolve_indication_range(repository: Repository, info) -> dict[str, int]: +async def resolve_indication_range(repository: Repository, info) -> dict[str, float]: owner = await OwnerLoader.loader(info).load(repository.author_id) yaml = await sync_to_async(UserYaml.get_final_yaml)( owner_yaml=owner.yaml, repo_yaml=repository.yaml ) - range: list[int] = yaml.get("coverage", {"range": [60, 80]}).get("range", [60, 80]) + range: list[float] = yaml.get("coverage", {"range": [60, 80]}).get( + "range", [60, 80] + ) return {"lowerRange": range[0], "upperRange": range[1]} @indication_range_bindable.field("upperRange") -def resolve_upper_range(indicationRange: IndicationRange, info) -> int: +def resolve_upper_range(indicationRange: IndicationRange, info) -> float: upperRange = indicationRange.get("upperRange") return upperRange @indication_range_bindable.field("lowerRange") -def resolve_lower_range(indicationRange: IndicationRange, info) -> int: +def resolve_lower_range(indicationRange: IndicationRange, info) -> float: lowerRange = indicationRange.get("lowerRange") return lowerRange diff --git a/graphql_api/types/segment_comparison/__init__.py b/graphql_api/types/segment_comparison/__init__.py index c0b16f181f..20f87a3b48 100644 --- a/graphql_api/types/segment_comparison/__init__.py +++ b/graphql_api/types/segment_comparison/__init__.py @@ -3,3 +3,10 @@ from .segment_comparison import segment_comparison_bindable, segments_result_bindable segment_comparison = ariadne_load_local_graphql(__file__, "segment_comparison.graphql") + + +__all__ = [ + "segment_comparison", + "segment_comparison_bindable", + "segments_result_bindable", +] diff --git a/graphql_api/types/segment_comparison/segment_comparison.py b/graphql_api/types/segment_comparison/segment_comparison.py index bea2ca9a53..12356678e2 100644 --- a/graphql_api/types/segment_comparison/segment_comparison.py +++ b/graphql_api/types/segment_comparison/segment_comparison.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import List, Optional +from typing import List from ariadne import ObjectType, UnionType diff --git a/graphql_api/types/self_hosted_license/__init__.py b/graphql_api/types/self_hosted_license/__init__.py index 18fe89c91a..777259ea74 100644 --- a/graphql_api/types/self_hosted_license/__init__.py +++ b/graphql_api/types/self_hosted_license/__init__.py @@ -5,3 +5,5 @@ self_hosted_license = ariadne_load_local_graphql( __file__, "self_hosted_license.graphql" ) + +__all__ = ["self_hosted_license", "self_hosted_license_bindable"] diff --git a/graphql_api/types/session/__init__.py b/graphql_api/types/session/__init__.py index e00a6ec2ba..49f9ff9cac 100644 --- a/graphql_api/types/session/__init__.py +++ b/graphql_api/types/session/__init__.py @@ -3,3 +3,5 @@ from .session import session_bindable session = ariadne_load_local_graphql(__file__, "session.graphql") + +__all__ = ["session", "session_bindable"] diff --git a/graphql_api/types/test_results/__init__.py b/graphql_api/types/test_results/__init__.py new file mode 100644 index 0000000000..be36dc2993 --- /dev/null +++ b/graphql_api/types/test_results/__init__.py @@ -0,0 +1,9 @@ +from shared.license import get_current_license + +from graphql_api.helpers.ariadne import ariadne_load_local_graphql + +from .test_results import test_result_bindable + +test_results = ariadne_load_local_graphql(__file__, "test_results.graphql") + +__all__ = ["get_current_license", "test_result_bindable"] diff --git a/graphql_api/types/test_results/test_results.graphql b/graphql_api/types/test_results/test_results.graphql new file mode 100644 index 0000000000..1f47b22391 --- /dev/null +++ b/graphql_api/types/test_results/test_results.graphql @@ -0,0 +1,7 @@ +type TestResult { + updatedAt: DateTime! + name: String! + commitsFailed: Int + failureRate: Float + avgDuration: Float +} diff --git a/graphql_api/types/test_results/test_results.py b/graphql_api/types/test_results/test_results.py new file mode 100644 index 0000000000..d90a309d68 --- /dev/null +++ b/graphql_api/types/test_results/test_results.py @@ -0,0 +1,30 @@ +from datetime import datetime + +from ariadne import ObjectType + +test_result_bindable = ObjectType("TestResult") + + +@test_result_bindable.field("name") +def resolve_name(test, info) -> str: + return test["name"].replace("\x1f", " ") + + +@test_result_bindable.field("updatedAt") +def resolve_updated_at(test, info) -> datetime: + return test["updated_at"] + + +@test_result_bindable.field("commitsFailed") +def resolve_commits_failed(test, info) -> int | None: + return test["commits_where_fail"] + + +@test_result_bindable.field("failureRate") +def resolve_failure_rate(test, info) -> float | None: + return test["failure_rate"] + + +@test_result_bindable.field("avgDuration") +def resolve_last_duration(test, info) -> float | None: + return test["avg_duration"] diff --git a/graphql_api/types/upload/__init__.py b/graphql_api/types/upload/__init__.py index 77b01e47f2..59f00b7129 100644 --- a/graphql_api/types/upload/__init__.py +++ b/graphql_api/types/upload/__init__.py @@ -3,3 +3,5 @@ from .upload import upload_bindable, upload_error_bindable upload = ariadne_load_local_graphql(__file__, "upload.graphql") + +__all__ = ["upload", "upload_bindable", "upload_error_bindable"] diff --git a/graphql_api/types/upload/upload.py b/graphql_api/types/upload/upload.py index 4a6f46915c..1a8997ffc0 100644 --- a/graphql_api/types/upload/upload.py +++ b/graphql_api/types/upload/upload.py @@ -27,6 +27,8 @@ @upload_bindable.field("state") def resolve_state(upload, info): + if not upload.state: + return UploadState.ERROR return UploadState(upload.state) diff --git a/graphql_api/types/user/__init__.py b/graphql_api/types/user/__init__.py index 6d1a4feab9..94580c9b04 100644 --- a/graphql_api/types/user/__init__.py +++ b/graphql_api/types/user/__init__.py @@ -1 +1,3 @@ from .user import user, user_bindable + +__all__ = ["user", "user_bindable"] diff --git a/graphql_api/types/user_token/__init__.py b/graphql_api/types/user_token/__init__.py index c52d7dc2c6..445e55d62f 100644 --- a/graphql_api/types/user_token/__init__.py +++ b/graphql_api/types/user_token/__init__.py @@ -3,3 +3,6 @@ from .user_token import user_token_bindable user_token = ariadne_load_local_graphql(__file__, "user_token.graphql") + + +__all__ = ["user_token", "user_token_bindable"] diff --git a/graphql_api/views.py b/graphql_api/views.py index 419984b447..2533ce7bb2 100644 --- a/graphql_api/views.py +++ b/graphql_api/views.py @@ -12,7 +12,11 @@ from ariadne.validation import cost_validator from ariadne_django.views import GraphQLAsyncView from django.conf import settings -from django.http import HttpResponseBadRequest, HttpResponseNotAllowed, JsonResponse +from django.http import ( + HttpResponseBadRequest, + HttpResponseNotAllowed, + JsonResponse, +) from graphql import DocumentNode from sentry_sdk import capture_exception from sentry_sdk import metrics as sentry_metrics @@ -22,6 +26,7 @@ from codecov.commands.executor import get_executor_from_request from codecov.db import sync_to_async from services import ServiceException +from services.redis_configuration import get_redis_connection from .schema import schema @@ -223,6 +228,16 @@ async def post(self, request, *args, **kwargs): log.info("GraphQL Request", extra=log_data) sentry_metrics.incr("graphql.info.request_made", tags={"path": req_path}) + if self._check_ratelimit(request=request): + sentry_metrics.incr("graphql.error.rate_limit", tags={"path": req_path}) + return JsonResponse( + data={ + "status": 429, + "detail": "It looks like you've hit the rate limit of 300 req/min. Try again later.", + }, + status=429, + ) + with RequestFinalizer(request): response = await super().post(request, *args, **kwargs) @@ -263,7 +278,7 @@ def context_value(self, request, *_): } def error_formatter(self, error, debug=False): - # the only way to check for a malformatted query + # the only way to check for a malformed query is_bad_query = "Cannot query field" in error.formatted["message"] if debug or is_bad_query: return format_error(error, debug) @@ -290,6 +305,53 @@ def _get_user(self, request): if request.user: request.user.pk + def _check_ratelimit(self, request): + redis = get_redis_connection() + + try: + # eagerly try to get user_id from request object + user_id = request.user.pk + except AttributeError: + user_id = None + + if user_id: + key = f"rl-user:{user_id}" + else: + user_ip = self.get_client_ip(request) + key = f"rl-ip:{user_ip}" + + limit = 300 + window = 60 # in seconds + + current_count = redis.get(key) + if current_count is None: + log.info( + "[GQL Rate Limit] - Setting new key", + extra=dict(key=key, user_id=user_id), + ) + redis.setex(key, window, 1) + elif int(current_count) >= limit: + log.warning( + "[GQL Rate Limit] - Rate limit reached for key", + extra=dict(key=key, limit=limit, count=current_count, user_id=user_id), + ) + return True + else: + log.warning( + "[GQL Rate Limit] - Incrementing rate limit for key", + extra=dict(key=key, limit=limit, count=current_count, user_id=user_id), + ) + redis.incr(key) + return False + + def get_client_ip(self, request): + x_forwarded_for = request.META.get("HTTP_X_FORWARDED_FOR") + if x_forwarded_for: + ip = x_forwarded_for.split(",")[0] + else: + ip = request.META.get("REMOTE_ADDR") + return ip + BaseAriadneView = AsyncGraphqlView.as_view() diff --git a/graphs/helpers/badge.py b/graphs/helpers/badge.py index 849d50889c..7280634be5 100644 --- a/graphs/helpers/badge.py +++ b/graphs/helpers/badge.py @@ -1,11 +1,9 @@ -from math import floor - from shared.helpers.color import coverage_to_color from graphs.badges.badges import large_badge, medium_badge, small_badge, unknown_badge -def get_badge(coverage, coverage_range, precision): +def get_badge(coverage: str | None, coverage_range: list[int], precision: str): """ Returns and SVG string containing coverage badge diff --git a/graphs/helpers/graph_utils.py b/graphs/helpers/graph_utils.py index b69e4a6873..ee548eaf87 100644 --- a/graphs/helpers/graph_utils.py +++ b/graphs/helpers/graph_utils.py @@ -1,7 +1,5 @@ from math import cos, pi, sin -from shared.helpers.color import coverage_to_color - style_n_defs = """ diff --git a/graphs/helpers/graphs.py b/graphs/helpers/graphs.py index e9162808c1..2eabd3a2d2 100644 --- a/graphs/helpers/graphs.py +++ b/graphs/helpers/graphs.py @@ -3,14 +3,11 @@ from graphs.settings import settings from .graph_utils import ( - _layout, _make_svg, - _max_aspect_ratio, _squarify, _svg_polar_rect, _svg_rect, _tree_height, - _worst_ratio, ) diff --git a/graphs/tests/test_badge_handler.py b/graphs/tests/test_badge_handler.py index 3f8026861b..156639235e 100644 --- a/graphs/tests/test_badge_handler.py +++ b/graphs/tests/test_badge_handler.py @@ -250,9 +250,7 @@ def test_unknown_bagde_incorrect_repo(self): def test_unknown_bagde_no_branch(self): gh_owner = OwnerFactory(service="github") - repo = RepositoryFactory( - author=gh_owner, active=True, private=False, name="repo1" - ) + RepositoryFactory(author=gh_owner, active=True, private=False, name="repo1") response = self._get( kwargs={ "service": "gh", @@ -296,7 +294,7 @@ def test_unknown_bagde_no_commit(self): repo = RepositoryFactory( author=gh_owner, active=True, private=False, name="repo1" ) - branch = BranchFactory(repository=repo, name="master") + BranchFactory(repository=repo, name="master") response = self._get( kwargs={ "service": "gh", @@ -340,7 +338,7 @@ def test_unknown_bagde_no_totals(self): repo = RepositoryFactory( author=gh_owner, active=True, private=False, name="repo1" ) - commit = CommitFactory(repository=repo, author=gh_owner, totals=None) + CommitFactory(repository=repo, author=gh_owner, totals=None) response = self._get( kwargs={ "service": "gh", @@ -384,7 +382,7 @@ def test_text_badge(self): repo = RepositoryFactory( author=gh_owner, active=True, private=False, name="repo1" ) - commit = CommitFactory(repository=repo, author=gh_owner) + CommitFactory(repository=repo, author=gh_owner) # test default precision response = self._get( @@ -435,7 +433,7 @@ def test_svg_badge(self): repo = RepositoryFactory( author=gh_owner, active=True, private=False, name="repo1" ) - commit = CommitFactory(repository=repo, author=gh_owner) + CommitFactory(repository=repo, author=gh_owner) # test default precision response = self._get( @@ -566,7 +564,7 @@ def test_private_badge_no_token(self): name="repo1", image_token="12345678", ) - commit = CommitFactory(repository=repo, author=gh_owner) + CommitFactory(repository=repo, author=gh_owner) # test default precision response = self._get( @@ -616,7 +614,7 @@ def test_private_badge(self): name="repo1", image_token="12345678", ) - commit = CommitFactory(repository=repo, author=gh_owner) + CommitFactory(repository=repo, author=gh_owner) # test default precision response = self._get( @@ -667,7 +665,7 @@ def test_branch_badge(self): image_token="12345678", branch="branch1", ) - commit = CommitFactory(repository=repo, author=gh_owner) + CommitFactory(repository=repo, author=gh_owner) commit_2_totals = { "C": 0, "M": 0, @@ -683,21 +681,6 @@ def test_branch_badge(self): "p": 0, "s": 1, } - commit_3_totals = { - "C": 0, - "M": 0, - "N": 0, - "b": 0, - "c": "80.00000", - "d": 0, - "diff": [1, 2, 1, 1, 0, "50.00000", 0, 0, 0, 0, 0, 0, 0], - "f": 3, - "h": 17, - "m": 3, - "n": 20, - "p": 0, - "s": 1, - } commit_2 = CommitFactory( commitid="b1c2b4fa3ae9ef615c8f740c5cba95d9851f9ae8", repository=repo, @@ -760,7 +743,7 @@ def test_badge_with_100_coverage(self): image_token="12345678", branch="branch1", ) - commit = CommitFactory(repository=repo, author=gh_owner) + CommitFactory(repository=repo, author=gh_owner) commit_2_totals = { "C": 0, "M": 0, @@ -779,9 +762,7 @@ def test_badge_with_100_coverage(self): commit_2 = CommitFactory( repository=repo, author=gh_owner, totals=commit_2_totals ) - branch_2 = BranchFactory( - repository=repo, name="branch1", head=commit_2.commitid - ) + BranchFactory(repository=repo, name="branch1", head=commit_2.commitid) # test default precision response = self._get_branch( kwargs={ @@ -832,7 +813,7 @@ def test_branch_badge_with_slash(self): image_token="12345678", branch="branch1", ) - commit = CommitFactory(repository=repo, author=gh_owner) + CommitFactory(repository=repo, author=gh_owner) commit_2_totals = { "C": 0, "M": 0, @@ -851,9 +832,7 @@ def test_branch_badge_with_slash(self): commit_2 = CommitFactory( repository=repo, author=gh_owner, totals=commit_2_totals ) - branch_2 = BranchFactory( - repository=repo, name="test/branch1", head=commit_2.commitid - ) + BranchFactory(repository=repo, name="test/branch1", head=commit_2.commitid) # test default precision response = self._get_branch( kwargs={ @@ -900,7 +879,7 @@ def test_flag_badge(self, full_report_mock): repo = RepositoryFactory( author=gh_owner, active=True, private=False, name="repo1" ) - commit = CommitFactory(repository=repo, author=gh_owner) + CommitFactory(repository=repo, author=gh_owner) full_report_mock.return_value = sample_report() # test default precision @@ -953,7 +932,7 @@ def test_none_branch_flag_badge(self): name="repo1", branch="not-master", ) - commit = CommitFactory(repository=repo, author=gh_owner) + CommitFactory(repository=repo, author=gh_owner) # test default precision response = self._get( @@ -1003,7 +982,7 @@ def test_unknown_flag_badge(self, full_report_mock): repo = RepositoryFactory( author=gh_owner, active=True, private=False, name="repo1" ) - commit = CommitFactory(repository=repo, author=gh_owner) + CommitFactory(repository=repo, author=gh_owner) full_report_mock.return_value = sample_report() # test default precision @@ -1054,7 +1033,7 @@ def test_unknown_sessions_flag_badge(self, full_report_mock): repo = RepositoryFactory( author=gh_owner, active=True, private=False, name="repo1" ) - commit = CommitFactory(repository=repo, author=gh_owner) + CommitFactory(repository=repo, author=gh_owner) full_report_mock.return_value = sample_report() # test default precision response = self._get( @@ -1104,7 +1083,7 @@ def test_unknown_report_flag_badge(self, full_report_mock): repo = RepositoryFactory( author=gh_owner, active=True, private=False, name="repo1" ) - commit = CommitFactory(repository=repo, author=gh_owner) + CommitFactory(repository=repo, author=gh_owner) full_report_mock.return_value = sample_report() # test default precision @@ -1158,7 +1137,7 @@ def test_yaml_range(self): name="repo1", yaml={"coverage": {"range": [0.0, 0.8]}}, ) - commit = CommitFactory(repository=repo, author=gh_owner) + CommitFactory(repository=repo, author=gh_owner) # test default precision response = self._get( @@ -1209,7 +1188,7 @@ def test_yaml_empty_range(self): name="repo1", yaml={"coverage": {}}, ) - commit = CommitFactory(repository=repo, author=gh_owner) + CommitFactory(repository=repo, author=gh_owner) # test default precision response = self._get( @@ -1257,7 +1236,7 @@ def test_commit_report_null(self, full_report_mock): repo = RepositoryFactory( author=gh_owner, active=True, private=False, name="repo1" ) - commit = CommitFactory(repository=repo, author=gh_owner) + CommitFactory(repository=repo, author=gh_owner) full_report_mock.return_value = None # test default precision @@ -1308,7 +1287,7 @@ def test_commit_report_no_flags(self, full_report_mock): repo = RepositoryFactory( author=gh_owner, active=True, private=False, name="repo1" ) - commit = CommitFactory(repository=repo, author=gh_owner) + CommitFactory(repository=repo, author=gh_owner) full_report_mock.return_value = sample_report_no_flags() # test default precision diff --git a/graphs/tests/test_graph_handler.py b/graphs/tests/test_graph_handler.py index 3101cd2835..97923a3cd7 100644 --- a/graphs/tests/test_graph_handler.py +++ b/graphs/tests/test_graph_handler.py @@ -85,7 +85,7 @@ def test_tree_graph(self): repo = RepositoryFactory( author=gh_owner, active=True, private=False, name="repo1" ) - commit = CommitWithReportFactory(repository=repo, author=gh_owner) + CommitWithReportFactory(repository=repo, author=gh_owner) # test default precision response = self._get( @@ -128,7 +128,7 @@ def test_icicle_graph(self): repo = RepositoryFactory( author=gh_owner, active=True, private=False, name="repo1" ) - commit = CommitWithReportFactory(repository=repo, author=gh_owner) + CommitWithReportFactory(repository=repo, author=gh_owner) # test default precision response = self._get( @@ -174,7 +174,7 @@ def test_sunburst_graph(self): repo = RepositoryFactory( author=gh_owner, active=True, private=False, name="repo1" ) - commit = CommitWithReportFactory(repository=repo, author=gh_owner) + CommitWithReportFactory(repository=repo, author=gh_owner) # test default precision response = self._get( @@ -252,7 +252,7 @@ def test_unkown_repo(self): def test_private_repo_no_token(self): gh_owner = OwnerFactory(service="github") - repo = RepositoryFactory( + RepositoryFactory( author=gh_owner, active=True, private=True, @@ -285,7 +285,7 @@ def test_private_repo(self): name="repo1", image_token="12345678", ) - commit = CommitWithReportFactory(repository=repo, author=gh_owner) + CommitWithReportFactory(repository=repo, author=gh_owner) response = self._get( "sunburst", @@ -328,9 +328,7 @@ def test_private_repo(self): def test_unkown_branch(self): gh_owner = OwnerFactory(service="github") - repo = RepositoryFactory( - author=gh_owner, active=True, private=False, name="repo1" - ) + RepositoryFactory(author=gh_owner, active=True, private=False, name="repo1") response = self._get( "sunburst", @@ -358,7 +356,7 @@ def test_branch_graph(self): image_token="12345678", branch="branch1", ) - commit = CommitWithReportFactory(repository=repo, author=gh_owner) + CommitWithReportFactory(repository=repo, author=gh_owner) commit_2_totals = { "C": 0, "M": 0, @@ -377,9 +375,7 @@ def test_branch_graph(self): commit_2 = CommitWithReportFactory( repository=repo, author=gh_owner, totals=commit_2_totals ) - branch_2 = BranchFactory( - repository=repo, name="branch1", head=commit_2.commitid - ) + BranchFactory(repository=repo, name="branch1", head=commit_2.commitid) # test default precision response = self._get_branch( "tree", @@ -429,7 +425,7 @@ def test_commit_graph(self): # make sure commit 2 report is different than commit 1 and # assert that the expected graph below still pertains to commit_1 - commit_2 = CommitFactory( + CommitFactory( repository=repo, author=gh_owner, parent_commit_id=commit_1.commitid, @@ -505,7 +501,7 @@ def test_pull_graph(self): image_token="12345678", branch="branch1", ) - pull = PullFactory( + PullFactory( pullid=10, repository_id=repo.repoid, _flare=[ @@ -562,7 +558,7 @@ def test_pull_graph(self): def test_no_pull_graph(self): gh_owner = OwnerFactory(service="github") - repo = RepositoryFactory( + RepositoryFactory( author=gh_owner, active=True, private=True, @@ -598,8 +594,8 @@ def test_pull_no_flare_graph(self): image_token="12345678", branch="master", ) - commit = CommitWithReportFactory(repository=repo, author=gh_owner) - pull = PullFactory(pullid=10, repository_id=repo.repoid, _flare=None) + CommitWithReportFactory(repository=repo, author=gh_owner) + PullFactory(pullid=10, repository_id=repo.repoid, _flare=None) # test default precision response = self._get_pull( diff --git a/graphs/urls.py b/graphs/urls.py index e68b99bfe8..bc0e0d45c8 100644 --- a/graphs/urls.py +++ b/graphs/urls.py @@ -1,4 +1,4 @@ -from django.urls import path, re_path +from django.urls import re_path from .views import BadgeHandler, GraphHandler diff --git a/graphs/views.py b/graphs/views.py index f09d7d00d3..62a2f8bf92 100644 --- a/graphs/views.py +++ b/graphs/views.py @@ -1,14 +1,12 @@ import logging from django.core.exceptions import ObjectDoesNotExist -from django.db import connection from django.http import Http404 from rest_framework import exceptions from rest_framework.exceptions import NotFound from rest_framework.negotiation import DefaultContentNegotiation from rest_framework.permissions import AllowAny from rest_framework.views import APIView -from shared.reports.resources import Report import services.report as report_service from api.shared.mixins import RepoPropertyMixin @@ -123,7 +121,7 @@ def get_coverage(self): def flag_coverage(self, flag_name, commit): """ - Looks into a commit's report sessions and returns the coverage for a perticular flag + Looks into a commit's report sessions and returns the coverage for a particular flag Parameters flag_name (string): name of flag @@ -158,36 +156,36 @@ def get_object(self, request, *args, **kwargs): if graph == "tree": options["width"] = int( self.request.query_params.get( - "width", settings["sunburst"]["options"]["width"] + "width", settings["sunburst"]["options"]["width"] or 100 ) ) options["height"] = int( self.request.query_params.get( - "height", settings["sunburst"]["options"]["height"] + "height", settings["sunburst"]["options"]["height"] or 100 ) ) return tree(flare, None, None, **options) elif graph == "icicle": options["width"] = int( self.request.query_params.get( - "width", settings["icicle"]["options"]["width"] + "width", settings["icicle"]["options"]["width"] or 100 ) ) options["height"] = int( self.request.query_params.get( - "height", settings["icicle"]["options"]["height"] + "height", settings["icicle"]["options"]["height"] or 100 ) ) return icicle(flare, **options) elif graph == "sunburst": options["width"] = int( self.request.query_params.get( - "width", settings["sunburst"]["options"]["width"] + "width", settings["sunburst"]["options"]["width"] or 100 ) ) options["height"] = int( self.request.query_params.get( - "height", settings["sunburst"]["options"]["height"] + "height", settings["sunburst"]["options"]["height"] or 100 ) ) return sunburst(flare, **options) diff --git a/labelanalysis/admin.py b/labelanalysis/admin.py index 8c38f3f3da..846f6b4061 100644 --- a/labelanalysis/admin.py +++ b/labelanalysis/admin.py @@ -1,3 +1 @@ -from django.contrib import admin - # Register your models here. diff --git a/labelanalysis/tests/integration/test_views.py b/labelanalysis/tests/integration/test_views.py index 0946c10a5d..cbc82d45c5 100644 --- a/labelanalysis/tests/integration/test_views.py +++ b/labelanalysis/tests/integration/test_views.py @@ -441,14 +441,6 @@ def test_simple_label_analysis_put_labels_wrong_base_return_404(db, mocker): produced_object = LabelAnalysisRequest.objects.get(head_commit=commit) assert produced_object == label_analysis assert produced_object.requested_labels is None - expected_response_json = { - "base_commit": base_commit.commitid, - "head_commit": commit.commitid, - "requested_labels": ["label_1", "label_2", "label_3"], - "result": None, - "state": "created", - "external_id": str(produced_object.external_id), - } patch_url = reverse( "view_label_analysis", kwargs=dict(external_id=produced_object.external_id) ) diff --git a/legacy_migrations/migrations/0001_initial.py b/legacy_migrations/migrations/0001_initial.py index c22b39f55b..ba2e60e085 100644 --- a/legacy_migrations/migrations/0001_initial.py +++ b/legacy_migrations/migrations/0001_initial.py @@ -1,6 +1,6 @@ # Generated by Django 3.1.6 on 2021-03-15 20:15 -from django.conf import settings +from django.conf import settings # noqa: F401 from django.db import migrations from .legacy_sql.main.main import run_sql as main_run_sql diff --git a/open_telemetry.py b/open_telemetry.py index 6aa6678d6d..1e12ba4c46 100644 --- a/open_telemetry.py +++ b/open_telemetry.py @@ -1,7 +1,6 @@ import json import logging import os -import re from datetime import datetime import requests @@ -132,13 +131,11 @@ def export(self, spans): api = self.attributes["api"] try: - to_send = [] + to_send = [self._format_span(span) for span in spans] headers = { "content-type": "application/json", "Authorization": self.attributes["token"], } - for span in spans: - to_send.append(self._format_span(span)) requests.post(api + "/api/ingest", headers=headers, json=to_send) except ConnectionError: logging.exception("failed to export all spans") diff --git a/plan/service.py b/plan/service.py index 672aff3f83..83d3b19bae 100644 --- a/plan/service.py +++ b/plan/service.py @@ -21,6 +21,8 @@ TrialStatus, ) from services import sentry +from services.self_hosted import enterprise_has_seats_left, license_seats +from utils.config import get_config log = logging.getLogger(__name__) @@ -41,8 +43,7 @@ def __init__(self, current_org: Owner): self.current_org = current_org if self.current_org.plan not in USER_PLAN_REPRESENTATIONS: raise ValueError("Unsupported plan") - else: - self.plan_data = USER_PLAN_REPRESENTATIONS[self.current_org.plan] + self._plan_data = None def update_plan(self, name, user_count: int | None) -> None: if name not in USER_PLAN_REPRESENTATIONS: @@ -51,7 +52,7 @@ def update_plan(self, name, user_count: int | None) -> None: raise ValueError("Quantity Needed") self.current_org.plan = name self.current_org.plan_user_count = user_count - self.plan_data = USER_PLAN_REPRESENTATIONS[self.current_org.plan] + self._plan_data = USER_PLAN_REPRESENTATIONS[self.current_org.plan] self.current_org.save() def current_org(self) -> Owner: @@ -65,12 +66,35 @@ def set_default_plan_data(self) -> None: self.current_org.stripe_subscription_id = None self.current_org.save() + @property + def has_account(self) -> bool: + return False if self.current_org.account is None else True + + @property + def plan_data(self) -> PlanData: + if self._plan_data is not None: + return self._plan_data + + if self.has_account: + self._plan_data = USER_PLAN_REPRESENTATIONS[self.current_org.account.plan] + else: + self._plan_data = USER_PLAN_REPRESENTATIONS[self.current_org.plan] + return self._plan_data + + @plan_data.setter + def set_plan_data(self, plan_data: PlanData | None) -> None: + self._plan_data = plan_data + @property def plan_name(self) -> str: return self.plan_data.value @property def plan_user_count(self) -> int: + if get_config("setup", "enterprise_license"): + return license_seats() + if self.has_account: + return self.current_org.account.total_seat_count return self.current_org.plan_user_count @property @@ -147,7 +171,7 @@ def _start_trial_helper( end_date: Optional[datetime] = None, is_extension: bool = False, ) -> None: - start_date = datetime.utcnow() + start_date = datetime.now() # When they are not extending a trial, have to setup all the default values if not is_extension: @@ -207,7 +231,7 @@ def start_trial_manually(self, current_owner: Owner, end_date: datetime) -> None def cancel_trial(self) -> None: if not self.is_org_trialing: raise ValidationError("Cannot cancel a trial that is not ongoing") - now = datetime.utcnow() + now = datetime.now() self.current_org.trial_status = TrialStatus.EXPIRED.value self.current_org.trial_end_date = now self.set_default_plan_data() @@ -230,7 +254,7 @@ def expire_trial_when_upgrading(self) -> None: self.current_org.plan_user_count = ( self.current_org.pretrial_users_count or 1 ) - self.current_org.trial_end_date = datetime.utcnow() + self.current_org.trial_end_date = datetime.now() self.current_org.save() @@ -263,6 +287,14 @@ def has_trial_dates(self) -> bool: @property def has_seats_left(self) -> bool: + if get_config("setup", "enterprise_license"): + return enterprise_has_seats_left() + if self.has_account: + # edge case: IF the User is already a plan_activated_user on any of the Orgs in the Account, + # AND their Account is at capacity, + # AND they try to become a plan_activated_user on another Org in the Account, + # has_seats_left will evaluate as False even though the User should be allowed to activate on the Org. + return self.current_org.account.can_activate_user() return ( self.plan_activated_users is None or len(self.plan_activated_users) < self.plan_user_count diff --git a/plan/tests/test_plan.py b/plan/tests/test_plan.py index e4d75b73ab..3f4109db5e 100644 --- a/plan/tests/test_plan.py +++ b/plan/tests/test_plan.py @@ -4,9 +4,10 @@ from django.test import TestCase from freezegun import freeze_time from pytest import raises +from shared.django_apps.codecov_auth.tests.factories import AccountsUsersFactory from codecov.commands.exceptions import ValidationError -from codecov_auth.tests.factories import OwnerFactory +from codecov_auth.tests.factories import AccountFactory, OwnerFactory from plan.constants import ( BASIC_PLAN, FREE_PLAN, @@ -32,7 +33,7 @@ def test_plan_service_trial_status_not_started(self): assert plan_service.trial_status == TrialStatus.NOT_STARTED.value def test_plan_service_trial_status_expired(self): - trial_start_date = datetime.utcnow() + trial_start_date = datetime.now() trial_end_date_expired = trial_start_date - timedelta(days=1) current_org = OwnerFactory( plan=PlanName.BASIC_PLAN_NAME.value, @@ -45,7 +46,7 @@ def test_plan_service_trial_status_expired(self): assert plan_service.trial_status == TrialStatus.EXPIRED.value def test_plan_service_trial_status_ongoing(self): - trial_start_date = datetime.utcnow() + trial_start_date = datetime.now() trial_end_date_ongoing = trial_start_date + timedelta(days=5) current_org = OwnerFactory( plan=PlanName.TRIAL_PLAN_NAME.value, @@ -70,14 +71,14 @@ def test_plan_service_expire_trial_when_upgrading_successful_if_trial_is_not_sta plan_service = PlanService(current_org=current_org_with_ongoing_trial) plan_service.expire_trial_when_upgrading() assert current_org_with_ongoing_trial.trial_status == TrialStatus.EXPIRED.value - assert current_org_with_ongoing_trial.plan_activated_users == None + assert current_org_with_ongoing_trial.plan_activated_users is None assert current_org_with_ongoing_trial.plan_user_count == 1 - assert current_org_with_ongoing_trial.trial_end_date == datetime.utcnow() + assert current_org_with_ongoing_trial.trial_end_date == datetime.now() def test_plan_service_expire_trial_when_upgrading_successful_if_trial_is_ongoing( self, ): - trial_start_date = datetime.utcnow() + trial_start_date = datetime.now() trial_end_date_ongoing = trial_start_date + timedelta(days=5) current_org_with_ongoing_trial = OwnerFactory( plan=PlanName.BASIC_PLAN_NAME.value, @@ -88,14 +89,14 @@ def test_plan_service_expire_trial_when_upgrading_successful_if_trial_is_ongoing plan_service = PlanService(current_org=current_org_with_ongoing_trial) plan_service.expire_trial_when_upgrading() assert current_org_with_ongoing_trial.trial_status == TrialStatus.EXPIRED.value - assert current_org_with_ongoing_trial.plan_activated_users == None + assert current_org_with_ongoing_trial.plan_activated_users is None assert current_org_with_ongoing_trial.plan_user_count == 1 - assert current_org_with_ongoing_trial.trial_end_date == datetime.utcnow() + assert current_org_with_ongoing_trial.trial_end_date == datetime.now() def test_plan_service_expire_trial_users_pretrial_users_count_if_existing( self, ): - trial_start_date = datetime.utcnow() + trial_start_date = datetime.now() trial_end_date_ongoing = trial_start_date + timedelta(days=5) pretrial_users_count = 5 current_org_with_ongoing_trial = OwnerFactory( @@ -108,12 +109,12 @@ def test_plan_service_expire_trial_users_pretrial_users_count_if_existing( plan_service = PlanService(current_org=current_org_with_ongoing_trial) plan_service.expire_trial_when_upgrading() assert current_org_with_ongoing_trial.trial_status == TrialStatus.EXPIRED.value - assert current_org_with_ongoing_trial.plan_activated_users == None + assert current_org_with_ongoing_trial.plan_activated_users is None assert current_org_with_ongoing_trial.plan_user_count == pretrial_users_count - assert current_org_with_ongoing_trial.trial_end_date == datetime.utcnow() + assert current_org_with_ongoing_trial.trial_end_date == datetime.now() def test_plan_service_start_trial_errors_if_status_is_ongoing(self): - trial_start_date = datetime.utcnow() + trial_start_date = datetime.now() trial_end_date = trial_start_date + timedelta( days=TrialDaysAmount.CODECOV_SENTRY.value ) @@ -126,11 +127,11 @@ def test_plan_service_start_trial_errors_if_status_is_ongoing(self): plan_service = PlanService(current_org=current_org) current_owner = OwnerFactory() - with self.assertRaises(ValidationError) as e: + with self.assertRaises(ValidationError): plan_service.start_trial(current_owner=current_owner) def test_plan_service_start_trial_errors_if_status_is_expired(self): - trial_start_date = datetime.utcnow() + trial_start_date = datetime.now() trial_end_date = trial_start_date + timedelta(days=-1) current_org = OwnerFactory( plan=PlanName.BASIC_PLAN_NAME.value, @@ -141,7 +142,7 @@ def test_plan_service_start_trial_errors_if_status_is_expired(self): plan_service = PlanService(current_org=current_org) current_owner = OwnerFactory() - with self.assertRaises(ValidationError) as e: + with self.assertRaises(ValidationError): plan_service.start_trial(current_owner=current_owner) def test_plan_service_start_trial_errors_if_status_is_cannot_trial(self): @@ -154,7 +155,7 @@ def test_plan_service_start_trial_errors_if_status_is_cannot_trial(self): plan_service = PlanService(current_org=current_org) current_owner = OwnerFactory() - with self.assertRaises(ValidationError) as e: + with self.assertRaises(ValidationError): plan_service.start_trial(current_owner=current_owner) def test_plan_service_start_trial_errors_owners_plan_is_not_a_free_plan(self): @@ -167,7 +168,7 @@ def test_plan_service_start_trial_errors_owners_plan_is_not_a_free_plan(self): plan_service = PlanService(current_org=current_org) current_owner = OwnerFactory() - with self.assertRaises(ValidationError) as e: + with self.assertRaises(ValidationError): plan_service.start_trial(current_owner=current_owner) def test_plan_service_start_trial_succeeds_if_trial_has_not_started(self): @@ -185,8 +186,8 @@ def test_plan_service_start_trial_succeeds_if_trial_has_not_started(self): current_owner = OwnerFactory() plan_service.start_trial(current_owner=current_owner) - assert current_org.trial_start_date == datetime.utcnow() - assert current_org.trial_end_date == datetime.utcnow() + timedelta( + assert current_org.trial_start_date == datetime.now() + assert current_org.trial_end_date == datetime.now() + timedelta( days=TrialDaysAmount.CODECOV_SENTRY.value ) assert current_org.trial_status == TrialStatus.ONGOING.value @@ -213,7 +214,7 @@ def test_plan_service_start_trial_manually(self): plan_service.start_trial_manually( current_owner=current_owner, end_date="2024-01-01 00:00:00" ) - assert current_org.trial_start_date == datetime.utcnow() + assert current_org.trial_start_date == datetime.now() assert current_org.trial_end_date == "2024-01-01 00:00:00" assert current_org.trial_status == TrialStatus.ONGOING.value assert current_org.plan == PlanName.TRIAL_PLAN_NAME.value @@ -232,7 +233,7 @@ def test_plan_service_start_trial_manually_already_on_paid_plan(self): plan_service = PlanService(current_org=current_org) current_owner = OwnerFactory() - with self.assertRaises(ValidationError) as e: + with self.assertRaises(ValidationError): plan_service.start_trial_manually( current_owner=current_owner, end_date="2024-01-01 00:00:00" ) @@ -265,8 +266,8 @@ def test_plan_service_returns_plan_data_for_non_trial_basic_plan(self): assert plan_service.trial_total_days == basic_plan.trial_days def test_plan_service_returns_plan_data_for_trialing_user_trial_plan(self): - trial_start_date = datetime.utcnow() - trial_end_date = datetime.utcnow() + timedelta( + trial_start_date = datetime.now() + trial_end_date = datetime.now() + timedelta( days=TrialDaysAmount.CODECOV_SENTRY.value ) current_org = OwnerFactory( @@ -285,7 +286,7 @@ def test_plan_service_returns_plan_data_for_trialing_user_trial_plan(self): assert plan_service.billing_rate == trial_plan.billing_rate assert plan_service.base_unit_price == trial_plan.base_unit_price assert plan_service.benefits == trial_plan.benefits - assert plan_service.monthly_uploads_limit == None # Not 250 since it's trialing + assert plan_service.monthly_uploads_limit is None # Not 250 since it's trialing assert plan_service.trial_total_days == trial_plan.trial_days def test_plan_service_sets_default_plan_data_values_correctly(self): @@ -303,14 +304,14 @@ def test_plan_service_sets_default_plan_data_values_correctly(self): assert current_org.plan == PlanName.BASIC_PLAN_NAME.value assert current_org.plan_user_count == 1 - assert current_org.plan_activated_users == None - assert current_org.stripe_subscription_id == None + assert current_org.plan_activated_users is None + assert current_org.stripe_subscription_id is None def test_plan_service_returns_if_owner_has_trial_dates(self): current_org = OwnerFactory( plan=PlanName.CODECOV_PRO_MONTHLY.value, - trial_start_date=datetime.utcnow(), - trial_end_date=datetime.utcnow() + timedelta(days=14), + trial_start_date=datetime.now(), + trial_end_date=datetime.now() + timedelta(days=14), ) current_org.save() @@ -363,6 +364,60 @@ def test_plan_service_update_plan_succeeds(self): assert current_org.plan == PlanName.TEAM_MONTHLY.value assert current_org.plan_user_count == 8 + def test_has_account(self): + current_org = OwnerFactory() + plan_service = PlanService(current_org=current_org) + self.assertFalse(plan_service.has_account) + + current_org.account = AccountFactory() + current_org.save() + plan_service = PlanService(current_org=current_org) + self.assertTrue(plan_service.has_account) + + def test_plan_data_has_account(self): + current_org = OwnerFactory(plan=PlanName.BASIC_PLAN_NAME.value) + plan_service = PlanService(current_org=current_org) + self.assertEqual(plan_service.plan_name, PlanName.BASIC_PLAN_NAME.value) + + current_org.account = AccountFactory(plan=PlanName.CODECOV_PRO_YEARLY.value) + current_org.save() + plan_service = PlanService(current_org=current_org) + self.assertEqual(plan_service.plan_name, PlanName.CODECOV_PRO_YEARLY.value) + + def test_plan_user_count_has_account(self): + org = OwnerFactory(plan=PlanName.BASIC_PLAN_NAME.value, plan_user_count=5) + account = AccountFactory( + plan=PlanName.BASIC_PLAN_NAME.value, plan_seat_count=50, free_seat_count=3 + ) + + plan_service = PlanService(current_org=org) + self.assertEqual(plan_service.plan_user_count, 5) + + org.account = account + org.save() + plan_service = PlanService(current_org=org) + self.assertEqual(plan_service.plan_user_count, 53) + + def test_has_seats_left_has_account(self): + org = OwnerFactory( + plan=PlanName.BASIC_PLAN_NAME.value, + plan_user_count=5, + plan_activated_users=[1, 2, 3], + ) + account = AccountFactory( + plan=PlanName.BASIC_PLAN_NAME.value, plan_seat_count=5, free_seat_count=3 + ) + for i in range(8): + AccountsUsersFactory(account=account) + + plan_service = PlanService(current_org=org) + self.assertEqual(plan_service.has_seats_left, True) + + org.account = account + org.save() + plan_service = PlanService(current_org=org) + self.assertEqual(plan_service.has_seats_left, False) + class AvailablePlansBeforeTrial(TestCase): """ @@ -508,8 +563,8 @@ class AvailablePlansExpiredTrialLessThanTenUsers(TestCase): def setUp(self): self.current_org = OwnerFactory( - trial_start_date=datetime.utcnow() + timedelta(days=-10), - trial_end_date=datetime.utcnow() + timedelta(days=-3), + trial_start_date=datetime.now() + timedelta(days=-10), + trial_end_date=datetime.now() + timedelta(days=-3), trial_status=TrialStatus.EXPIRED.value, plan_user_count=3, ) @@ -623,8 +678,8 @@ class AvailablePlansExpiredTrialMoreThanTenActivatedUsers(TestCase): def setUp(self): self.current_org = OwnerFactory( - trial_start_date=datetime.utcnow() + timedelta(days=-10), - trial_end_date=datetime.utcnow() + timedelta(days=-3), + trial_start_date=datetime.now() + timedelta(days=-10), + trial_end_date=datetime.now() + timedelta(days=-3), trial_status=TrialStatus.EXPIRED.value, plan_user_count=1, plan_activated_users=[i for i in range(13)], @@ -708,8 +763,8 @@ def test_trial_expired(self): plan_user_count=100, plan_activated_users=[i for i in range(10)], trial_status=TrialStatus.EXPIRED.value, - trial_start_date=datetime.utcnow() + timedelta(days=-10), - trial_end_date=datetime.utcnow() + timedelta(days=-3), + trial_start_date=datetime.now() + timedelta(days=-10), + trial_end_date=datetime.now() + timedelta(days=-3), ) self.owner = OwnerFactory() self.plan_service = PlanService(current_org=self.current_org) @@ -723,8 +778,8 @@ def test_trial_ongoing(self): plan_user_count=100, plan_activated_users=[i for i in range(10)], trial_status=TrialStatus.ONGOING.value, - trial_start_date=datetime.utcnow() + timedelta(days=-10), - trial_end_date=datetime.utcnow() + timedelta(days=3), + trial_start_date=datetime.now() + timedelta(days=-10), + trial_end_date=datetime.now() + timedelta(days=3), ) self.owner = OwnerFactory() self.plan_service = PlanService(current_org=self.current_org) @@ -765,8 +820,8 @@ class AvailablePlansOngoingTrial(TestCase): def setUp(self): self.current_org = OwnerFactory( plan=PlanName.TRIAL_PLAN_NAME.value, - trial_start_date=datetime.utcnow(), - trial_end_date=datetime.utcnow() + timedelta(days=14), + trial_start_date=datetime.now(), + trial_end_date=datetime.now() + timedelta(days=14), trial_status=TrialStatus.ONGOING.value, plan_user_count=1000, plan_activated_users=None, diff --git a/reports/tests/factories.py b/reports/tests/factories.py index 055aa9aa96..7dac99a159 100644 --- a/reports/tests/factories.py +++ b/reports/tests/factories.py @@ -1,12 +1,10 @@ -from datetime import datetime - import factory from factory.django import DjangoModelFactory from core.tests.factories import CommitFactory, RepositoryFactory from graphql_api.types.enums import UploadErrorEnum from reports import models -from reports.models import ReportResults +from reports.models import ReportResults, TestInstance class CommitReportFactory(DjangoModelFactory): @@ -97,3 +95,27 @@ class Meta: ReportResults.ReportResultsStates.COMPLETED, ] ) + + +class TestFactory(factory.django.DjangoModelFactory): + class Meta: + model = models.Test + + id = factory.Faker("word") + name = factory.Faker("word") + repository = factory.SubFactory(RepositoryFactory) + commits_where_fail = [] + + +class TestInstanceFactory(factory.django.DjangoModelFactory): + class Meta: + model = models.TestInstance + + test = factory.SubFactory(TestFactory) + duration_seconds = 1.0 + outcome = TestInstance.Outcome.FAILURE.value + failure_message = "Test failed" + branch = "master" + repoid = factory.SelfAttribute("test.repository.repoid") + commitid = "123456" + upload = factory.SubFactory(UploadFactory) diff --git a/requirements.in b/requirements.in index 9a9923db9e..3305da32ed 100644 --- a/requirements.in +++ b/requirements.in @@ -4,7 +4,7 @@ ariadne_django celery>=5.3.6 cerberus ddtrace -Django +Django>=4.2.15 django-cors-headers django-csp django-dynamic-fixture @@ -12,7 +12,7 @@ django-filter django-model-utils django-postgres-extra>=2.0.8 django-prometheus -djangorestframework +djangorestframework==3.15.2 drf-spectacular drf-spectacular-sidecar elastic-apm @@ -20,7 +20,7 @@ factory-boy fakeredis freezegun https://github.com/codecov/opentelem-python/archive/refs/tags/v0.0.4a1.tar.gz#egg=codecovopentelem -https://github.com/codecov/shared/archive/c71c6ae56ea219d40d9316576716148e836e02c2.tar.gz#egg=shared +https://github.com/codecov/shared/archive/d4ceb0eeb71eaa855c793d11edd6db34d54bc883.tar.gz#egg=shared google-cloud-pubsub gunicorn>=22.0.0 https://github.com/photocrowd/django-cursor-pagination/archive/f560902696b0c8509e4d95c10ba0d62700181d84.tar.gz @@ -32,6 +32,7 @@ opentracing pre-commit psycopg2 PyJWT +pydantic pytest>=7.2.0 pytest-cov pytest-django @@ -44,13 +45,14 @@ pytz redis regex requests -sentry-sdk>=1.40.0 +sentry-sdk>=2.13.0 sentry-sdk[celery] setproctitle simplejson stripe>=9.6.0 -urllib3>=1.26.17 +urllib3>=1.26.19 vcrpy whitenoise django-autocomplete-light django-better-admin-arrayfield +certifi>=2024.07.04 diff --git a/requirements.txt b/requirements.txt index 05b7f12a63..623dedfb1a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,6 +10,8 @@ amqp==5.2.0 # via kombu analytics-python==1.3.0b1 # via shared +annotated-types==0.7.0 + # via pydantic anyio==3.6.1 # via # httpcore @@ -53,12 +55,13 @@ celery==5.3.6 # via # -r requirements.in # sentry-sdk -cerberus==1.3.2 +cerberus==1.3.5 # via # -r requirements.in # shared -certifi==2023.7.22 +certifi==2024.7.4 # via + # -r requirements.in # elastic-apm # httpcore # httpx @@ -96,7 +99,7 @@ coverage[toml]==7.5.1 # via # codecovopentelem # pytest-cov -cryptography==42.0.5 +cryptography==43.0.1 # via shared ddsketch==3.0.1 # via ddtrace @@ -106,7 +109,7 @@ deprecated==1.2.12 # via opentelemetry-api distlib==0.3.1 # via virtualenv -django==4.2.11 +django==4.2.15 # via # -r requirements.in # ariadne-django @@ -136,7 +139,7 @@ django-dynamic-fixture==3.1.1 # via -r requirements.in django-filter==2.4.0 # via -r requirements.in -django-model-utils==4.3.1 +django-model-utils==4.5.1 # via # -r requirements.in # shared @@ -148,7 +151,7 @@ django-prometheus==2.3.1 # via # -r requirements.in # shared -djangorestframework==3.14.0 +djangorestframework==3.15.2 # via # -r requirements.in # drf-spectacular @@ -296,7 +299,7 @@ opentelemetry-util-http==0.45b0 # opentelemetry-instrumentation-wsgi opentracing==2.4.0 # via -r requirements.in -packaging==20.9 +packaging==24.1 # via # gunicorn # pytest @@ -332,6 +335,10 @@ pyasn1-modules==0.2.8 # via google-auth pycparser==2.20 # via cffi +pydantic==2.8.2 + # via -r requirements.in +pydantic-core==2.20.1 + # via pydantic pyjwt==2.8.0 # via # -r requirements.in @@ -339,7 +346,7 @@ pyjwt==2.8.0 pyparsing==2.4.7 # via # httplib2 - # packaging + # shared pyrsistent==0.18.1 # via jsonschema pytest==8.1.1 @@ -375,7 +382,6 @@ python-redis-lock==4.0.0 pytz==2022.1 # via # -r requirements.in - # djangorestframework # shared pyyaml==6.0.1 # via @@ -390,7 +396,7 @@ redis==4.4.4 # shared regex==2023.12.25 # via -r requirements.in -requests==2.31.0 +requests==2.32.3 # via # -r requirements.in # analytics-python @@ -405,13 +411,13 @@ rsa==4.7.2 # via google-auth s3transfer==0.5.0 # via boto3 -sentry-sdk[celery]==1.44.1 +sentry-sdk[celery]==2.13.0 # via # -r requirements.in # shared setproctitle==1.1.10 # via -r requirements.in -shared @ https://github.com/codecov/shared/archive/c71c6ae56ea219d40d9316576716148e836e02c2.tar.gz +shared @ https://github.com/codecov/shared/archive/d4ceb0eeb71eaa855c793d11edd6db34d54bc883.tar.gz # via -r requirements.in simplejson==3.17.2 # via -r requirements.in @@ -459,13 +465,15 @@ typing-extensions==4.6.2 # ariadne # ddtrace # opentelemetry-sdk + # pydantic + # pydantic-core # shared # stripe tzdata==2024.1 # via celery uritemplate==4.1.1 # via drf-spectacular -urllib3==1.26.18 +urllib3==1.26.19 # via # -r requirements.in # botocore @@ -498,7 +506,7 @@ xmltodict==0.13.0 # via ddtrace yarl==1.9.4 # via vcrpy -zipp==3.17.0 +zipp==3.19.2 # via importlib-metadata # The following packages are considered to be unsafe in a requirements file: diff --git a/rollouts/__init__.py b/rollouts/__init__.py index 16e9d53cb9..7f46f49555 100644 --- a/rollouts/__init__.py +++ b/rollouts/__init__.py @@ -7,5 +7,7 @@ def owner_slug(owner: Owner) -> str: return f"{owner.service}/{owner.username}" +__all__ = ["Feature"] + # By default, features have one variant: # { "enabled": FeatureVariant(True, 1.0) } diff --git a/ruff.toml b/ruff.toml index a824e1ef3a..4c03a80c73 100644 --- a/ruff.toml +++ b/ruff.toml @@ -36,10 +36,10 @@ indent-width = 4 target-version = "py312" [lint] -# Currently only enabled for F (Pyflakes), I (isort), E,W (pycodestyle:Error/Warning), PLC (Pylint:Convention) -# and PLE (Pylint:Error) rules: https://docs.astral.sh/ruff/rules/ -select = ["F", "I", "E", "W", "PLC", "PLE"] -ignore = ["F841", "F401", "F405", "F403", "E501", "E712", "E711"] +# Currently only enabled for F (Pyflakes), I (isort), E,W (pycodestyle:Error/Warning), PLC/PLE (Pylint:Convention/Error) +# and PERF (Perflint) rules: https://docs.astral.sh/ruff/rules/ +select = ["F", "I", "E", "W", "PLC", "PLE", "PERF"] +ignore = ["F405", "F403", "E501", "E712"] # Allow fix for all enabled rules (when `--fix`) is provided. # The preferred method (for now) w.r.t. fixable rules is to manually update the makefile diff --git a/services/archive.py b/services/archive.py index 292c6532f2..283ca82fa7 100644 --- a/services/archive.py +++ b/services/archive.py @@ -1,14 +1,10 @@ -import json import logging from base64 import b16encode from enum import Enum from hashlib import md5 from uuid import uuid4 -from django.conf import settings from django.utils import timezone -from minio import Minio -from shared.utils.ReportEncoder import ReportEncoder from services.storage import StorageService from utils.config import get_config @@ -18,11 +14,7 @@ class MinioEndpoints(Enum): chunks = "{version}/repos/{repo_hash}/commits/{commitid}/chunks.txt" - json_data = "{version}/repos/{repo_hash}/commits/{commitid}/json_data/{table}/{field}/{external_id}.json" - json_data_no_commit = ( - "{version}/repos/{repo_hash}/json_data/{table}/{field}/{external_id}.json" - ) - raw = "v4/raw/{date}/{repo_hash}/{commit_sha}/{reportid}.txt" + raw_with_upload_id = ( "v4/raw/{date}/{repo_hash}/{commit_sha}/{reportid}/{uploadid}.txt" ) @@ -32,7 +24,6 @@ class MinioEndpoints(Enum): static_analysis_single_file = ( "{version}/repos/{repo_hash}/static_analysis/files/{location}" ) - test_results = "test_results/v1/raw/{date}/{repo_hash}/{commit_sha}/{uploadid}.txt" def get_path(self, **kwaargs): @@ -72,21 +63,6 @@ def __init__(self, repository, ttl=None): self.storage = StorageService() self.storage_hash = self.get_archive_hash(repository) - """ - Accessor for underlying StorageService. You typically shouldn't need - this for anything. - """ - - def storage_client(self): - return self.storage - - """ - Getter. Returns true if the current configuration is enterprise. - """ - - def is_enterprise(self): - return settings.IS_ENTERPRISE - """ Generates a hash key from repo specific information. Provides slight obfuscation of data in minio storage @@ -110,141 +86,23 @@ def get_archive_hash(cls, repository): _hash.update(val) return b16encode(_hash.digest()).decode() - def write_json_data_to_storage( - self, - commit_id, - table: str, - field: str, - external_id: str, - data: dict, - *, - encoder=ReportEncoder, - ): - if commit_id is None: - # Some classes don't have a commit associated with them - # For example Pull belongs to multiple commits. - path = MinioEndpoints.json_data_no_commit.get_path( - version="v4", - repo_hash=self.storage_hash, - table=table, - field=field, - external_id=external_id, - ) - else: - path = MinioEndpoints.json_data.get_path( - version="v4", - repo_hash=self.storage_hash, - commitid=commit_id, - table=table, - field=field, - external_id=external_id, - ) - stringified_data = json.dumps(data, cls=encoder) - self.write_file(path, stringified_data) - return path - - """ - Grabs path from storage, adds data to path object - writes back to path, overwriting the original contents - """ - - def update_archive(self, path, data): - self.storage.append_to_file(self.root, path, data) - - """ - Writes a generic file to the archive -- it's typically recommended to - not use this in lieu of the convenience methods write_raw_upload and - write_chunks - """ - - def write_file(self, path, data, reduced_redundancy=False, gzipped=False): - self.storage.write_file( - self.root, - path, - data, - reduced_redundancy=reduced_redundancy, - gzipped=gzipped, - ) - - """ - Convenience write method, writes a raw upload to a destination. - Returns the path it writes. - """ - - def write_raw_upload(self, commit_sha, report_id, data, gzipped=False): - # create a custom report path for a raw upload. - # write the file. - path = "/".join( - ( - "v4/raw", - timezone.now().strftime("%Y-%m-%d"), - self.storage_hash, - commit_sha, - "%s.txt" % report_id, - ) - ) - - self.write_file(path, data, gzipped=gzipped) - - return path - - """ - Convenience method to write a chunks.txt file to storage. - """ - - def write_chunks(self, commit_sha, data): - path = MinioEndpoints.chunks.get_path( - version="v4", repo_hash=self.storage_hash, commitid=commit_sha - ) - - self.write_file(path, data) - return path - - """ - Generic method to read a file from the archive - """ - def read_file(self, path): + """ + Generic method to read a file from the archive + """ contents = self.storage.read_file(self.root, path) return contents.decode() - """ - Generic method to delete a file from the archive. - """ - - def delete_file(self, path): - self.storage.delete_file(self.root, path) - - """ - Deletes an entire repository's contents - """ - - def delete_repo_files(self): - path = "v4/repos/{}".format(self.storage_hash) - objects = self.storage.list_folder_contents(self.root, path) - for obj in objects: - self.storage.delete_file(self.root, obj.object_name) - - """ - Convenience method to read a chunks file from the archive. - """ - def read_chunks(self, commit_sha): + """ + Convenience method to read a chunks file from the archive. + """ path = MinioEndpoints.chunks.get_path( version="v4", repo_hash=self.storage_hash, commitid=commit_sha ) log.info("Downloading chunks from path %s for commit %s", path, commit_sha) return self.read_file(path) - """ - Delete a chunk file from the archive - """ - - def delete_chunk_from_archive(self, commit_sha): - path = "v4/repos/{}/commits/{}/chunks.txt".format(self.storage_hash, commit_sha) - - self.delete_file(path) - def create_presigned_put(self, path): return self.storage.create_presigned_put(self.root, path, self.ttl) diff --git a/services/billing.py b/services/billing.py index d2313cbc1f..29402423c8 100644 --- a/services/billing.py +++ b/services/billing.py @@ -69,6 +69,10 @@ def update_payment_method(self, owner, payment_method): def update_email_address(self, owner, email_address): pass + @abstractmethod + def update_billing_address(self, owner, name, billing_address): + pass + @abstractmethod def get_schedule(self, owner): pass @@ -108,7 +112,7 @@ def get_invoice(self, owner, invoice_id): ) try: invoice = stripe.Invoice.retrieve(invoice_id) - except stripe.InvalidRequestError as e: + except stripe.InvalidRequestError: log.info(f"invoice {invoice_id} not found for owner {owner.ownerid}") return None if invoice["customer"] != owner.stripe_customer_id: @@ -174,6 +178,7 @@ def get_subscription(self, owner: Owner): "latest_invoice", "customer", "customer.invoice_settings.default_payment_method", + "customer.tax_ids", ], ) @@ -396,21 +401,43 @@ def create_checkout_session(self, owner: Owner, desired_plan): return session["id"] @_log_stripe_error - def update_payment_method(self, owner, payment_method): - log.info(f"Stripe update payment method for owner {owner.ownerid}") - if owner.stripe_subscription_id is None: - log.info( - f"stripe_subscription_id is None, no updating card for owner {owner.ownerid}" + def update_payment_method(self, owner: Owner, payment_method): + log.info( + "Stripe update payment method for owner", + extra=dict( + owner_id=owner.ownerid, + user_id=self.requesting_user.ownerid, + subscription_id=owner.stripe_subscription_id, + customer_id=owner.stripe_customer_id, + ), + ) + if owner.stripe_subscription_id is None or owner.stripe_customer_id is None: + log.warn( + "Missing subscription or customer id, returning early", + extra=dict( + owner_id=owner.ownerid, + subscription_id=owner.stripe_subscription_id, + customer_id=owner.stripe_customer_id, + ), ) return None - # attach the payment method + set ass default on the invoice and subscription + # attach the payment method + set as default on the invoice and subscription stripe.PaymentMethod.attach(payment_method, customer=owner.stripe_customer_id) stripe.Customer.modify( owner.stripe_customer_id, invoice_settings={"default_payment_method": payment_method}, ) + stripe.Subscription.modify( + owner.stripe_subscription_id, default_payment_method=payment_method + ) log.info( - f"Stripe success update payment method for owner {owner.ownerid} by user #{self.requesting_user.ownerid}" + "Successfully updated payment method for owner {owner.ownerid} by user #{self.requesting_user.ownerid}", + extra=dict( + owner_id=owner.ownerid, + user_id=self.requesting_user.ownerid, + subscription_id=owner.stripe_subscription_id, + customer_id=owner.stripe_customer_id, + ), ) @_log_stripe_error @@ -430,17 +457,36 @@ def update_email_address(self, owner: Owner, email_address: str): ) @_log_stripe_error - def update_billing_address(self, owner: Owner, billing_address): + def update_billing_address(self, owner: Owner, name, billing_address): log.info(f"Stripe update billing address for owner {owner.ownerid}") - if owner.stripe_subscription_id is None: + if owner.stripe_customer_id is None: log.info( - f"stripe_subscription_id is None, cannot update billing address for owner {owner.ownerid}" + f"stripe_customer_id is None, cannot update default billing address for owner {owner.ownerid}" ) return None - stripe.Customer.modify(owner.stripe_customer_id, address=billing_address) - log.info( - f"Stripe successfully updated billing address for owner {owner.ownerid} by user #{self.requesting_user.ownerid}" - ) + + try: + default_payment_method = stripe.Customer.retrieve( + owner.stripe_customer_id + ).invoice_settings.default_payment_method + + stripe.PaymentMethod.modify( + default_payment_method, + billing_details={"name": name, "address": billing_address}, + ) + + stripe.Customer.modify(owner.stripe_customer_id, address=billing_address) + log.info( + f"Stripe successfully updated billing address for owner {owner.ownerid} by user #{self.requesting_user.ownerid}" + ) + except Exception: + log.error( + "Unable to update billing address for customer", + extra=dict( + customer_id=owner.stripe_customer_id, + subscription_id=owner.stripe_subscription_id, + ), + ) @_log_stripe_error def apply_cancellation_discount(self, owner: Owner): @@ -507,7 +553,7 @@ def update_payment_method(self, owner, payment_method): def update_email_address(self, owner, email_address): pass - def update_billing_address(self, owner, billing_address): + def update_billing_address(self, owner, name, billing_address): pass def get_schedule(self, owner): @@ -586,13 +632,13 @@ def update_email_address(self, owner: Owner, email_address: str): """ return self.payment_service.update_email_address(owner, email_address) - def update_billing_address(self, owner: Owner, billing_address): + def update_billing_address(self, owner: Owner, name: str, billing_address): """ Takes an owner and a billing address. Try to update the owner's billing address to the address passed in. Address should be validated via stripe component prior to hitting this service method. Return None if invalid. """ - return self.payment_service.update_billing_address(owner, billing_address) + return self.payment_service.update_billing_address(owner, name, billing_address) def apply_cancellation_discount(self, owner: Owner): return self.payment_service.apply_cancellation_discount(owner) diff --git a/services/bundle_analysis.py b/services/bundle_analysis.py index 413df12f02..032f06dc3e 100644 --- a/services/bundle_analysis.py +++ b/services/bundle_analysis.py @@ -2,6 +2,7 @@ import os from dataclasses import dataclass from datetime import datetime +from decimal import Decimal from typing import Any, Dict, Iterable, List, Optional from django.utils.functional import cached_property @@ -18,7 +19,10 @@ from shared.storage import get_appropriate_storage_service from core.models import Commit, Repository -from graphql_api.actions.measurements import measurements_by_ids +from graphql_api.actions.measurements import ( + measurements_by_ids, + measurements_last_uploaded_before_start_date, +) from reports.models import CommitReport from services.archive import ArchiveService from timeseries.helpers import fill_sparse_measurements @@ -44,14 +48,6 @@ def load_report( return loader.load(commit_report.external_id) -# TODO: depreacted with Issue 1199 -def load_time_conversion(size: int) -> float: - """ - Converts total size in bytes to approximate time (in seconds) to download using a 3G internet (3 Mbps) - """ - return round((8 * size) / (1024 * 1024 * 3), 1) - - def get_extension(filename: str) -> str: """ Gets the file extension of the file without the dot @@ -88,7 +84,7 @@ def __init__( asset_type: BundleAnalysisMeasurementsAssetType, asset_name: Optional[str], interval: Interval, - after: datetime, + after: Optional[datetime], before: datetime, ): self.raw_measurements = raw_measurements @@ -161,14 +157,20 @@ class BundleSize: @dataclass class BundleData: - def __init__(self, size_in_bytes: int): + def __init__(self, size_in_bytes: int, gzip_size_in_bytes: Optional[int] = None): self.size_in_bytes = size_in_bytes self.size_in_bits = size_in_bytes * 8 + self.gzip_size_in_bytes = gzip_size_in_bytes @cached_property def size(self) -> BundleSize: + gzip_size = ( + self.gzip_size_in_bytes + if self.gzip_size_in_bytes is not None + else int(float(self.size_in_bytes) * BundleSize.GZIP) + ) return BundleSize( - gzip=int(float(self.size_in_bytes) * BundleSize.GZIP), + gzip=gzip_size, uncompress=int(float(self.size_in_bytes) * BundleSize.UNCOMPRESS), ) @@ -204,6 +206,10 @@ def __init__(self, asset: SharedAssetReport): self.asset = asset self.all_modules = None + @cached_property + def id(self) -> int: + return self.asset.id + @cached_property def name(self) -> str: return self.asset.hashed_name @@ -220,6 +226,10 @@ def extension(self) -> str: def size_total(self) -> int: return self.asset.size + @cached_property + def gzip_size_total(self) -> int: + return self.asset.gzip_size + @cached_property def modules(self) -> List[ModuleReport]: return [ModuleReport(module) for module in self.asset.modules()] @@ -231,8 +241,9 @@ def module_extensions(self) -> List[str]: @dataclass class BundleReport(object): - def __init__(self, report: SharedBundleReport): + def __init__(self, report: SharedBundleReport, filters: Dict[str, Any] = {}): self.report = report + self.filters = filters @cached_property def name(self) -> str: @@ -242,13 +253,18 @@ def name(self) -> str: def all_assets(self) -> List[AssetReport]: return [AssetReport(asset) for asset in self.report.asset_reports()] - def assets(self, extensions: Optional[List[str]] = None) -> List[AssetReport]: - all_assets = self.all_assets - - # TODO: Unimplemented #1192 - Filter by extensions - filtered_assets = all_assets - - return filtered_assets + def assets( + self, ordering: Optional[str] = None, ordering_desc: Optional[bool] = None + ) -> List[AssetReport]: + ordering_dict: Dict[str, Any] = {} + if ordering: + ordering_dict["ordering_column"] = ordering + if ordering_desc is not None: + ordering_dict["ordering_desc"] = ordering_desc + return [ + AssetReport(asset) + for asset in self.report.asset_reports(**{**ordering_dict, **self.filters}) + ] def asset(self, name: str) -> Optional[AssetReport]: for asset_report in self.all_assets: @@ -257,12 +273,11 @@ def asset(self, name: str) -> Optional[AssetReport]: @cached_property def size_total(self) -> int: - return self.report.total_size() + return self.report.total_size(**self.filters) - # To be deprecated after FE uses BundleData @cached_property - def load_time_total(self) -> float: - return load_time_conversion(self.report.total_size()) + def gzip_size_total(self) -> int: + return self.report.total_gzip_size(**self.filters) @cached_property def module_extensions(self) -> List[str]: @@ -273,23 +288,24 @@ def module_extensions(self) -> List[str]: @cached_property def module_count(self) -> int: - return len(self.module_extensions) + return sum([len(asset.modules) for asset in self.assets()]) + + @cached_property + def is_cached(self) -> bool: + return self.report.is_cached() @dataclass class BundleAnalysisReport(object): def __init__(self, report: SharedBundleAnalysisReport): self.report = report - self.cleanup() - - def cleanup(self) -> None: - if self.report and self.report.db_session: - self.report.db_session.close() - def bundle(self, name: str) -> Optional[BundleReport]: + def bundle( + self, name: str, filters: Dict[str, List[str]] + ) -> Optional[BundleReport]: bundle_report = self.report.bundle_report(name) if bundle_report: - return BundleReport(bundle_report) + return BundleReport(bundle_report, filters) @cached_property def bundles(self) -> List[BundleReport]: @@ -300,8 +316,8 @@ def size_total(self) -> int: return sum([bundle.size_total for bundle in self.bundles]) @cached_property - def load_time_total(self) -> float: - return load_time_conversion(self.size_total) + def is_cached(self) -> bool: + return self.report.is_cached() @dataclass @@ -311,20 +327,15 @@ def __init__( loader: BundleAnalysisReportLoader, base_report_key: str, head_report_key: str, + repository: Repository, ): self.comparison = SharedBundleAnalysisComparison( loader, base_report_key, head_report_key, + repository, ) self.head_report = self.comparison.head_report - self.cleanup() - - def cleanup(self) -> None: - if self.comparison.head_report and self.comparison.head_report.db_session: - self.comparison.head_report.db_session.close() - if self.comparison.base_report and self.comparison.base_report.db_session: - self.comparison.base_report.db_session.close() @cached_property def bundles(self) -> List["BundleComparison"]: @@ -341,18 +352,10 @@ def bundles(self) -> List["BundleComparison"]: def size_delta(self) -> int: return sum([change.size_delta for change in self.comparison.bundle_changes()]) - @cached_property - def load_time_delta(self) -> float: - return load_time_conversion(self.size_delta) - @cached_property def size_total(self) -> int: return BundleAnalysisReport(self.head_report).size_total - @cached_property - def load_time_total(self) -> float: - return load_time_conversion(self.size_total) - @dataclass class BundleComparison(object): @@ -376,22 +379,14 @@ def size_delta(self) -> int: def size_total(self) -> int: return self.head_bundle_report_size - @cached_property - def load_time_delta(self) -> float: - return load_time_conversion(self.bundle_change.size_delta) - - @cached_property - def load_time_total(self) -> float: - return load_time_conversion(self.head_bundle_report_size) - class BundleAnalysisMeasurementsService(object): def __init__( self, repository: Repository, interval: Interval, - after: datetime, before: datetime, + after: Optional[datetime] = None, branch: Optional[str] = None, ) -> None: self.repository = repository @@ -400,6 +395,46 @@ def __init__( self.before = before self.branch = branch + def _compute_measurements( + self, measurable_name: str, measurable_ids: List[str] + ) -> Dict[int, List[Dict[str, Any]]]: + all_measurements = measurements_by_ids( + repository=self.repository, + measurable_name=measurable_name, + measurable_ids=measurable_ids, + interval=self.interval, + after=self.after, + before=self.before, + branch=self.branch, + ) + + # Carry over previous available value for start date if its value is null + for measurable_id, measurements in all_measurements.items(): + if self.after is not None and measurements[0]["timestamp_bin"] > self.after: + carryover_measurement = measurements_last_uploaded_before_start_date( + owner_id=self.repository.author.ownerid, + repo_id=self.repository.repoid, + measurable_name=measurable_name, + measurable_id=measurable_id, + start_date=self.after, + branch=self.branch, + ) + + # Create a new datapoint in the measurements and prepend it to the existing list + # If there isn't any measurements before the start date range, measurements will be untouched + if carryover_measurement: + value = Decimal(carryover_measurement[0]["value"]) + carryover = dict(measurements[0]) + carryover["timestamp_bin"] = self.after + carryover["min"] = value + carryover["max"] = value + carryover["avg"] = value + all_measurements[measurable_id] = [carryover] + all_measurements[ + measurable_id + ] + + return all_measurements + def compute_asset( self, asset_report: AssetReport ) -> Optional[BundleAnalysisMeasurementData]: @@ -407,15 +442,11 @@ def compute_asset( if asset.asset_type != AssetType.JAVASCRIPT: return None - measurements = measurements_by_ids( - repository=self.repository, + measurements = self._compute_measurements( measurable_name=MeasurementName.BUNDLE_ANALYSIS_ASSET_SIZE.value, measurable_ids=[asset.uuid], - interval=self.interval, - after=self.after, - before=self.before, - branch=self.branch, ) + return BundleAnalysisMeasurementData( raw_measurements=list(measurements.get(asset.uuid, [])), asset_type=BundleAnalysisMeasurementsAssetType.JAVASCRIPT_SIZE, @@ -441,26 +472,19 @@ def compute_report( else: measurable_ids = [bundle_report.name] - measurements = measurements_by_ids( - repository=self.repository, + measurements = self._compute_measurements( measurable_name=asset_type.value.value, measurable_ids=measurable_ids, - interval=self.interval, - after=self.after, - before=self.before, - branch=self.branch, ) - results = [] - for measurable_id in measurable_ids: - results.append( - BundleAnalysisMeasurementData( - raw_measurements=list(measurements.get(measurable_id, [])), - asset_type=asset_type, - asset_name=asset_uuid_to_name_mapping.get(measurable_id, None), - interval=self.interval, - after=self.after, - before=self.before, - ) + return [ + BundleAnalysisMeasurementData( + raw_measurements=list(measurements.get(measurable_id, [])), + asset_type=asset_type, + asset_name=asset_uuid_to_name_mapping.get(measurable_id, None), + interval=self.interval, + after=self.after, + before=self.before, ) - return results + for measurable_id in measurable_ids + ] diff --git a/services/comparison.py b/services/comparison.py index 726629e75b..de1be75038 100644 --- a/services/comparison.py +++ b/services/comparison.py @@ -14,7 +14,6 @@ from django.db.models import Prefetch, QuerySet from django.utils.functional import cached_property from shared.helpers.yaml import walk -from shared.reports.readonly import ReadOnlyReport from shared.reports.types import ReportTotals from shared.utils.merge import LineType, line_type @@ -691,7 +690,7 @@ def get_file_comparison(self, file_name, with_src=False, bypass_max_diff=False): @property def git_comparison(self): - return self._fetch_comparison_and_reverse_comparison[0] + return self._fetch_comparison[0] @cached_property def base_report(self): @@ -713,14 +712,21 @@ def head_report(self): else: raise e - report.apply_diff(self.git_comparison["diff"]) + # Return the old report if the github API call fails for any reason + try: + report.apply_diff(self.git_comparison["diff"]) + except Exception: + pass return report @cached_property def has_different_number_of_head_and_base_sessions(self): - self.validate() + log.info("has_different_number_of_head_and_base_sessions - Start") head_sessions = self.head_report.sessions base_sessions = self.base_report.sessions + log.info( + f"has_different_number_of_head_and_base_sessions - Retrieved sessions - head {len(head_sessions)} / base {len(base_sessions)}" + ) # We're treating this case as false since considering CFF's complicates the logic if self._has_cff_sessions(head_sessions) or self._has_cff_sessions( base_sessions @@ -731,10 +737,12 @@ def has_different_number_of_head_and_base_sessions(self): # I feel this method should belong to the API Report class, but we're thinking of getting rid of that class soon # In truth, this should be in the shared.Report class def _has_cff_sessions(self, sessions) -> bool: + log.info(f"_has_cff_sessions - sessions count {len(sessions)}") for session in sessions.values(): if session.session_type.value == "carriedforward": + log.info("_has_cff_sessions - Found carriedforward") return True - + log.info("_has_cff_sessions - No carriedforward") return False @property @@ -763,10 +771,9 @@ def upload_commits(self): return commits_queryset @cached_property - def _fetch_comparison_and_reverse_comparison(self): + def _fetch_comparison(self): """ - Fetches comparison and reverse comparison concurrently, then - caches the result. Returns (comparison, reverse_comparison). + Fetches comparison, and caches the result. """ adapter = RepoProviderService().get_adapter( self.user, self.base_commit.repository @@ -775,12 +782,8 @@ def _fetch_comparison_and_reverse_comparison(self): self.base_commit.commitid, self.head_commit.commitid ) - reverse_comparison_coro = adapter.get_compare( - self.head_commit.commitid, self.base_commit.commitid - ) - async def runnable(): - return await asyncio.gather(comparison_coro, reverse_comparison_coro) + return await asyncio.gather(comparison_coro) return async_to_sync(runnable)() @@ -792,18 +795,6 @@ def non_carried_forward_flags(self): flags_dict = self.head_report.flags return [flag for flag, vals in flags_dict.items() if not vals.carriedforward] - @cached_property - def has_unmerged_base_commits(self): - """ - We use reverse comparison to detect if any commits exist in the - base reference but not in the head reference. We use this information - to show a message in the UI urging the user to integrate the changes - in the base reference in order to see accurate coverage information. - We compare with 1 because torngit injects the base commit into the commits - array because reasons. - """ - return len(self._fetch_comparison_and_reverse_comparison[1]["commits"]) > 1 - class FlagComparison(object): def __init__(self, comparison, flag_name): @@ -870,7 +861,7 @@ def has_diff(self) -> bool: """ Returns `True` if the file has any additions or removals in the diff """ - return ( + return bool( self.added_diff_coverage and len(self.added_diff_coverage) > 0 or self.removed_diff_coverage @@ -954,7 +945,10 @@ def change_coverage(self) -> Optional[float]: and self.head_coverage and self.head_coverage.coverage ): - return float(self.head_coverage.coverage - self.base_coverage.coverage) + return float( + float(self.head_coverage.coverage or 0) + - float(self.base_coverage.coverage or 0) + ) @cached_property def file_name(self) -> Optional[str]: diff --git a/services/components.py b/services/components.py index 97c7f2b7e7..34460588e8 100644 --- a/services/components.py +++ b/services/components.py @@ -15,7 +15,7 @@ from timeseries.models import Interval -def commit_components(commit: Commit, owner: Owner) -> List[Component]: +def commit_components(commit: Commit, owner: Owner | None) -> List[Component]: """ Get the list of components for a commit. A request is made to the provider on behalf of the given `owner` diff --git a/services/path.py b/services/path.py index b343a73fa3..9ce429aef3 100644 --- a/services/path.py +++ b/services/path.py @@ -1,4 +1,3 @@ -import re from collections import defaultdict from dataclasses import dataclass from functools import cached_property diff --git a/services/refresh.py b/services/refresh.py index 99e1542ed3..fcb332dc34 100644 --- a/services/refresh.py +++ b/services/refresh.py @@ -1,4 +1,3 @@ -from contextlib import suppress from json import dumps, loads from celery.result import result_from_tuple diff --git a/services/report.py b/services/report.py index 1d2e150e01..7f8aa62f80 100644 --- a/services/report.py +++ b/services/report.py @@ -13,7 +13,7 @@ from shared.utils.sessions import Session, SessionType from core.models import Commit -from reports.models import AbstractTotals, CommitReport, ReportDetails, ReportSession +from reports.models import AbstractTotals, CommitReport, ReportSession from services.archive import ArchiveService from utils.config import RUN_ENV @@ -29,7 +29,7 @@ def file_reports(self): def flags(self): """returns dict(:name=)""" flags_dict = {} - for sid, session in self.sessions.items(): + for session in self.sessions.values(): if session.flags is not None: carriedforward = session.session_type.value == "carriedforward" carriedforward_from = session.session_extras.get("carriedforward_from") @@ -226,7 +226,6 @@ def build_files(commit_report: CommitReport) -> dict[str, ReportFileSummary]: file["filename"]: ReportFileSummary( file_index=file["file_index"], file_totals=ReportTotals(*file["file_totals"]), - session_totals=file["session_totals"], diff_totals=file["diff_totals"], ) for file in report_details.files_array diff --git a/services/self_hosted.py b/services/self_hosted.py index 1db2c12b63..e2b5273873 100644 --- a/services/self_hosted.py +++ b/services/self_hosted.py @@ -44,7 +44,7 @@ def admin_owners() -> QuerySet: def is_admin_owner(owner: Optional[Owner]) -> bool: """ - Returns true iff the given owner is an admin. + Returns true if the given owner is an admin. """ return owner is not None and admin_owners().filter(pk=owner.pk).exists() @@ -64,13 +64,12 @@ def activated_owners() -> QuerySet: .values_list("plan_activated_owner_ids", flat=True) .distinct() ) - return Owner.objects.filter(pk__in=owner_ids) def is_activated_owner(owner: Owner) -> bool: """ - Returns true iff the given owner is activated in this instance. + Returns true if the given owner is activated in this instance. """ return activated_owners().filter(pk=owner.pk).exists() @@ -79,13 +78,27 @@ def license_seats() -> int: """ Max number of seats allowed by the current license. """ - license = get_current_license() - return license.number_allowed_users or 0 + enterprise_license = get_current_license() + if not enterprise_license.is_valid: + return 0 + return enterprise_license.number_allowed_users or 0 + + +def enterprise_has_seats_left() -> bool: + """ + The activated_owner_query is heavy, so check the license first, only proceed if they have a valid license. + """ + license_seat_count = license_seats() + if license_seat_count == 0: + return False + owners = activated_owners() + count = owners.count() + return count < license_seat_count def can_activate_owner(owner: Owner) -> bool: """ - Returns true iff there are available seats left for activation. + Returns true if there are available seats left for activation. """ if is_activated_owner(owner): # user is already activated in at least 1 org diff --git a/services/task/task.py b/services/task/task.py index 972712b0b5..d842664422 100644 --- a/services/task/task.py +++ b/services/task/task.py @@ -1,11 +1,9 @@ import logging -import os from datetime import datetime, timedelta from typing import Iterable, List, Optional, Tuple import celery -import sentry_sdk -from celery import Celery, chain, group, signals, signature +from celery import Celery, chain, group, signature from celery.canvas import Signature from django.conf import settings from sentry_sdk import set_tag diff --git a/services/tests/samples/bundle_report.sqlite b/services/tests/samples/bundle_report.sqlite index 59910f61cb..dfaf104b00 100644 Binary files a/services/tests/samples/bundle_report.sqlite and b/services/tests/samples/bundle_report.sqlite differ diff --git a/services/tests/samples/head_bundle_report_with_compare_sha_6ca727b0142bf5625bb82af2555d308862063222.sqlite b/services/tests/samples/head_bundle_report_with_compare_sha_6ca727b0142bf5625bb82af2555d308862063222.sqlite new file mode 100644 index 0000000000..d17d6789bf Binary files /dev/null and b/services/tests/samples/head_bundle_report_with_compare_sha_6ca727b0142bf5625bb82af2555d308862063222.sqlite differ diff --git a/services/tests/samples/head_bundle_report_with_gzip_size.sqlite b/services/tests/samples/head_bundle_report_with_gzip_size.sqlite new file mode 100644 index 0000000000..06b12b5cdf Binary files /dev/null and b/services/tests/samples/head_bundle_report_with_gzip_size.sqlite differ diff --git a/services/tests/test_analytics.py b/services/tests/test_analytics.py index 0586767bef..522620082a 100644 --- a/services/tests/test_analytics.py +++ b/services/tests/test_analytics.py @@ -1,5 +1,5 @@ from datetime import datetime, timedelta -from unittest.mock import call, patch +from unittest.mock import patch import pytest from django.test import TestCase diff --git a/services/tests/test_archive.py b/services/tests/test_archive.py index e65d523c36..85cd39737e 100644 --- a/services/tests/test_archive.py +++ b/services/tests/test_archive.py @@ -1,10 +1,7 @@ -import json from pathlib import Path -from time import time from unittest.mock import patch from django.test import TestCase -from shared.storage import MinioStorageService from core.tests.factories import RepositoryFactory from services.archive import ArchiveService @@ -19,101 +16,3 @@ def test_create_raw_upload_presigned_put(self, create_presigned_put_mock): repo = RepositoryFactory.create() service = ArchiveService(repo) assert service.create_raw_upload_presigned_put("ABCD") == "presigned url" - - -class TestWriteData(object): - def test_write_report_details_to_storage(self, mocker, db): - repo = RepositoryFactory() - mock_write_file = mocker.patch.object(MinioStorageService, "write_file") - - data = [ - { - "filename": "file_1.go", - "file_index": 0, - "file_totals": [0, 8, 5, 3, 0, "62.50000", 0, 0, 0, 0, 10, 2, 0], - "session_totals": { - "0": [0, 8, 5, 3, 0, "62.50000", 0, 0, 0, 0, 10, 2], - "meta": {"session_count": 1}, - }, - "diff_totals": None, - }, - { - "filename": "file_2.py", - "file_index": 1, - "file_totals": [0, 2, 1, 0, 1, "50.00000", 1, 0, 0, 0, 0, 0, 0], - "session_totals": { - "0": [0, 2, 1, 0, 1, "50.00000", 1], - "meta": {"session_count": 1}, - }, - "diff_totals": None, - }, - ] - archive_service = ArchiveService(repository=repo) - commitid = "some-commit-sha" - external_id = "some-uuid4-id" - path = archive_service.write_json_data_to_storage( - commit_id=commitid, - table="reports_reportsdetails", - field="files_array", - external_id=external_id, - data=data, - ) - assert ( - path - == f"v4/repos/{archive_service.storage_hash}/commits/{commitid}/json_data/reports_reportsdetails/files_array/{external_id}.json" - ) - mock_write_file.assert_called_with( - archive_service.root, - path, - json.dumps(data), - gzipped=False, - reduced_redundancy=False, - ) - - def test_write_report_details_to_storage_no_commitid(self, mocker, db): - repo = RepositoryFactory() - mock_write_file = mocker.patch.object(MinioStorageService, "write_file") - - data = [ - { - "filename": "file_1.go", - "file_index": 0, - "file_totals": [0, 8, 5, 3, 0, "62.50000", 0, 0, 0, 0, 10, 2, 0], - "session_totals": { - "0": [0, 8, 5, 3, 0, "62.50000", 0, 0, 0, 0, 10, 2], - "meta": {"session_count": 1}, - }, - "diff_totals": None, - }, - { - "filename": "file_2.py", - "file_index": 1, - "file_totals": [0, 2, 1, 0, 1, "50.00000", 1, 0, 0, 0, 0, 0, 0], - "session_totals": { - "0": [0, 2, 1, 0, 1, "50.00000", 1], - "meta": {"session_count": 1}, - }, - "diff_totals": None, - }, - ] - archive_service = ArchiveService(repository=repo) - commitid = None - external_id = "some-uuid4-id" - path = archive_service.write_json_data_to_storage( - commit_id=commitid, - table="reports_reportsdetails", - field="files_array", - external_id=external_id, - data=data, - ) - assert ( - path - == f"v4/repos/{archive_service.storage_hash}/json_data/reports_reportsdetails/files_array/{external_id}.json" - ) - mock_write_file.assert_called_with( - archive_service.root, - path, - json.dumps(data), - gzipped=False, - reduced_redundancy=False, - ) diff --git a/services/tests/test_billing.py b/services/tests/test_billing.py index 5250c3a50c..0f44d22d17 100644 --- a/services/tests/test_billing.py +++ b/services/tests/test_billing.py @@ -893,9 +893,11 @@ def test_get_proration_params(self): plan=PlanName.CODECOV_PRO_MONTHLY.value, plan_user_count=20 ) desired_plan = {"value": PlanName.SENTRY_MONTHLY.value, "quantity": 19} - self.stripe._get_proration_params(owner, desired_plan) == "none" + assert self.stripe._get_proration_params(owner, desired_plan) == "none" desired_plan = {"value": PlanName.SENTRY_MONTHLY.value, "quantity": 20} - self.stripe._get_proration_params(owner, desired_plan) == "always_invoice" + assert ( + self.stripe._get_proration_params(owner, desired_plan) == "always_invoice" + ) desired_plan = {"value": PlanName.SENTRY_MONTHLY.value, "quantity": 21} assert ( self.stripe._get_proration_params(owner, desired_plan) == "always_invoice" @@ -1081,14 +1083,13 @@ def test_create_checkout_session_with_stripe_customer_id( def test_get_subscription_when_no_subscription(self): owner = OwnerFactory(stripe_subscription_id=None) - assert self.stripe.get_subscription(owner) == None + assert self.stripe.get_subscription(owner) is None @patch("services.billing.stripe.Subscription.retrieve") def test_get_subscription_returns_stripe_data(self, subscription_retrieve_mock): owner = OwnerFactory(stripe_subscription_id="abc") # only including fields relevant to implementation stripe_data_subscription = {"doesnt": "matter"} - payment_method_id = "pm_something_something" subscription_retrieve_mock.return_value = stripe_data_subscription assert self.stripe.get_subscription(owner) == stripe_data_subscription subscription_retrieve_mock.assert_called_once_with( @@ -1097,16 +1098,20 @@ def test_get_subscription_returns_stripe_data(self, subscription_retrieve_mock): "latest_invoice", "customer", "customer.invoice_settings.default_payment_method", + "customer.tax_ids", ], ) def test_update_payment_method_when_no_subscription(self): owner = OwnerFactory(stripe_subscription_id=None) - assert self.stripe.update_payment_method(owner, "abc") == None + assert self.stripe.update_payment_method(owner, "abc") is None @patch("services.billing.stripe.PaymentMethod.attach") @patch("services.billing.stripe.Customer.modify") - def test_update_payment_method(self, modify_customer_mock, attach_payment_mock): + @patch("services.billing.stripe.Subscription.modify") + def test_update_payment_method( + self, modify_sub_mock, modify_customer_mock, attach_payment_mock + ): payment_method_id = "pm_1234567" subscription_id = "sub_abc" customer_id = "cus_abc" @@ -1121,13 +1126,17 @@ def test_update_payment_method(self, modify_customer_mock, attach_payment_mock): customer_id, invoice_settings={"default_payment_method": payment_method_id} ) + modify_sub_mock.assert_called_once_with( + subscription_id, default_payment_method=payment_method_id + ) + def test_update_email_address_with_invalid_email(self): owner = OwnerFactory(stripe_subscription_id=None) - assert self.stripe.update_email_address(owner, "not-an-email") == None + assert self.stripe.update_email_address(owner, "not-an-email") is None def test_update_email_address_when_no_subscription(self): owner = OwnerFactory(stripe_subscription_id=None) - assert self.stripe.update_email_address(owner, "test@gmail.com") == None + assert self.stripe.update_email_address(owner, "test@gmail.com") is None @patch("services.billing.stripe.Customer.modify") def test_update_email_address(self, modify_customer_mock): @@ -1140,37 +1149,50 @@ def test_update_email_address(self, modify_customer_mock): self.stripe.update_email_address(owner, "test@gmail.com") modify_customer_mock.assert_called_once_with(customer_id, email=email) - def test_update_billing_address_with_invalid_email(self): - owner = OwnerFactory(stripe_subscription_id=None) - assert self.stripe.update_billing_address(owner, "gabagool") == None + @patch("logging.Logger.error") + def test_update_billing_address_with_invalid_address(self, log_error_mock): + owner = OwnerFactory(stripe_customer_id="123", stripe_subscription_id="123") + assert self.stripe.update_billing_address(owner, "John Doe", "gabagool") is None + log_error_mock.assert_called_with( + "Unable to update billing address for customer", + extra={ + "customer_id": "123", + "subscription_id": "123", + }, + ) - def test_update_billing_address_when_no_subscription(self): - owner = OwnerFactory(stripe_subscription_id=None) + def test_update_billing_address_when_no_customer_id(self): + owner = OwnerFactory(stripe_customer_id=None) assert ( self.stripe.update_billing_address( owner, + name="John Doe", billing_address={ - "line_1": "45 Fremont St.", - "line_2": "", + "line1": "45 Fremont St.", + "line2": "", "city": "San Francisco", "state": "CA", "country": "US", "postal_code": "94105", }, ) - == None + is None ) + @patch("services.billing.stripe.Customer.retrieve") + @patch("services.billing.stripe.PaymentMethod.modify") @patch("services.billing.stripe.Customer.modify") - def test_update_billing_address(self, modify_customer_mock): + def test_update_billing_address( + self, modify_customer_mock, modify_payment_mock, retrieve_customer_mock + ): subscription_id = "sub_abc" customer_id = "cus_abc" owner = OwnerFactory( stripe_subscription_id=subscription_id, stripe_customer_id=customer_id ) billing_address = { - "line_1": "45 Fremont St.", - "line_2": "", + "line1": "45 Fremont St.", + "line2": "", "city": "San Francisco", "state": "CA", "country": "US", @@ -1178,8 +1200,12 @@ def test_update_billing_address(self, modify_customer_mock): } self.stripe.update_billing_address( owner, + name="John Doe", billing_address=billing_address, ) + + retrieve_customer_mock.assert_called_once() + modify_payment_mock.assert_called_once() modify_customer_mock.assert_called_once_with( customer_id, address=billing_address ) @@ -1190,7 +1216,7 @@ def test_get_invoice_not_found(self, retrieve_invoice_mock): retrieve_invoice_mock.side_effect = InvalidRequestError( message="not found", param=invoice_id ) - assert self.stripe.get_invoice(OwnerFactory(), invoice_id) == None + assert self.stripe.get_invoice(OwnerFactory(), invoice_id) is None retrieve_invoice_mock.assert_called_once_with(invoice_id) @patch("services.billing.stripe.Invoice.retrieve") @@ -1199,7 +1225,7 @@ def test_get_invoice_customer_dont_match(self, retrieve_invoice_mock): invoice_id = "abc" invoice = {"invoice_id": "abc", "customer": "cus_abc"} retrieve_invoice_mock.return_value = invoice - assert self.stripe.get_invoice(owner, invoice_id) == None + assert self.stripe.get_invoice(owner, invoice_id) is None retrieve_invoice_mock.assert_called_once_with(invoice_id) @patch("services.billing.stripe.Invoice.retrieve") @@ -1261,7 +1287,7 @@ def test_apply_cancellation_discount_yearly( assert not customer_modify_mock.called assert not coupon_create_mock.called - assert owner.stripe_coupon_id == None + assert owner.stripe_coupon_id is None @patch("services.billing.stripe.Coupon.create") @patch("services.billing.stripe.Customer.modify") @@ -1276,7 +1302,7 @@ def test_apply_cancellation_discount_no_subscription( assert not customer_modify_mock.called assert not coupon_create_mock.called - assert owner.stripe_coupon_id == None + assert owner.stripe_coupon_id is None @patch("services.billing.stripe.Coupon.create") @patch("services.billing.stripe.Customer.modify") @@ -1319,6 +1345,9 @@ def update_payment_method(self, owner, plan): def update_email_address(self, owner, email_address): pass + def update_billing_address(self, owner, name, billing_address): + pass + def get_schedule(self, owner): pass diff --git a/services/tests/test_bundle_analysis.py b/services/tests/test_bundle_analysis.py index 04a026c0c5..4aa2d90b70 100644 --- a/services/tests/test_bundle_analysis.py +++ b/services/tests/test_bundle_analysis.py @@ -74,8 +74,6 @@ def test_bundle_comparison(self, mock_shared_bundle_change): assert bundle_comparison.change_type == "added" assert bundle_comparison.size_delta == 1000000 assert bundle_comparison.size_total == 7654321 - assert bundle_comparison.load_time_delta == 2.5 - assert bundle_comparison.load_time_total == 19.5 class TestBundleAnalysisComparison(TestCase): @@ -120,13 +118,12 @@ def test_bundle_analysis_comparison(self, get_storage_service): loader, self.base_commit_report.external_id, self.head_commit_report.external_id, + self.repo, ) assert len(bac.bundles) == 5 assert bac.size_delta == 36555 assert bac.size_total == 201720 - assert bac.load_time_delta == 0.1 - assert bac.load_time_total == 0.5 class TestBundleReport(TestCase): @@ -146,7 +143,6 @@ def name(self): assert bundle_comparison.name == "bundle1" assert bundle_comparison.size_total == 7654321 - assert bundle_comparison.load_time_total == 19.5 class TestBundleAnalysisReport(TestCase): @@ -179,4 +175,3 @@ def test_bundle_analysis_report(self, get_storage_service): assert len(bar.bundles) == 4 assert bar.size_total == 201720 - assert bar.load_time_total == 0.5 diff --git a/services/tests/test_comparison.py b/services/tests/test_comparison.py index 3a5193318f..f7bef14f92 100644 --- a/services/tests/test_comparison.py +++ b/services/tests/test_comparison.py @@ -141,7 +141,7 @@ def test_diff_with_1_line_removed_file_adjusts_lines(self): def test_pop_line_returns_none_if_no_diff_or_src(self): manager = FileComparisonTraverseManager() - assert manager.pop_line() == None + assert manager.pop_line() is None def test_pop_line_pops_first_line_in_segment_if_traversing_that_segment(self): expected_line_value = "+this is a line!" @@ -244,7 +244,7 @@ def test_can_traverse_diff_with_line_numbers_greater_than_file_eof(self): manager.apply([visitor]) assert visitor.line_numbers == [(1, 1), (2, 2), (3, None), (None, 3)] - def test_can_traverse_diff_with_difflike_lines(self): + def test_can_traverse_diff_with_diff_like_lines(self): src = [ "- line 1", # not part of diff "- line 2", # not part of diff @@ -371,7 +371,7 @@ def test_number_shows_none_for_base_if_plus_not_part_of_diff(self): lc = LineComparison(None, [0, "", [], 0, 0], base_ln, head_ln, "+", False) assert lc.number == {"base": base_ln, "head": head_ln} - def test_number_shows_none_for_base_if_minux_not_part_of_diff(self): + def test_number_shows_none_for_base_if_minus_not_part_of_diff(self): base_ln = 3 head_ln = 4 lc = LineComparison(None, [0, "", [], 0, 0], base_ln, head_ln, "-", False) @@ -451,16 +451,16 @@ def test_hit_session_ids(self): def test_hit_session_ids_no_coverage(self): lc = LineComparison(None, [0, "", [[0, 0, 0, 0, 0]], 0, 0], 0, 0, "", False) - assert lc.hit_session_ids == None + assert lc.hit_session_ids is None def test_hit_session_ids_no_head_line(self): lc = LineComparison(None, None, 0, 0, "", False) - assert lc.hit_session_ids == None + assert lc.hit_session_ids is None class FileComparisonConstructorTests(TestCase): def test_constructor_no_keyError_if_diff_data_segements_is_missing(self): - file_comp = FileComparison( + FileComparison( head_file=ReportFile("file1"), base_file=ReportFile("file1"), diff_data={} ) @@ -542,7 +542,7 @@ def test_stats_returns_diff_stats_if_diff_data(self): self.file_comparison.diff_data = {"stats": expected_stats} assert self.file_comparison.stats == expected_stats - def test_lines_returns_emptylist_if_no_diff_or_src(self): + def test_lines_returns_empty_list_if_no_diff_or_src(self): assert self.file_comparison.lines == [] # essentially a smoke/integration test @@ -988,7 +988,7 @@ def test_files_with_changes_retrieves_from_redis(self, mocked_get): @patch("redis.Redis.get") def test_files_with_changes_returns_none_if_no_files_with_changes(self, mocked_get): mocked_get.return_value = None - assert self.comparison._files_with_changes == None + assert self.comparison._files_with_changes is None @patch("redis.Redis.get") def test_files_with_changes_doesnt_crash_if_redis_connection_problem( @@ -1154,11 +1154,11 @@ def test_allow_coverage_offsets(self, get_config_mock): with self.subTest("returns app settings value if exists, True if not"): get_config_mock.return_value = True comparison = PullRequestComparison(owner, pull) - comparison.allow_coverage_offsets is True + assert comparison.allow_coverage_offsets is True get_config_mock.return_value = False comparison = PullRequestComparison(owner, pull) - comparison.allow_coverage_offsets is False + assert comparison.allow_coverage_offsets is False @patch("services.repo_providers.RepoProviderService.get_adapter") def test_pseudo_diff_returns_diff_between_base_and_compared_to( @@ -1318,22 +1318,6 @@ def setUp(self): self.comparison = Comparison(user=owner, base_commit=base, head_commit=head) asyncio.set_event_loop(asyncio.new_event_loop()) - def test_returns_true_if_reverse_comparison_has_commits(self, get_adapter_mock): - commits = ["a", "b"] - get_adapter_mock.return_value = ( - ComparisonHasUnmergedBaseCommitsTests.MockFetchDiffCoro(commits) - ) - assert self.comparison.has_unmerged_base_commits is True - - def test_returns_false_if_reverse_comparison_has_one_commit_or_less( - self, get_adapter_mock - ): - commits = ["a"] - get_adapter_mock.return_value = ( - ComparisonHasUnmergedBaseCommitsTests.MockFetchDiffCoro(commits) - ) - assert self.comparison.has_unmerged_base_commits is False - class SegmentTests(TestCase): def _report_lines(self, hits): diff --git a/services/tests/test_path.py b/services/tests/test_path.py index 38021e2e89..f05e579e20 100644 --- a/services/tests/test_path.py +++ b/services/tests/test_path.py @@ -298,7 +298,7 @@ def test_provider_path_not_found(self, mock_provider_adapter): @patch("services.repo_providers.RepoProviderService.get_adapter") def test_provider_path_other_error(self, mock_provider_adapter): mock_provider_adapter.side_effect = TorngitClientGeneralError(500, None, None) - assert provider_path_exists("foo/bar", self.commit, self.owner) == None + assert provider_path_exists("foo/bar", self.commit, self.owner) is None @pytest.mark.usefixtures("sample_report") diff --git a/services/tests/test_profiling.py b/services/tests/test_profiling.py index a10f840d81..edf8f7b6d2 100644 --- a/services/tests/test_profiling.py +++ b/services/tests/test_profiling.py @@ -1,5 +1,4 @@ import json -import re from datetime import datetime from unittest.mock import MagicMock, patch @@ -107,7 +106,7 @@ def test_latest_profiling_commit_with_sha(self): def test_summary_data_not_summarized(self): pc = ProfilingCommitFactory(repository=self.repo) - assert self.service.summary_data(pc) == None + assert self.service.summary_data(pc) is None @patch("services.archive.ArchiveService.read_file") def test_summary_data_not_found(self, read_file): @@ -127,7 +126,7 @@ def test_summary_data_not_found(self, read_file): last_summarized_at=datetime.now(), ) - assert self.service.summary_data(pc) == None + assert self.service.summary_data(pc) is None @patch("services.archive.ArchiveService.read_file") def test_summary_data(self, read_file): diff --git a/services/tests/test_refresh.py b/services/tests/test_refresh.py index 0e04e37738..bfc21a7666 100644 --- a/services/tests/test_refresh.py +++ b/services/tests/test_refresh.py @@ -1,5 +1,3 @@ -from json import loads - import pytest from services.refresh import RefreshService diff --git a/services/tests/test_repo_providers.py b/services/tests/test_repo_providers.py index 3d2b0669a2..320cbad66d 100644 --- a/services/tests/test_repo_providers.py +++ b/services/tests/test_repo_providers.py @@ -249,7 +249,7 @@ def test_get_adapter_verify_ssl_true(self, mock_get_config, mock_get_provider): user = OwnerFactory(service="github") repo = RepositoryFactory.create(author=user, bot=bot) - provider = RepoProviderService().get_adapter( + RepoProviderService().get_adapter( user, repo, use_ssl=True, token=repo.bot.oauth_token ) mock_get_provider.call_args == ( @@ -285,7 +285,7 @@ def test_get_adapter_for_uploads_verify_ssl_false( user = OwnerFactory(service="github") repo = RepositoryFactory.create(author=user, bot=bot) - provider = RepoProviderService().get_adapter( + RepoProviderService().get_adapter( user, repo, use_ssl=True, token=repo.bot.oauth_token ) mock_get_provider.call_args == ( diff --git a/services/tests/test_report.py b/services/tests/test_report.py index ce60a9e4b2..78659aee9c 100644 --- a/services/tests/test_report.py +++ b/services/tests/test_report.py @@ -144,7 +144,7 @@ def test_build_report_from_commit(self, read_chunks_mock): def test_build_report_from_commit_file_not_in_storage(self, read_chunks_mock): read_chunks_mock.side_effect = FileNotInStorageError() commit = CommitWithReportFactory.create(message="aaaaa", commitid="abf6d4d") - assert build_report_from_commit(commit) == None + assert build_report_from_commit(commit) is None @patch("services.archive.ArchiveService.read_chunks") def test_build_report_from_commit_cff_and_direct_uploads(self, read_chunks_mock): diff --git a/services/tests/test_self_hosted.py b/services/tests/test_self_hosted.py index ce255632cb..e9e4681389 100644 --- a/services/tests/test_self_hosted.py +++ b/services/tests/test_self_hosted.py @@ -26,7 +26,7 @@ class SelfHostedTestCase(TestCase): @patch("services.self_hosted.get_config") def test_admin_owners(self, get_config): owner1 = OwnerFactory(service="github", username="foo") - owner2 = OwnerFactory(service="github", username="bar") + OwnerFactory(service="github", username="bar") owner3 = OwnerFactory(service="gitlab", username="foo") get_config.return_value = [ @@ -40,9 +40,9 @@ def test_admin_owners(self, get_config): get_config.assert_called_once_with("setup", "admins", default=[]) def test_admin_owners_empty(self): - owner1 = OwnerFactory(service="github", username="foo") - owner2 = OwnerFactory(service="github", username="bar") - owner3 = OwnerFactory(service="gitlab", username="foo") + OwnerFactory(service="github", username="foo") + OwnerFactory(service="github", username="bar") + OwnerFactory(service="gitlab", username="foo") owners = admin_owners() assert list(owners) == [] @@ -64,9 +64,9 @@ def test_activated_owners(self): user1 = OwnerFactory() user2 = OwnerFactory() user3 = OwnerFactory() - user4 = OwnerFactory() - org1 = OwnerFactory(plan_activated_users=[user1.pk]) - org2 = OwnerFactory(plan_activated_users=[user2.pk, user3.pk]) + OwnerFactory() + OwnerFactory(plan_activated_users=[user1.pk]) + OwnerFactory(plan_activated_users=[user2.pk, user3.pk]) owners = activated_owners() assert list(owners) == [user1, user2, user3] @@ -87,12 +87,14 @@ def test_is_activated_owner(self, activated_owners): @patch("services.self_hosted.get_current_license") def test_license_seats(self, get_current_license): - get_current_license.return_value = LicenseInformation(number_allowed_users=123) + get_current_license.return_value = LicenseInformation( + number_allowed_users=123, is_valid=True + ) assert license_seats() == 123 @patch("services.self_hosted.get_current_license") def test_license_seats_not_specified(self, get_current_license): - get_current_license.return_value = LicenseInformation() + get_current_license.return_value = LicenseInformation(is_valid=True) assert license_seats() == 0 @patch("services.self_hosted.activated_owners") diff --git a/services/tests/test_sentry.py b/services/tests/test_sentry.py index bdb59bf5fb..9276ac0113 100644 --- a/services/tests/test_sentry.py +++ b/services/tests/test_sentry.py @@ -29,11 +29,11 @@ def test_decode_state(self): @override_settings(SENTRY_JWT_SHARED_SECRET="wrong") def test_decode_state_wrong_secret(self): res = decode_state(self.state) - assert res == None + assert res is None def test_decode_state_malformed(self): res = decode_state("malformed") - assert res == None + assert res is None class SaveSentryStateTests(TransactionTestCase): @@ -70,8 +70,8 @@ def test_save_sentry_state_invalid_state(self): save_sentry_state(self.owner, MagicMock()) self.owner.refresh_from_db() - assert self.owner.sentry_user_id == None - assert self.owner.sentry_user_data == None + assert self.owner.sentry_user_id is None + assert self.owner.sentry_user_data is None def test_save_sentry_state_duplicate_user_id(self): OwnerFactory(sentry_user_id="sentry-user-id") @@ -79,8 +79,8 @@ def test_save_sentry_state_duplicate_user_id(self): save_sentry_state(self.owner, MagicMock()) self.owner.refresh_from_db() - assert self.owner.sentry_user_id == None - assert self.owner.sentry_user_data == None + assert self.owner.sentry_user_id is None + assert self.owner.sentry_user_data is None class IsSentryUserTests(TestCase): diff --git a/services/tests/test_task.py b/services/tests/test_task.py index eb1e4f4308..4bb0b5ed83 100644 --- a/services/tests/test_task.py +++ b/services/tests/test_task.py @@ -1,9 +1,7 @@ from datetime import datetime -from operator import xor from unittest.mock import MagicMock import pytest -from celery import Task from django.conf import settings from freezegun import freeze_time from shared import celery_config @@ -175,9 +173,7 @@ def test_backfill_repo(mocker): @freeze_time("2023-06-13T10:01:01.000123") def test_backfill_dataset(mocker): signature_mock = mocker.patch("services.task.task.signature") - mock_route_task = mocker.patch( - "services.task.task.route_task", return_value={"queue": "celery"} - ) + mocker.patch("services.task.task.route_task", return_value={"queue": "celery"}) signature = MagicMock() signature_mock.return_value = signature diff --git a/services/tests/test_yaml.py b/services/tests/test_yaml.py index a2ea893f30..ca6651c429 100644 --- a/services/tests/test_yaml.py +++ b/services/tests/test_yaml.py @@ -1,6 +1,5 @@ from unittest.mock import patch -from django.contrib.auth.models import AnonymousUser from django.test import TransactionTestCase from shared.torngit.exceptions import TorngitObjectNotFoundError diff --git a/services/yaml.py b/services/yaml.py index 118dea7be6..edf5faab8b 100644 --- a/services/yaml.py +++ b/services/yaml.py @@ -1,6 +1,7 @@ import enum +import logging from functools import lru_cache -from typing import Dict, Optional +from typing import Dict from asgiref.sync import async_to_sync from shared.yaml import UserYaml, fetch_current_yaml_from_provider_via_reference @@ -16,7 +17,10 @@ class YamlStates(enum.Enum): DEFAULT = "default" -def fetch_commit_yaml(commit: Commit, owner: Owner) -> Optional[Dict]: +log = logging.getLogger(__name__) + + +def fetch_commit_yaml(commit: Commit, owner: Owner | None) -> Dict | None: """ Fetches the codecov.yaml file for a particular commit from the service provider. Service provider API request is made on behalf of the given `owner`. @@ -35,12 +39,16 @@ def fetch_commit_yaml(commit: Commit, owner: Owner) -> Optional[Dict]: # have various exceptions, which we do not care about to get the final # yaml used for a commit, as any error here, the codecov.yaml would not # be used, so we return None here + log.warning( + "Was not able to fetch yaml file for commit. Ignoring error and returning None.", + extra={"commit_id": commit.commitid}, + ) return None @lru_cache() # TODO: make this use the Redis cache logic in 'shared' once it's there -def final_commit_yaml(commit: Commit, owner: Owner) -> UserYaml: +def final_commit_yaml(commit: Commit, owner: Owner | None) -> UserYaml: return UserYaml.get_final_yaml( owner_yaml=commit.repository.author.yaml, repo_yaml=commit.repository.yaml, diff --git a/staticanalysis/admin.py b/staticanalysis/admin.py index 8c38f3f3da..846f6b4061 100644 --- a/staticanalysis/admin.py +++ b/staticanalysis/admin.py @@ -1,3 +1 @@ -from django.contrib import admin - # Register your models here. diff --git a/staticanalysis/tests/test_views.py b/staticanalysis/tests/test_views.py index 14a2625fc6..624ea642b8 100644 --- a/staticanalysis/tests/test_views.py +++ b/staticanalysis/tests/test_views.py @@ -1,4 +1,4 @@ -from uuid import UUID, uuid4 +from uuid import uuid4 from django.urls import reverse from rest_framework.test import APIClient diff --git a/staticanalysis/urls.py b/staticanalysis/urls.py index 03c5dcabf8..6a2a5510b8 100644 --- a/staticanalysis/urls.py +++ b/staticanalysis/urls.py @@ -1,5 +1,3 @@ -from django.urls import path - from staticanalysis.views import StaticAnalysisSuiteViewSet from utils.routers import OptionalTrailingSlashRouter diff --git a/staticanalysis/views.py b/staticanalysis/views.py index 49ce38cbbc..7045e63020 100644 --- a/staticanalysis/views.py +++ b/staticanalysis/views.py @@ -3,7 +3,6 @@ from django.http import HttpResponse from rest_framework import mixins, viewsets from rest_framework.decorators import action -from rest_framework.response import Response from shared.celery_config import static_analysis_task_name from codecov_auth.authentication.repo_auth import RepositoryTokenAuthentication diff --git a/timeseries/admin.py b/timeseries/admin.py index e0f5fd11e0..5e493dffd4 100644 --- a/timeseries/admin.py +++ b/timeseries/admin.py @@ -3,7 +3,6 @@ import django.forms as forms from django.conf import settings from django.contrib import admin, messages -from django.db import transaction from django.db.models import QuerySet from django.shortcuts import render diff --git a/timeseries/helpers.py b/timeseries/helpers.py index 1eb045c469..bf5faf721d 100644 --- a/timeseries/helpers.py +++ b/timeseries/helpers.py @@ -21,15 +21,12 @@ from django.db.models.functions import Cast from django.utils import timezone -import services.report as report_service from codecov_auth.models import Owner from core.models import Commit, Repository -from reports.models import RepositoryFlag from services.task import TaskService from timeseries.models import ( Dataset, Interval, - Measurement, MeasurementName, MeasurementSummary, ) diff --git a/timeseries/migrations/0007_auto_20220727_2011.py b/timeseries/migrations/0007_auto_20220727_2011.py index 947e8d6dc7..2697f871dd 100644 --- a/timeseries/migrations/0007_auto_20220727_2011.py +++ b/timeseries/migrations/0007_auto_20220727_2011.py @@ -1,6 +1,5 @@ # Generated by Django 3.2.12 on 2022-07-27 20:11 -import django.utils.timezone from django.db import migrations, models import core.models diff --git a/timeseries/migrations/0014_remove_measurement_timeseries_measurement_flag_unique_and_more.py b/timeseries/migrations/0014_remove_measurement_timeseries_measurement_flag_unique_and_more.py index 1aa7ffe2f2..27a4ec9dc6 100644 --- a/timeseries/migrations/0014_remove_measurement_timeseries_measurement_flag_unique_and_more.py +++ b/timeseries/migrations/0014_remove_measurement_timeseries_measurement_flag_unique_and_more.py @@ -1,7 +1,6 @@ import django.utils.timezone from django.db import migrations from shared.django_apps.migration_utils import ( # Generated by Django 4.1.7 on 2023-05-15 20:46 - RiskyAddIndex, RiskyRemoveConstraint, RiskyRemoveField, ) diff --git a/timeseries/tests/test_admin.py b/timeseries/tests/test_admin.py index 60f23ec711..d3a86dcef6 100644 --- a/timeseries/tests/test_admin.py +++ b/timeseries/tests/test_admin.py @@ -64,7 +64,7 @@ def test_perform_backfill(self, backfill_dataset): ) assert res.status_code == 302 - backfill_dataset.call_count == 2 + assert backfill_dataset.call_count == 2 backfill_dataset.assert_any_call( self.dataset1, start_date=timezone.datetime(2000, 1, 1, tzinfo=timezone.utc), diff --git a/timeseries/tests/test_helpers.py b/timeseries/tests/test_helpers.py index 66509cac1c..687b0646c7 100644 --- a/timeseries/tests/test_helpers.py +++ b/timeseries/tests/test_helpers.py @@ -11,7 +11,6 @@ from codecov_auth.tests.factories import OwnerFactory from core.tests.factories import CommitFactory, RepositoryFactory -from reports.tests.factories import RepositoryFlagFactory from timeseries.helpers import ( coverage_measurements, fill_sparse_measurements, @@ -19,7 +18,7 @@ refresh_measurement_summaries, repository_coverage_measurements_with_fallback, ) -from timeseries.models import Dataset, Interval, Measurement, MeasurementName +from timeseries.models import Dataset, Interval, MeasurementName from timeseries.tests.factories import DatasetFactory, MeasurementFactory diff --git a/timeseries/tests/test_models.py b/timeseries/tests/test_models.py index f1039da28a..14003b2696 100644 --- a/timeseries/tests/test_models.py +++ b/timeseries/tests/test_models.py @@ -5,7 +5,7 @@ from django.test import TransactionTestCase from freezegun import freeze_time -from timeseries.models import Dataset, Interval, MeasurementName, MeasurementSummary +from timeseries.models import Dataset, Interval, MeasurementSummary from .factories import DatasetFactory, MeasurementFactory diff --git a/upload/helpers.py b/upload/helpers.py index df9619f58a..77194a8595 100644 --- a/upload/helpers.py +++ b/upload/helpers.py @@ -42,6 +42,11 @@ is_pull_noted_in_branch = re.compile(r".*(pull|pr)\/(\d+).*") +# Valid values are `https://dev.azure.com/username/` or `https://username.visualstudio.com/` +# May be URL-encoded, so ':' can be '%3A' and '/' can be '%2F' +# Username is alphanumeric with '_' and '-' +_valid_azure_server_uri = r"^https?(?:://|%3A%2F%2F)(?:dev.azure.com(?:/|%2F)[a-zA-Z0-9_-]+(?:/|%2F)|[a-zA-Z0-9_-]+.visualstudio.com(?:/|%2F))$" + log = logging.getLogger(__name__) redis = get_redis_connection() @@ -207,7 +212,10 @@ def parse_params(data): "url": {"type": "string"}, # custom location where report is found "parent": {"type": "string"}, "project": {"type": "string"}, - "server_uri": {"type": "string"}, + "server_uri": { + "type": "string", + "regex": _valid_azure_server_uri, + }, "root": {"type": "string"}, # deprecated "storage_path": {"type": "string"}, } @@ -231,6 +239,8 @@ def get_repo_with_github_actions_oidc_token(token): else: service = "github_enterprise" github_enterprise_url = get_config("github_enterprise", "url") + # remove trailing slashes if present + github_enterprise_url = re.sub(r"/+$", "", github_enterprise_url) jwks_url = f"{github_enterprise_url}/_services/token/.well-known/jwks" jwks_client = PyJWKClient(jwks_url) signing_key = jwks_client.get_signing_key_from_jwt(token) @@ -441,13 +451,13 @@ def determine_upload_commit_to_use(upload_params, repository): git_commit_data = _get_git_commit_data( adapter, upload_params.get("commit"), token ) - except TorngitObjectNotFoundError as e: + except TorngitObjectNotFoundError: log.warning( "Unable to fetch commit. Not found", extra=dict(commit=upload_params.get("commit")), ) return upload_params.get("commit") - except TorngitClientError as e: + except TorngitClientError: log.warning( "Unable to fetch commit", extra=dict(commit=upload_params.get("commit")) ) @@ -781,13 +791,28 @@ def get_version_from_headers(headers): def generate_upload_sentry_metrics_tags( - action, request, repository, is_shelter_request, endpoint: Optional[str] = None + action, + request, + is_shelter_request, + endpoint: Optional[str] = None, + repository: Optional[Repository] = None, + position: Optional[str] = None, + upload_version: Optional[str] = None, ): - return dict( + metrics_tags = dict( agent=get_agent_from_headers(request.headers), version=get_version_from_headers(request.headers), action=action, endpoint=endpoint, - repo_visibility="private" if repository.private is True else "public", is_using_shelter="yes" if is_shelter_request else "no", ) + if repository: + metrics_tags["repo_visibility"] = ( + "private" if repository.private is True else "public" + ) + if position: + metrics_tags["position"] = position + if upload_version: + metrics_tags["upload_version"] = upload_version + + return metrics_tags diff --git a/upload/serializers.py b/upload/serializers.py index 951fbf8ac8..244c4dd537 100644 --- a/upload/serializers.py +++ b/upload/serializers.py @@ -1,5 +1,3 @@ -from typing import Dict - from django.conf import settings from rest_framework import serializers @@ -7,6 +5,7 @@ from core.models import Commit, Repository from reports.models import CommitReport, ReportResults, ReportSession, RepositoryFlag from services.archive import ArchiveService +from services.task import TaskService class FlagListField(serializers.ListField): @@ -138,9 +137,15 @@ class Meta: def create(self, validated_data): repo = validated_data.pop("repository", None) commitid = validated_data.pop("commitid", None) - commit, _ = Commit.objects.get_or_create( + commit, created = Commit.objects.get_or_create( repository=repo, commitid=commitid, defaults=validated_data ) + + if created: + TaskService().update_commit( + commitid=commit.commitid, repoid=commit.repository.repoid + ) + return commit @@ -156,7 +161,7 @@ class Meta: ) fields = read_only_fields + ("code",) - def create(self, validated_data): + def create(self, validated_data) -> tuple[CommitReport, bool]: report = ( CommitReport.objects.coverage_reports() .filter( @@ -169,8 +174,8 @@ def create(self, validated_data): if report.report_type is None: report.report_type = CommitReport.ReportType.COVERAGE report.save() - return report - return super().create(validated_data) + return report, False + return super().create(validated_data), True class ReportResultsSerializer(serializers.ModelSerializer): diff --git a/upload/tests/test_serializers.py b/upload/tests/test_serializers.py index bbce26d29a..4591f7441e 100644 --- a/upload/tests/test_serializers.py +++ b/upload/tests/test_serializers.py @@ -18,10 +18,10 @@ def get_fake_upload(): - user_with_no_uplaods = OwnerFactory() - user_with_uplaods = OwnerFactory() - repo = RepositoryFactory.create(author=user_with_uplaods, private=True) - public_repo = RepositoryFactory.create(author=user_with_uplaods, private=False) + OwnerFactory() + user_with_uploads = OwnerFactory() + repo = RepositoryFactory.create(author=user_with_uploads, private=True) + RepositoryFactory.create(author=user_with_uploads, private=False) commit = CommitFactory.create(repository=repo) report = CommitReportFactory.create(commit=commit) diff --git a/upload/tests/test_throttles.py b/upload/tests/test_throttles.py index 00c2448abe..94d56c295f 100644 --- a/upload/tests/test_throttles.py +++ b/upload/tests/test_throttles.py @@ -1,4 +1,4 @@ -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import MagicMock, Mock from django.test import override_settings from rest_framework.test import APITestCase @@ -96,7 +96,7 @@ def test_check_commit_constraints_settings_disabled(self): self.request_should_not_throttle(unrelated_commit) @override_settings(UPLOAD_THROTTLING_ENABLED=True) - def test_check_commit_constraints_settings_enabled(self): + def test_throttle_check_commit_constraints_settings_enabled(self): author = self.owner first_commit = CommitFactory.create(repository__author=author) diff --git a/upload/tests/test_upload.py b/upload/tests/test_upload.py index 9edadcdae3..415540c6c2 100644 --- a/upload/tests/test_upload.py +++ b/upload/tests/test_upload.py @@ -334,7 +334,7 @@ def test_determine_repo_upload_tokenless(self, mock_get): repo = G(Repository, author=org) expected_response = { "id": 732059764, - "finishTime": f"{datetime.utcnow()}", + "finishTime": f"{datetime.now()}", "status": "inProgress", "sourceVersion": "3be5c52bd748c508a7e96993c02cf3518c816e84", "buildNumber": "732059764", @@ -342,7 +342,7 @@ def test_determine_repo_upload_tokenless(self, mock_get): "number": "498.1", "state": "passed", "started_at": "2020-10-01T20:02:55Z", - "finished_at": f"{datetime.utcnow()}".split(".")[0], + "finished_at": f"{datetime.now()}".split(".")[0], "project": {"visibility": "public", "repositoryType": "github"}, "triggerInfo": {"pr.sourceSha": "3be5c52bd748c508a7e96993c02cf3518c816e84"}, "build": { @@ -406,7 +406,7 @@ def test_determine_repo_upload_tokenless(self, mock_get): "using_global_token": False, "branch": None, "project": "p12", - "server_uri": "https://", + "server_uri": "https://dev.azure.com/example/", "_did_change_merge_commit": False, "parent": "123abc", } @@ -626,7 +626,7 @@ def test_insert_commit(self): assert commit.branch == "test" assert commit.pullid == 123 assert commit.merged == False - assert commit.parent_commit_id == None + assert commit.parent_commit_id is None with self.subTest("commit already in database"): G( @@ -654,7 +654,7 @@ def test_insert_commit(self): assert commit.state == "pending" assert commit.branch == "oranges" assert commit.pullid == 456 - assert commit.merged == None + assert commit.merged is None assert commit.parent_commit_id == "different_parent_commit" with self.subTest("parent provided"): @@ -674,8 +674,8 @@ def test_insert_commit(self): assert commit.repository == repo assert commit.state == "pending" assert commit.branch == "test" - assert commit.pullid == None - assert commit.merged == None + assert commit.pullid is None + assert commit.merged is None assert commit.parent_commit_id == parent.commitid def test_parse_request_headers(self): @@ -796,9 +796,7 @@ def test_validate_upload_repository_blacklisted(self): def test_validate_upload_per_repo_billing_invalid(self): redis = MockRedis() owner = G(Owner, plan="1m") - repo_already_activated = G( - Repository, author=owner, private=True, activated=True, active=True - ) + G(Repository, author=owner, private=True, activated=True, active=True) repo = G(Repository, author=owner, private=True, activated=False, active=False) commit = G(Commit) @@ -824,9 +822,7 @@ def test_validate_upload_gitlab_subgroups(self): parent_service_id=top_subgroup.service_id, service="gitlab", ) - repo_already_activated = G( - Repository, author=parent_group, private=True, activated=True, active=True - ) + G(Repository, author=parent_group, private=True, activated=True, active=True) repo = G(Repository, author=bottom_subgroup, private=True, activated=False) commit = G(Commit) @@ -1996,7 +1992,7 @@ def test_success(self, mock_get): "number": "498.1", "state": "passed", "started_at": "2020-10-01T20:02:55Z", - "finished_at": f"{datetime.utcnow()}".split(".")[0], + "finished_at": f"{datetime.now()}".split(".")[0], "build": { "@type": "build", "@href": "/build/732059763", @@ -2176,11 +2172,57 @@ def test_azure_no_server_uri(self): line.strip() for line in expected_error.split("\n") ] + def test_azure_invalid_server_uri(self): + expected_error = """Unable to locate build via Azure API. Please upload with the Codecov repository upload token to resolve issue.""" + + params = { + "project": "project123", + "job": 732059764, + "server_uri": "https://dev.azure.com/missing_trailing_slash", + } + with pytest.raises(NotFound) as e: + TokenlessUploadHandler("azure_pipelines", params).verify_upload() + assert [line.strip() for line in e.value.args[0].split("\n")] == [ + line.strip() for line in expected_error.split("\n") + ] + + params["server_uri"] = "https://missing_trailing_slash.visualstudio.com" + with pytest.raises(NotFound) as e: + TokenlessUploadHandler("azure_pipelines", params).verify_upload() + assert [line.strip() for line in e.value.args[0].split("\n")] == [ + line.strip() for line in expected_error.split("\n") + ] + + params["server_uri"] = "https://example.visualstudio.com.attacker.com/" + with pytest.raises(NotFound) as e: + TokenlessUploadHandler("azure_pipelines", params).verify_upload() + assert [line.strip() for line in e.value.args[0].split("\n")] == [ + line.strip() for line in expected_error.split("\n") + ] + + params["server_uri"] = "https://dev.azure.com.attacker.com/example" + with pytest.raises(NotFound) as e: + TokenlessUploadHandler("azure_pipelines", params).verify_upload() + assert [line.strip() for line in e.value.args[0].split("\n")] == [ + line.strip() for line in expected_error.split("\n") + ] + + params["server_uri"] = "https://dev.azure.attacker.com/example" + with pytest.raises(NotFound) as e: + TokenlessUploadHandler("azure_pipelines", params).verify_upload() + assert [line.strip() for line in e.value.args[0].split("\n")] == [ + line.strip() for line in expected_error.split("\n") + ] + @patch.object(requests, "get") def test_azure_http_error(self, mock_get): mock_get.side_effect = [requests.exceptions.HTTPError("Not found")] - params = {"project": "project123", "job": 732059764, "server_uri": "https://"} + params = { + "project": "project123", + "job": 732059764, + "server_uri": "https://dev.azure.com/example/", + } expected_error = """Unable to locate build via Azure API. Please upload with the Codecov repository upload token to resolve issue.""" @@ -2194,7 +2236,11 @@ def test_azure_http_error(self, mock_get): def test_azure_connection_error(self, mock_get): mock_get.side_effect = [requests.exceptions.ConnectionError("Not found")] - params = {"project": "project123", "job": 732059764, "server_uri": "https://"} + params = { + "project": "project123", + "job": 732059764, + "server_uri": "https://dev.azure.com/example/", + } expected_error = """Unable to locate build via Azure API. Please upload with the Codecov repository upload token to resolve issue.""" @@ -2221,7 +2267,7 @@ def test_azure_no_errors(self, mock_get): params = { "project": "project123", "job": 732059764, - "server_uri": "https://", + "server_uri": "https://dev.azure.com/example/", "commit": "c739768fcac68144a3a6d82305b9c4106934d31a", "build": "20190725.8", } @@ -2233,7 +2279,7 @@ def test_azure_no_errors(self, mock_get): @patch.object(requests, "get") def test_azure_wrong_build_number(self, mock_get): expected_response = { - "finishTime": f"{datetime.utcnow()}", + "finishTime": f"{datetime.now()}", "buildNumber": "BADBUILDNUM", "status": "completed", "sourceVersion": "c739768fcac68144a3a6d82305b9c4106934d31a", @@ -2246,7 +2292,7 @@ def test_azure_wrong_build_number(self, mock_get): params = { "project": "project123", "job": 732059764, - "server_uri": "https://", + "server_uri": "https://dev.azure.com/example/", "commit": "c739768fcac68144a3a6d82305b9c4106934d31a", "build": "20190725.8", } @@ -2262,7 +2308,7 @@ def test_azure_wrong_build_number(self, mock_get): @patch.object(requests, "get") def test_azure_expired_build(self, mock_get): expected_response = { - "finishTime": f"{datetime.utcnow() - timedelta(minutes=4)}", + "finishTime": f"{datetime.now() - timedelta(minutes=4)}", "buildNumber": "20190725.8", "status": "completed", "sourceVersion": "c739768fcac68144a3a6d82305b9c4106934d31a", @@ -2276,7 +2322,7 @@ def test_azure_expired_build(self, mock_get): params = { "project": "project123", "job": 732059764, - "server_uri": "https://", + "server_uri": "https://dev.azure.com/example/", "commit": "c739768fcac68144a3a6d82305b9c4106934d31a", "build": "20190725.8", } @@ -2292,7 +2338,7 @@ def test_azure_expired_build(self, mock_get): @patch.object(requests, "get") def test_azure_invalid_status(self, mock_get): expected_response = { - "finishTime": f"{datetime.utcnow()}", + "finishTime": f"{datetime.now()}", "buildNumber": "20190725.8", "status": "BADSTATUS", "sourceVersion": "c739768fcac68144a3a6d82305b9c4106934d31a", @@ -2306,7 +2352,7 @@ def test_azure_invalid_status(self, mock_get): params = { "project": "project123", "job": 732059764, - "server_uri": "https://", + "server_uri": "https://dev.azure.com/example/", "commit": "c739768fcac68144a3a6d82305b9c4106934d31a", "build": "20190725.8", } @@ -2336,7 +2382,7 @@ def test_azure_wrong_commit(self, mock_get): params = { "project": "project123", "job": 732059764, - "server_uri": "https://", + "server_uri": "https://dev.azure.com/example/", "commit": "c739768fcac68144a3a6d82305b9c4106934d31a", "build": "20190725.8", } @@ -2366,7 +2412,7 @@ def test_azure_not_public(self, mock_get): params = { "project": "project123", "job": 732059764, - "server_uri": "https://", + "server_uri": "https://dev.azure.com/example/", "commit": "c739768fcac68144a3a6d82305b9c4106934d31a", "build": "20190725.8", } @@ -2396,7 +2442,7 @@ def test_azure_wrong_service_type(self, mock_get): params = { "project": "project123", "job": 732059764, - "server_uri": "https://", + "server_uri": "https://dev.azure.com/example/", "commit": "c739768fcac68144a3a6d82305b9c4106934d31a", "build": "20190725.8", } @@ -2426,7 +2472,11 @@ def test_appveyor_no_job(self): def test_appveyor_http_error(self, mock_get): mock_get.side_effect = [requests.exceptions.HTTPError("Not found")] - params = {"project": "project123", "job": "732059764", "server_uri": "https://"} + params = { + "project": "project123", + "job": "732059764", + "server_uri": "https://dev.azure.com/example/", + } expected_error = """Unable to locate build via Appveyor API. Please upload with the Codecov repository upload token to resolve issue.""" @@ -2443,7 +2493,7 @@ def test_appveyor_connection_error(self, mock_get): params = { "project": "project123", "job": "something/else/732059764", - "server_uri": "https://", + "server_uri": "https://dev.azure.com/example/", } expected_error = """Unable to locate build via Appveyor API. Please upload with the Codecov repository upload token to resolve issue.""" @@ -2472,7 +2522,7 @@ def test_appveyor_finished_build(self, mock_get): params = { "project": "project123", "job": "732059764", - "server_uri": "https://", + "server_uri": "https://dev.azure.com/example/", "commit": "c739768fcac68144a3a6d82305b9c4106934d31a", "build": "20190725.8", } @@ -2503,7 +2553,7 @@ def test_appveyor_no_errors(self, mock_get): params = { "project": "project123", "job": "732059764", - "server_uri": "https://", + "server_uri": "https://dev.azure.com/example/", "commit": "c739768fcac68144a3a6d82305b9c4106934d31a", "build": "732059764", } @@ -2530,7 +2580,7 @@ def test_appveyor_invalid_service(self, mock_get): params = { "project": "project123", "job": "732059764", - "server_uri": "https://", + "server_uri": "https://dev.azure.com/example/", "commit": "c739768fcac68144a3a6d82305b9c4106934d31a", "build": "732059764", } @@ -2677,7 +2727,7 @@ def test_underscore_replace(self, mock_get): "commit_sha": "c739768fcac68144a3a6d82305b9c4106934d31a", "slug": "owner/repo", "public": True, - "finish_time": f"{datetime.utcnow()}".split(".")[0], + "finish_time": f"{datetime.now()}".split(".")[0], } mock_get.return_value.status_code.return_value = 200 mock_get.return_value.return_value = expected_response @@ -2866,7 +2916,7 @@ def test_github_actions_no_build_status(self, mock_get): "commit_sha": "c739768fcac68144a3a6d82305b9c4106934d31a", "slug": "owner/repo", "public": True, - "finish_time": f"{datetime.utcnow() - timedelta(minutes=10)}".split(".")[0], + "finish_time": f"{datetime.now() - timedelta(minutes=10)}".split(".")[0], } mock_get.return_value.status_code.return_value = 200 mock_get.return_value.return_value = expected_response @@ -2895,7 +2945,7 @@ def test_github_actions(self, mock_get): "commit_sha": "c739768fcac68144a3a6d82305b9c4106934d31a", "slug": "owner/repo", "public": True, - "finish_time": f"{datetime.utcnow()}".split(".")[0], + "finish_time": f"{datetime.now()}".split(".")[0], } mock_get.return_value.status_code.return_value = 200 mock_get.return_value.return_value = expected_response diff --git a/upload/tests/test_upload_download.py b/upload/tests/test_upload_download.py index ee3e3cba73..e6643f157e 100644 --- a/upload/tests/test_upload_download.py +++ b/upload/tests/test_upload_download.py @@ -2,7 +2,6 @@ import minio from ddf import G -from rest_framework.reverse import reverse from rest_framework.test import APITransactionTestCase from codecov_auth.models import Owner diff --git a/upload/tests/views/test_bundle_analysis.py b/upload/tests/views/test_bundle_analysis.py index 1559a57521..8703e9c754 100644 --- a/upload/tests/views/test_bundle_analysis.py +++ b/upload/tests/views/test_bundle_analysis.py @@ -16,7 +16,7 @@ @pytest.mark.django_db(databases={"default", "timeseries"}) -def test_upload_bundle_analysis(db, client, mocker, mock_redis): +def test_upload_bundle_analysis_success(db, client, mocker, mock_redis): upload = mocker.patch.object(TaskService, "upload") mock_sentry_metrics = mocker.patch( "upload.views.bundle_analysis.sentry_metrics.incr" @@ -41,6 +41,7 @@ def test_upload_bundle_analysis(db, client, mocker, mock_redis): "buildURL": "test-build-url", "job": "test-job", "service": "test-service", + "compareSha": "6fd5b89357fc8cdf34d6197549ac7c6d7e5aaaaa", }, format="json", headers={"User-Agent": "codecov-cli/0.4.7"}, @@ -73,6 +74,7 @@ def test_upload_bundle_analysis(db, client, mocker, mock_redis): "url": f"v1/uploads/{reportid}.json", "commit": commit_sha, "report_code": None, + "bundle_analysis_compare_sha": "6fd5b89357fc8cdf34d6197549ac7c6d7e5aaaaa", } # sets latest upload timestamp @@ -96,14 +98,15 @@ def test_upload_bundle_analysis(db, client, mocker, mock_redis): "endpoint": "bundle_analysis", "repo_visibility": "private", "is_using_shelter": "no", + "position": "end", }, ) @pytest.mark.django_db(databases={"default", "timeseries"}) def test_upload_bundle_analysis_org_token(db, client, mocker, mock_redis): - upload = mocker.patch.object(TaskService, "upload") - create_presigned_put = mocker.patch( + mocker.patch.object(TaskService, "upload") + mocker.patch( "services.archive.StorageService.create_presigned_put", return_value="test-presigned-put", ) @@ -128,7 +131,7 @@ def test_upload_bundle_analysis_org_token(db, client, mocker, mock_redis): @pytest.mark.django_db(databases={"default", "timeseries"}) def test_upload_bundle_analysis_existing_commit(db, client, mocker, mock_redis): upload = mocker.patch.object(TaskService, "upload") - create_presigned_put = mocker.patch( + mocker.patch( "services.archive.StorageService.create_presigned_put", return_value="test-presigned-put", ) @@ -160,7 +163,7 @@ def test_upload_bundle_analysis_existing_commit(db, client, mocker, mock_redis): def test_upload_bundle_analysis_missing_args(db, client, mocker, mock_redis): upload = mocker.patch.object(TaskService, "upload") - create_presigned_put = mocker.patch( + mocker.patch( "services.archive.StorageService.create_presigned_put", return_value="test-presigned-put", ) @@ -196,7 +199,7 @@ def test_upload_bundle_analysis_missing_args(db, client, mocker, mock_redis): def test_upload_bundle_analysis_invalid_token(db, client, mocker, mock_redis): upload = mocker.patch.object(TaskService, "upload") - create_presigned_put = mocker.patch( + mocker.patch( "services.archive.StorageService.create_presigned_put", return_value="test-presigned-put", ) @@ -215,10 +218,7 @@ def test_upload_bundle_analysis_invalid_token(db, client, mocker, mock_redis): format="json", ) assert res.status_code == 401 - assert res.json() == { - "detail": "Failed token authentication, please double-check that your repository token matches in the Codecov UI, " - "or review the docs https://docs.codecov.com/docs/adding-the-codecov-token" - } + assert res.json() == {"detail": "Not valid tokenless upload"} assert not upload.called @@ -348,3 +348,205 @@ def test_upload_bundle_analysis_measurement_timeseries_disabled( name=measurement_type.value, repository_id=repository.pk, ).exists() + + +@pytest.mark.django_db(databases={"default", "timeseries"}) +def test_upload_bundle_analysis_no_repo(db, client, mocker, mock_redis): + upload = mocker.patch.object(TaskService, "upload") + mocker.patch.object(TaskService, "upload") + mocker.patch( + "services.archive.StorageService.create_presigned_put", + return_value="test-presigned-put", + ) + + repository = RepositoryFactory.create() + org_token = OrganizationLevelTokenFactory.create(owner=repository.author) + + client = APIClient() + client.credentials(HTTP_AUTHORIZATION=f"token {org_token.token}") + + res = client.post( + reverse("upload-bundle-analysis"), + { + "commit": "6fd5b89357fc8cdf34d6197549ac7c6d7e5977ef", + "slug": "FakeUser::::NonExistentName", + }, + format="json", + ) + assert res.status_code == 404 + assert res.json() == {"detail": "Repository not found."} + assert not upload.called + + +@pytest.mark.django_db(databases={"default", "timeseries"}) +def test_upload_bundle_analysis_tokenless_success(db, client, mocker, mock_redis): + upload = mocker.patch.object(TaskService, "upload") + + create_presigned_put = mocker.patch( + "services.archive.StorageService.create_presigned_put", + return_value="test-presigned-put", + ) + + repository = RepositoryFactory.create(private=False) + commit_sha = "6fd5b89357fc8cdf34d6197549ac7c6d7e5977ef" + + client = APIClient() + + res = client.post( + reverse("upload-bundle-analysis"), + { + "commit": commit_sha, + "slug": f"{repository.author.username}::::{repository.name}", + "build": "test-build", + "buildURL": "test-build-url", + "job": "test-job", + "service": "test-service", + "compareSha": "6fd5b89357fc8cdf34d6197549ac7c6d7e5aaaaa", + "branch": "f1:main", + "git_service": "github", + }, + format="json", + headers={"User-Agent": "codecov-cli/0.4.7"}, + ) + + assert res.status_code == 201 + + # returns presigned storage URL + assert res.json() == {"url": "test-presigned-put"} + + assert upload.called + create_presigned_put.assert_called_once_with("bundle-analysis", ANY, 30) + + +@pytest.mark.django_db(databases={"default", "timeseries"}) +def test_upload_bundle_analysis_tokenless_no_repo(db, client, mocker, mock_redis): + upload = mocker.patch.object(TaskService, "upload") + + repository = RepositoryFactory.create(private=False) + commit_sha = "6fd5b89357fc8cdf34d6197549ac7c6d7e5977ef" + + client = APIClient() + + res = client.post( + reverse("upload-bundle-analysis"), + { + "commit": commit_sha, + "slug": f"fakerepo::::{repository.name}", + "build": "test-build", + "buildURL": "test-build-url", + "job": "test-job", + "service": "test-service", + "compareSha": "6fd5b89357fc8cdf34d6197549ac7c6d7e5aaaaa", + "branch": "f1:main", + "git_service": "github", + }, + format="json", + headers={"User-Agent": "codecov-cli/0.4.7"}, + ) + + assert res.status_code == 401 + assert res.json() == {"detail": "Not valid tokenless upload"} + assert not upload.called + + +@pytest.mark.django_db(databases={"default", "timeseries"}) +def test_upload_bundle_analysis_tokenless_no_git_service( + db, client, mocker, mock_redis +): + upload = mocker.patch.object(TaskService, "upload") + + repository = RepositoryFactory.create(private=False) + commit_sha = "6fd5b89357fc8cdf34d6197549ac7c6d7e5977ef" + + client = APIClient() + + res = client.post( + reverse("upload-bundle-analysis"), + { + "commit": commit_sha, + "slug": f"{repository.author.username}::::{repository.name}", + "build": "test-build", + "buildURL": "test-build-url", + "job": "test-job", + "service": "test-service", + "compareSha": "6fd5b89357fc8cdf34d6197549ac7c6d7e5aaaaa", + "branch": "f1:main", + "git_service": "fakegitservice", + }, + format="json", + headers={"User-Agent": "codecov-cli/0.4.7"}, + ) + + assert res.status_code == 401 + assert res.json() == {"detail": "Not valid tokenless upload"} + assert not upload.called + + +@pytest.mark.django_db(databases={"default", "timeseries"}) +def test_upload_bundle_analysis_tokenless_bad_json(db, client, mocker, mock_redis): + upload = mocker.patch.object(TaskService, "upload") + + repository = RepositoryFactory.create(private=False) + commit_sha = "6fd5b89357fc8cdf34d6197549ac7c6d7e5977ef" + + from json import JSONDecodeError + + with patch( + "codecov_auth.authentication.repo_auth.json.loads", + side_effect=JSONDecodeError("mocked error", doc="doc", pos=0), + ): + client = APIClient() + + res = client.post( + reverse("upload-bundle-analysis"), + { + "commit": commit_sha, + "slug": f"{repository.author.username}::::{repository.name}", + "build": "test-build", + "buildURL": "test-build-url", + "job": "test-job", + "service": "test-service", + "compareSha": "6fd5b89357fc8cdf34d6197549ac7c6d7e5aaaaa", + "branch": "f1:main", + "git_service": "github", + }, + format="json", + headers={"User-Agent": "codecov-cli/0.4.7"}, + ) + + assert res.status_code == 401 + assert not upload.called + + +@pytest.mark.django_db(databases={"default", "timeseries"}) +def test_upload_bundle_analysis_tokenless_mismatched_branch( + db, client, mocker, mock_redis +): + upload = mocker.patch.object(TaskService, "upload") + + commit_sha = "6fd5b89357fc8cdf34d6197549ac7c6d7e5977ef" + repository = RepositoryFactory.create(private=False) + CommitFactory.create(repository=repository, commitid=commit_sha, branch="main") + + client = APIClient() + + res = client.post( + reverse("upload-bundle-analysis"), + { + "commit": commit_sha, + "slug": f"{repository.author.username}::::{repository.name}", + "build": "test-build", + "buildURL": "test-build-url", + "job": "test-job", + "service": "test-service", + "compareSha": "6fd5b89357fc8cdf34d6197549ac7c6d7e5aaaaa", + "branch": "f1:main", + "git_service": "github", + }, + format="json", + headers={"User-Agent": "codecov-cli/0.4.7"}, + ) + + assert res.status_code == 401 + assert res.json() == {"detail": "Not valid tokenless upload"} + assert not upload.called diff --git a/upload/tests/views/test_commits.py b/upload/tests/views/test_commits.py index 840fa07187..bfb94dd717 100644 --- a/upload/tests/views/test_commits.py +++ b/upload/tests/views/test_commits.py @@ -1,4 +1,4 @@ -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import patch import pytest from django.urls import reverse @@ -7,7 +7,6 @@ from core.models import Commit from core.tests.factories import CommitFactory, RepositoryFactory -from services.repo_providers import RepoProviderService from services.task import TaskService from upload.views.commits import CommitViews @@ -33,7 +32,7 @@ def test_get_repo_with_invalid_service(): def test_get_repo_not_found(db): # Making sure that owner has different repos and getting none when the name of the repo isn't correct - repository = RepositoryFactory( + RepositoryFactory( name="the_repo", author__username="codecov", author__service="github" ) upload_views = CommitViews() @@ -198,7 +197,7 @@ def test_create_commit_already_exists(db, client, mocker): } assert response.status_code == 201 assert expected_response == response_json - mocked_call.assert_called_with(commitid=commit.commitid, repoid=repository.repoid) + mocked_call.assert_not_called() @pytest.mark.parametrize("branch", ["main", "someone:main", "someone/fork:main"]) @@ -317,5 +316,6 @@ def test_commit_github_oidc_auth(mock_jwks_client, mock_jwt_decode, db, mocker): "endpoint": "create_commit", "repo_visibility": "public", "is_using_shelter": "no", + "position": "end", }, ) diff --git a/upload/tests/views/test_empty_upload.py b/upload/tests/views/test_empty_upload.py index f7bfadcc44..6b0c9089c9 100644 --- a/upload/tests/views/test_empty_upload.py +++ b/upload/tests/views/test_empty_upload.py @@ -99,6 +99,7 @@ def test_empty_upload_with_yaml_ignored_files( "endpoint": "empty_upload", "repo_visibility": "private", "is_using_shelter": "no", + "position": "end", }, ) @@ -290,7 +291,6 @@ def test_empty_upload_with_testable_file_invalid_serializer( ], ) response = client.post(url, data={"should_force": "hello world"}) - response_json = response.json() assert response.status_code == 400 diff --git a/upload/tests/views/test_reports.py b/upload/tests/views/test_reports.py index d6478ccdcc..42a3b34f6c 100644 --- a/upload/tests/views/test_reports.py +++ b/upload/tests/views/test_reports.py @@ -1,4 +1,4 @@ -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import patch import pytest from django.urls import reverse @@ -8,7 +8,6 @@ from core.tests.factories import CommitFactory, RepositoryFactory from reports.models import CommitReport, ReportResults from reports.tests.factories import ReportResultsFactory -from services.repo_providers import RepoProviderService from services.task.task import TaskService from upload.views.uploads import CanDoCoverageUploadsPermission @@ -72,6 +71,7 @@ def test_reports_post(client, db, mocker): "endpoint": "create_report", "repo_visibility": "private", "is_using_shelter": "no", + "position": "end", }, ) @@ -177,7 +177,7 @@ def test_create_report_already_exists(client, db, mocker): name="the_repo", author__username="codecov", author__service="github" ) commit = CommitFactory(repository=repository) - report = CommitReport.objects.create(commit=commit, code="code") + CommitReport.objects.create(commit=commit, code="code") repository.save() client = APIClient() @@ -195,7 +195,7 @@ def test_create_report_already_exists(client, db, mocker): assert CommitReport.objects.filter( commit_id=commit.id, code="code", report_type=CommitReport.ReportType.COVERAGE ).exists() - mocked_call.assert_called_once() + mocked_call.assert_not_called() def test_reports_post_code_as_default(client, db, mocker): @@ -311,6 +311,7 @@ def test_reports_results_post_successful_github_oidc_auth( "endpoint": "create_report_results", "repo_visibility": "private", "is_using_shelter": "no", + "position": "end", }, ) @@ -401,7 +402,7 @@ def test_report_results_get_unsuccessful(client, db, mocker): name="the_repo", author__username="codecov", author__service="github" ) commit = CommitFactory(repository=repository) - commit_report = CommitReport.objects.create(commit=commit, code="code") + CommitReport.objects.create(commit=commit, code="code") repository.save() client = APIClient() diff --git a/upload/tests/views/test_test_results.py b/upload/tests/views/test_test_results.py index af473ce2df..9035c4e9d3 100644 --- a/upload/tests/views/test_test_results.py +++ b/upload/tests/views/test_test_results.py @@ -36,6 +36,8 @@ def test_upload_test_results(db, client, mocker, mock_redis): "buildURL": "test-build-url", "job": "test-job", "service": "test-service", + "ci_service": "github-actions", + "branch": "aaaaaa", }, format="json", headers={"User-Agent": "codecov-cli/0.4.7"}, @@ -63,6 +65,7 @@ def test_upload_test_results(db, client, mocker, mock_redis): # creates commit commit = Commit.objects.get(commitid=commit_sha) assert commit + assert commit.branch is not None # saves args in Redis redis = get_redis_connection() @@ -72,7 +75,7 @@ def test_upload_test_results(db, client, mocker, mock_redis): "build": "test-build", "build_url": "test-build-url", "job": "test-job", - "service": "test-service", + "service": "github-actions", "url": f"test_results/v1/raw/{date}/{repo_hash}/{commit_sha}/{reportid}.txt", "commit": commit_sha, "report_code": None, @@ -100,13 +103,14 @@ def test_upload_test_results(db, client, mocker, mock_redis): "endpoint": "test_results", "repo_visibility": "private", "is_using_shelter": "no", + "position": "end", }, ) def test_test_results_org_token(db, client, mocker, mock_redis): - upload = mocker.patch.object(TaskService, "upload") - create_presigned_put = mocker.patch( + mocker.patch.object(TaskService, "upload") + mocker.patch( "services.archive.StorageService.create_presigned_put", return_value="test-presigned-put", ) @@ -123,6 +127,7 @@ def test_test_results_org_token(db, client, mocker, mock_redis): { "commit": "6fd5b89357fc8cdf34d6197549ac7c6d7e5977ef", "slug": f"{repository.author.username}::::{repository.name}", + "branch": "aaaaaa", }, format="json", ) @@ -157,6 +162,7 @@ def test_test_results_github_oidc_token( { "commit": "6fd5b89357fc8cdf34d6197549ac7c6d7e5977ef", "slug": f"{repository.author.username}::::{repository.name}", + "branch": "aaaaaa", }, format="json", ) @@ -180,16 +186,39 @@ def test_test_results_no_auth(db, client, mocker, mock_redis): format="json", ) assert res.status_code == 401 - assert ( - res.json().get("detail") - == "Failed token authentication, please double-check that your repository token matches in the Codecov UI, " - "or review the docs https://docs.codecov.com/docs/adding-the-codecov-token" + assert res.json().get("detail") == "Not valid tokenless upload" + + +def test_upload_test_results_no_repo(db, client, mocker, mock_redis): + upload = mocker.patch.object(TaskService, "upload") + mocker.patch.object(TaskService, "upload") + mocker.patch( + "services.archive.StorageService.create_presigned_put", + return_value="test-presigned-put", + ) + + repository = RepositoryFactory.create() + org_token = OrganizationLevelTokenFactory.create(owner=repository.author) + + client = APIClient() + client.credentials(HTTP_AUTHORIZATION=f"token {org_token.token}") + + res = client.post( + reverse("upload-test-results"), + { + "commit": "6fd5b89357fc8cdf34d6197549ac7c6d7e5977ef", + "slug": "FakeUser::::NonExistentName", + }, + format="json", ) + assert res.status_code == 404 + assert res.json() == {"detail": "Repository not found."} + assert not upload.called def test_upload_test_results_missing_args(db, client, mocker, mock_redis): upload = mocker.patch.object(TaskService, "upload") - create_presigned_put = mocker.patch( + mocker.patch( "services.archive.StorageService.create_presigned_put", return_value="test-presigned-put", ) @@ -223,11 +252,70 @@ def test_upload_test_results_missing_args(db, client, mocker, mock_redis): assert not upload.called +def test_upload_test_results_missing_branch_no_commit(db, client, mocker, mock_redis): + upload = mocker.patch.object(TaskService, "upload") + mocker.patch( + "services.archive.StorageService.create_presigned_put", + return_value="test-presigned-put", + ) + + repository = RepositoryFactory.create() + + client = APIClient() + client.credentials(HTTP_AUTHORIZATION=f"token {repository.upload_token}") + + commit_sha = "aaaaaa" + res = client.post( + reverse("upload-test-results"), + { + "commit": "aaaaaa", + "slug": f"{repository.author.username}::::{repository.name}", + }, + format="json", + ) + assert res.status_code == 201 + + assert upload.called + + commit = Commit.objects.get(commitid=commit_sha) + assert commit.branch is not None + + +def test_upload_test_results_branch_none_no_commit(db, client, mocker, mock_redis): + upload = mocker.patch.object(TaskService, "upload") + mocker.patch( + "services.archive.StorageService.create_presigned_put", + return_value="test-presigned-put", + ) + + repository = RepositoryFactory.create() + + client = APIClient() + client.credentials(HTTP_AUTHORIZATION=f"token {repository.upload_token}") + + commit_sha = "aaaaaa" + res = client.post( + reverse("upload-test-results"), + { + "commit": "aaaaaa", + "slug": f"{repository.author.username}::::{repository.name}", + "branch": None, + }, + format="json", + ) + assert res.status_code == 201 + + assert upload.called + + commit = Commit.objects.get(commitid=commit_sha) + assert commit.branch is not None + + def test_update_repo_fields_when_upload_is_triggered( db, client, mocker, mock_redis ) -> None: - upload = mocker.patch.object(TaskService, "upload") - create_presigned_put = mocker.patch( + mocker.patch.object(TaskService, "upload") + mocker.patch( "services.archive.StorageService.create_presigned_put", return_value="test-presigned-put", ) diff --git a/upload/tests/views/test_upload_completion.py b/upload/tests/views/test_upload_completion.py index 59ba798d49..b1bf665259 100644 --- a/upload/tests/views/test_upload_completion.py +++ b/upload/tests/views/test_upload_completion.py @@ -1,6 +1,8 @@ from unittest.mock import patch +from django.http import HttpResponse from django.urls import reverse +from rest_framework import status from rest_framework.test import APIClient from core.tests.factories import CommitFactory, RepositoryFactory @@ -95,6 +97,7 @@ def test_upload_completion_view_processed_uploads(mocked_manual_trigger, db, moc "endpoint": "upload_complete", "repo_visibility": "private", "is_using_shelter": "no", + "position": "end", }, ) @@ -188,6 +191,48 @@ def test_upload_completion_view_no_auth(db, mocker): ) +@patch("codecov_auth.authentication.repo_auth.exception_handler") +def test_upload_completion_view_repo_auth_custom_exception_handler_error( + customized_error, db, mocker +): + mocked_response = HttpResponse( + "No content posted.", + status=status.HTTP_401_UNAUTHORIZED, + content_type="application/json", + ) + mocked_response.data = "invalid" + customized_error.return_value = mocked_response + repository = RepositoryFactory( + name="the_repo", author__username="codecov", author__service="github" + ) + token = "BAD" + commit = CommitFactory(repository=repository) + report = CommitReportFactory(commit=commit) + upload1 = UploadFactory(report=report) + upload2 = UploadFactory(report=report) + repository.save() + commit.save() + report.save() + upload1.save() + upload2.save() + + client = APIClient() + client.credentials(HTTP_AUTHORIZATION=f"token {token}") + url = reverse( + "new_upload.upload-complete", + args=[ + "github", + "codecov::::the_repo", + commit.commitid, + ], + ) + response = client.post( + url, + ) + assert response.status_code == 401 + assert response == mocked_response + + @patch("services.task.TaskService.manual_upload_completion_trigger") def test_upload_completion_view_still_processing_uploads( mocked_manual_trigger, db, mocker diff --git a/upload/tests/views/test_uploads.py b/upload/tests/views/test_uploads.py index a94e8b1f91..ca287ee4dc 100644 --- a/upload/tests/views/test_uploads.py +++ b/upload/tests/views/test_uploads.py @@ -1,11 +1,12 @@ -from unittest.mock import AsyncMock, MagicMock, call, patch +from unittest.mock import MagicMock, call, patch import pytest from django.conf import settings from django.test import override_settings from django.urls import reverse from rest_framework.exceptions import ValidationError -from rest_framework.test import APIClient +from rest_framework.test import APIClient, APITestCase +from shared.utils.test_utils import mock_config_helper from codecov_auth.authentication.repo_auth import OrgLevelTokenRepositoryAuth from codecov_auth.services.org_level_token_service import OrgLevelTokenService @@ -19,7 +20,6 @@ ) from reports.tests.factories import CommitReportFactory, UploadFactory from services.archive import ArchiveService, MinioEndpoints -from services.repo_providers import RepoProviderService from upload.views.uploads import CanDoCoverageUploadsPermission, UploadViews @@ -76,7 +76,7 @@ def test_uploads_get_not_allowed(client, db, mocker): repository = RepositoryFactory( name="the-repo", author__username="codecov", author__service="github" ) - commit = CommitFactory(repository=repository, commitid="commit-sha") + CommitFactory(repository=repository, commitid="commit-sha") owner = repository.author client = APIClient() client.force_authenticate(user=owner) @@ -568,7 +568,7 @@ def test_uploads_post_shelter(db, mocker, mock_redis): "services.archive.StorageService.create_presigned_put", return_value="presigned put", ) - upload_task_mock = mocker.patch( + mocker.patch( "upload.views.uploads.UploadViews.trigger_upload_task", return_value=True ) mock_sentry_metrics = mocker.patch("upload.views.uploads.sentry_metrics.incr") @@ -617,6 +617,7 @@ def test_uploads_post_shelter(db, mocker, mock_redis): "endpoint": "create_upload", "repo_visibility": "private", "is_using_shelter": "yes", + "position": "end", }, ) @@ -637,7 +638,7 @@ def test_uploads_post_shelter(db, mocker, mock_redis): report_id=commit_report.id, upload_extras={"format_version": "v1"} ).first() assert response.status_code == 201 - archive_service = ArchiveService(repository) + ArchiveService(repository) assert upload.storage_path == "shelter/test/path.txt" presigned_put_mock.assert_called_with("archive", upload.storage_path, 10) @@ -714,3 +715,75 @@ def test_activate_already_activated_repo(db): upload_views = UploadViews() upload_views.activate_repo(repo) assert repo.active + + +class TestGitlabEnterpriseOIDC(APITestCase): + @pytest.fixture(scope="function", autouse=True) + def inject_mocker(request, mocker): + request.mocker = mocker + + @pytest.fixture(autouse=True) + def mock_config(self, mocker): + mock_config_helper( + mocker, configs={"github_enterprise.url": "https://example.com/"} + ) + + @patch("upload.views.uploads.AnalyticsService") + @patch("upload.helpers.jwt.decode") + @patch("upload.helpers.PyJWKClient") + @patch("shared.metrics.metrics.incr") + def test_uploads_post_github_enterprise_oidc_auth_jwks_url( + self, + mock_metrics, + mock_jwks_client, + mock_jwt_decode, + analytics_service_mock, + ): + self.mocker.patch( + "services.archive.StorageService.create_presigned_put", + return_value="presigned put", + ) + self.mocker.patch( + "upload.views.uploads.UploadViews.trigger_upload_task", return_value=True + ) + + repository = RepositoryFactory( + name="the_repo", + author__username="codecov", + author__service="github_enterprise", + private=False, + ) + mock_jwt_decode.return_value = { + "repository": f"url/{repository.name}", + "repository_owner": repository.author.username, + "iss": "https://enterprise-client.actions.githubusercontent.com", + "audience": [settings.CODECOV_API_URL], + } + token = "ThisValueDoesNotMatterBecauseOf_mock_jwt_decode" + + commit = CommitFactory(repository=repository) + commit_report = CommitReport.objects.create(commit=commit, code="code") + + client = APIClient() + url = reverse( + "new_upload.uploads", + args=[ + "github_enterprise", + "codecov::::the_repo", + commit.commitid, + commit_report.code, + ], + ) + response = client.post( + url, + { + "state": "uploaded", + "flags": ["flag1", "flag2"], + "version": "version", + }, + headers={"Authorization": f"token {token}"}, + ) + assert response.status_code == 201 + mock_jwks_client.assert_called_with( + "https://example.com/_services/token/.well-known/jwks" + ) diff --git a/upload/tokenless/appveyor.py b/upload/tokenless/appveyor.py index 6b54a81a94..0a1fca232f 100644 --- a/upload/tokenless/appveyor.py +++ b/upload/tokenless/appveyor.py @@ -1,5 +1,4 @@ import logging -from datetime import datetime, timedelta import requests from requests.exceptions import ConnectionError, HTTPError diff --git a/upload/tokenless/azure.py b/upload/tokenless/azure.py index ef8e194928..9e6da9bded 100644 --- a/upload/tokenless/azure.py +++ b/upload/tokenless/azure.py @@ -82,7 +82,7 @@ def verify(self): finishTimestamp, "%Y-%m-%d %H:%M:%S.%f" ) finishTimeWithBuffer = buildFinishDateObj + timedelta(minutes=4) - now = datetime.utcnow() + now = datetime.now() if not now <= finishTimeWithBuffer: raise NotFound( "Azure build has already finished. Please upload with the Codecov repository upload token to resolve issue." diff --git a/upload/tokenless/circleci.py b/upload/tokenless/circleci.py index f473b8ede4..1aea3cb27e 100644 --- a/upload/tokenless/circleci.py +++ b/upload/tokenless/circleci.py @@ -1,5 +1,4 @@ import logging -from datetime import datetime, timedelta import requests from django.conf import settings diff --git a/upload/tokenless/github_actions.py b/upload/tokenless/github_actions.py index dd61c2919f..2a14c4d999 100644 --- a/upload/tokenless/github_actions.py +++ b/upload/tokenless/github_actions.py @@ -93,7 +93,7 @@ def verify(self): or build["slug"] != f"{owner}/{repo}" or ( build["commit_sha"] != self.upload_params.get("commit") - and self.upload_params.get("pr") == None + and self.upload_params.get("pr") is None ) ): self.log_warning( @@ -116,7 +116,7 @@ def verify(self): ) finish_time_with_buffer = build_finish_date_obj + timedelta(minutes=10) - now = datetime.utcnow() + now = datetime.now() if not now <= finish_time_with_buffer: log.warning( "Actions workflow run is stale", diff --git a/upload/tokenless/tokenless.py b/upload/tokenless/tokenless.py index 9d95435acc..54861e4048 100644 --- a/upload/tokenless/tokenless.py +++ b/upload/tokenless/tokenless.py @@ -1,12 +1,5 @@ import logging -import os -from datetime import datetime, timedelta -from json import load -import requests -from django.http import HttpResponse -from requests.exceptions import ConnectionError, HTTPError -from rest_framework import status from rest_framework.exceptions import NotFound from upload.tokenless.appveyor import TokenlessAppveyorHandler @@ -46,7 +39,7 @@ def verify_upload(self): ) try: return self.verifier(self.upload_params).verify() - except TypeError as e: + except TypeError: raise NotFound( "Your CI provider is not compatible with tokenless uploads, please upload using your repository token to resolve this." ) diff --git a/upload/tokenless/travis.py b/upload/tokenless/travis.py index 0c29a90fac..6f3b50cd5a 100644 --- a/upload/tokenless/travis.py +++ b/upload/tokenless/travis.py @@ -1,6 +1,5 @@ import logging from datetime import datetime, timedelta -from json import load import requests from requests.exceptions import ConnectionError, HTTPError @@ -118,11 +117,11 @@ def verify(self): ) # Verify job finished within the last 4 minutes or is still in progress - if job["finished_at"] != None: + if job["finished_at"] is not None: finishTimestamp = job["finished_at"].replace("T", " ").replace("Z", "") buildFinishDateObj = datetime.strptime(finishTimestamp, "%Y-%m-%d %H:%M:%S") finishTimeWithBuffer = buildFinishDateObj + timedelta(minutes=4) - now = datetime.utcnow() + now = datetime.now() if not now <= finishTimeWithBuffer: log.warning( "Cancelling upload: 4 mins since build", diff --git a/upload/views/bundle_analysis.py b/upload/views/bundle_analysis.py index f0807bee24..6b2d095501 100644 --- a/upload/views/bundle_analysis.py +++ b/upload/views/bundle_analysis.py @@ -1,9 +1,11 @@ import logging import uuid +from typing import Any, Callable from django.conf import settings +from django.http import HttpRequest from rest_framework import serializers, status -from rest_framework.exceptions import NotAuthenticated +from rest_framework.exceptions import NotAuthenticated, NotFound from rest_framework.permissions import BasePermission from rest_framework.response import Response from rest_framework.views import APIView @@ -11,6 +13,7 @@ from shared.bundle_analysis.storage import StoragePaths, get_bucket_name from codecov_auth.authentication.repo_auth import ( + BundleAnalysisTokenlessAuthentication, GitHubOIDCTokenAuthentication, OrgLevelTokenAuthentication, RepositoryLegacyTokenAuthentication, @@ -31,7 +34,7 @@ class UploadBundleAnalysisPermission(BasePermission): - def has_permission(self, request, view): + def has_permission(self, request: HttpRequest, view: Any) -> bool: return request.auth is not None and "upload" in request.auth.get_scopes() @@ -44,6 +47,8 @@ class UploadSerializer(serializers.Serializer): pr = serializers.CharField(required=False, allow_null=True) service = serializers.CharField(required=False, allow_null=True) branch = serializers.CharField(required=False, allow_null=True) + compareSha = serializers.CharField(required=False, allow_null=True) + git_service = serializers.CharField(required=False, allow_null=True) class BundleAnalysisView(APIView, ShelterMixin): @@ -52,12 +57,23 @@ class BundleAnalysisView(APIView, ShelterMixin): OrgLevelTokenAuthentication, GitHubOIDCTokenAuthentication, RepositoryLegacyTokenAuthentication, + BundleAnalysisTokenlessAuthentication, ] - def get_exception_handler(self): + def get_exception_handler(self) -> Callable: return repo_auth_custom_exception_handler - def post(self, request): + def post(self, request: HttpRequest) -> Response: + sentry_metrics.incr( + "upload", + tags=generate_upload_sentry_metrics_tags( + action="bundle_analysis", + endpoint="bundle_analysis", + request=self.request, + is_shelter_request=self.is_shelter_request(), + position="start", + ), + ) serializer = UploadSerializer(data=request.data) if not serializer.is_valid(): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) @@ -73,6 +89,9 @@ def post(self, request): else: raise NotAuthenticated() + if repo is None: + raise NotFound("Repository not found.") + update_fields = [] if not repo.active or not repo.activated: repo.active = True @@ -117,6 +136,8 @@ def post(self, request): # these are used for dispatching the task below "commit": commit.commitid, "report_code": None, + # custom comparison sha for the current uploaded commit sha + "bundle_analysis_compare_sha": data.get("compareSha"), } log.info( @@ -128,12 +149,6 @@ def post(self, request): ), ) - dispatch_upload_task( - task_arguments, - repo, - get_redis_connection(), - report_type=CommitReport.ReportType.BUNDLE_ANALYSIS, - ) sentry_metrics.incr( "upload", tags=generate_upload_sentry_metrics_tags( @@ -142,9 +157,17 @@ def post(self, request): request=self.request, repository=repo, is_shelter_request=self.is_shelter_request(), + position="end", ), ) + dispatch_upload_task( + task_arguments, + repo, + get_redis_connection(), + report_type=CommitReport.ReportType.BUNDLE_ANALYSIS, + ) + if settings.TIMESERIES_ENABLED: supported_bundle_analysis_measurement_types = [ MeasurementName.BUNDLE_ANALYSIS_ASSET_SIZE, diff --git a/upload/views/commits.py b/upload/views/commits.py index c620b4f9cf..bf97e61fcc 100644 --- a/upload/views/commits.py +++ b/upload/views/commits.py @@ -1,6 +1,6 @@ import logging -from rest_framework.exceptions import NotAuthenticated, ValidationError +from rest_framework.exceptions import NotAuthenticated from rest_framework.generics import ListCreateAPIView from sentry_sdk import metrics as sentry_metrics @@ -13,7 +13,6 @@ repo_auth_custom_exception_handler, ) from core.models import Commit -from services.task import TaskService from upload.helpers import generate_upload_sentry_metrics_tags from upload.serializers import CommitSerializer from upload.views.base import GetterMixin @@ -52,15 +51,25 @@ def create(self, request, *args, **kwargs): return super().create(request, *args, **kwargs) def perform_create(self, serializer): + sentry_metrics.incr( + "upload", + tags=generate_upload_sentry_metrics_tags( + action="coverage", + endpoint="create_commit", + request=self.request, + is_shelter_request=self.is_shelter_request(), + position="start", + ), + ) repository = self.get_repo() + commit = serializer.save(repository=repository) + log.info( "Request to create new commit", extra=dict(repo=repository.name, commit=commit.commitid), ) - TaskService().update_commit( - commitid=commit.commitid, repoid=commit.repository.repoid - ) + sentry_metrics.incr( "upload", tags=generate_upload_sentry_metrics_tags( @@ -69,6 +78,8 @@ def perform_create(self, serializer): request=self.request, repository=repository, is_shelter_request=self.is_shelter_request(), + position="end", ), ) + return commit diff --git a/upload/views/empty_upload.py b/upload/views/empty_upload.py index 0432b9f3e8..bf8ad4573d 100644 --- a/upload/views/empty_upload.py +++ b/upload/views/empty_upload.py @@ -81,6 +81,16 @@ def get_exception_handler(self): return repo_auth_custom_exception_handler def post(self, request, *args, **kwargs): + sentry_metrics.incr( + "upload", + tags=generate_upload_sentry_metrics_tags( + action="coverage", + endpoint="empty_upload", + request=self.request, + is_shelter_request=self.is_shelter_request(), + position="start", + ), + ) serializer = EmptyUploadSerializer(data=request.data) if not serializer.is_valid(): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) @@ -141,6 +151,7 @@ def post(self, request, *args, **kwargs): request=self.request, repository=repo, is_shelter_request=self.is_shelter_request(), + position="end", ), ) if set(changed_files) == set(ignored_changed_files): diff --git a/upload/views/legacy.py b/upload/views/legacy.py index 996f78abd0..3922fd4c74 100644 --- a/upload/views/legacy.py +++ b/upload/views/legacy.py @@ -1,14 +1,11 @@ import asyncio import logging import re -from contextlib import suppress -from datetime import datetime from json import dumps from uuid import uuid4 import minio from django.conf import settings -from django.contrib.auth.models import AnonymousUser from django.core.exceptions import MultipleObjectsReturned from django.http import Http404, HttpResponse, HttpResponseServerError from django.utils import timezone @@ -16,7 +13,7 @@ from django.utils.encoding import smart_str from django.views import View from rest_framework import renderers, status -from rest_framework.exceptions import APIException, ValidationError +from rest_framework.exceptions import ValidationError from rest_framework.permissions import AllowAny from rest_framework.views import APIView from sentry_sdk import metrics as sentry_metrics @@ -78,6 +75,17 @@ def options(self, request, *args, **kwargs): def post(self, request, *args, **kwargs): # Extract the version version = self.kwargs["version"] + sentry_metrics.incr( + "upload", + tags=generate_upload_sentry_metrics_tags( + action="coverage", + endpoint="legacy_upload", + request=self.request, + is_shelter_request=self.is_shelter_request(), + position="start", + upload_version=version, + ), + ) log.info( f"Received upload request {version}", @@ -135,12 +143,12 @@ def post(self, request, *args, **kwargs): try: repository = determine_repo_for_upload(upload_params) owner = repository.author - except ValidationError as e: + except ValidationError: response.status_code = status.HTTP_400_BAD_REQUEST response.content = "Could not determine repo and owner" metrics.incr("uploads.rejected", 1) return response - except MultipleObjectsReturned as e: + except MultipleObjectsReturned: response.status_code = status.HTTP_400_BAD_REQUEST response.content = "Found too many repos" metrics.incr("uploads.rejected", 1) @@ -157,23 +165,29 @@ def post(self, request, *args, **kwargs): ), ) - sentry_tags = generate_upload_sentry_metrics_tags( - action="coverage", - endpoint="legacy_upload", - request=self.request, - repository=repository, - is_shelter_request=self.is_shelter_request(), - ) - sentry_metrics.incr( "upload", - tags=sentry_tags, + tags=generate_upload_sentry_metrics_tags( + action="coverage", + endpoint="legacy_upload", + request=self.request, + repository=repository, + is_shelter_request=self.is_shelter_request(), + position="end", + upload_version=version, + ), ) sentry_metrics.set( "upload_set", repository.author.ownerid, - tags=sentry_tags, + tags=generate_upload_sentry_metrics_tags( + action="coverage", + endpoint="legacy_upload", + request=self.request, + repository=repository, + is_shelter_request=self.is_shelter_request(), + ), ) # Validate the upload to make sure the org has enough repo credits and is allowed to upload for this commit @@ -264,7 +278,7 @@ def post(self, request, *args, **kwargs): ), ) - headers = parse_headers(request.META, upload_params) + parse_headers(request.META, upload_params) archive_service = ArchiveService(repository) # only Shelter requests are allowed to set their own `storage_path` @@ -382,8 +396,12 @@ async def get_repo(self): if owner is None: raise Http404("Requested report could not be found") repo = await RepositoryCommands( - self.request.current_owner, self.service - ).fetch_repository(owner, self.repo_name) + self.request.current_owner, + self.service, + ).fetch_repository( + owner, self.repo_name, [], exclude_okta_enforced_repos=False + ) # Okta sign-in is only enforced on the UI for now. + if repo is None: raise Http404("Requested report could not be found") return repo diff --git a/upload/views/reports.py b/upload/views/reports.py index 7a52c1479f..16c4d67efb 100644 --- a/upload/views/reports.py +++ b/upload/views/reports.py @@ -38,6 +38,16 @@ def get_exception_handler(self): return repo_auth_custom_exception_handler def perform_create(self, serializer): + sentry_metrics.incr( + "upload", + tags=generate_upload_sentry_metrics_tags( + action="coverage", + endpoint="create_report", + request=self.request, + is_shelter_request=self.is_shelter_request(), + position="start", + ), + ) repository = self.get_repo() commit = self.get_commit(repository) log.info( @@ -47,13 +57,15 @@ def perform_create(self, serializer): code = serializer.validated_data.get("code") if code == "default": serializer.validated_data["code"] = None - instance = serializer.save( + instance, was_created = serializer.save( commit_id=commit.id, report_type=CommitReport.ReportType.COVERAGE, ) - TaskService().preprocess_upload( - repository.repoid, commit.commitid, instance.code - ) + if was_created: + TaskService().preprocess_upload( + repository.repoid, commit.commitid, instance.code + ) + sentry_metrics.incr( "upload", tags=generate_upload_sentry_metrics_tags( @@ -62,6 +74,7 @@ def perform_create(self, serializer): request=self.request, repository=repository, is_shelter_request=self.is_shelter_request(), + position="end", ), ) return instance @@ -89,6 +102,16 @@ def get_exception_handler(self): return repo_auth_custom_exception_handler def perform_create(self, serializer): + sentry_metrics.incr( + "upload", + tags=generate_upload_sentry_metrics_tags( + action="coverage", + endpoint="create_report_results", + request=self.request, + is_shelter_request=self.is_shelter_request(), + position="start", + ), + ) repository = self.get_repo() commit = self.get_commit(repository) report = self.get_report(commit) @@ -113,6 +136,7 @@ def perform_create(self, serializer): request=self.request, repository=repository, is_shelter_request=self.is_shelter_request(), + position="end", ), ) return instance diff --git a/upload/views/test_results.py b/upload/views/test_results.py index 8aa4698d71..b6b62c0979 100644 --- a/upload/views/test_results.py +++ b/upload/views/test_results.py @@ -3,7 +3,7 @@ from django.utils import timezone from rest_framework import serializers, status -from rest_framework.exceptions import NotAuthenticated +from rest_framework.exceptions import NotAuthenticated, NotFound from rest_framework.permissions import BasePermission from rest_framework.response import Response from rest_framework.views import APIView @@ -13,6 +13,7 @@ GitHubOIDCTokenAuthentication, OrgLevelTokenAuthentication, RepositoryLegacyTokenAuthentication, + TokenlessAuthentication, repo_auth_custom_exception_handler, ) from codecov_auth.authentication.types import RepositoryAsUser @@ -42,7 +43,8 @@ class UploadSerializer(serializers.Serializer): job = serializers.CharField(required=False) flags = FlagListField(required=False) pr = serializers.CharField(required=False) - service = serializers.CharField(required=False) + branch = serializers.CharField(required=False, allow_null=True) + ci_service = serializers.CharField(required=False) storage_path = serializers.CharField(required=False) @@ -55,12 +57,23 @@ class TestResultsView( OrgLevelTokenAuthentication, GitHubOIDCTokenAuthentication, RepositoryLegacyTokenAuthentication, + TokenlessAuthentication, ] def get_exception_handler(self): return repo_auth_custom_exception_handler def post(self, request): + metrics.incr( + "upload", + tags=generate_upload_sentry_metrics_tags( + action="test_results", + endpoint="test_results", + request=request, + is_shelter_request=self.is_shelter_request(), + position="start", + ), + ) serializer = UploadSerializer(data=request.data) if not serializer.is_valid(): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) @@ -76,6 +89,9 @@ def post(self, request): else: raise NotAuthenticated() + if repo is None: + raise NotFound("Repository not found.") + update_fields = [] if not repo.active or not repo.activated: repo.active = True @@ -97,6 +113,7 @@ def post(self, request): request=request, repository=repo, is_shelter_request=self.is_shelter_request(), + position="end", ), ) @@ -104,7 +121,7 @@ def post(self, request): commitid=data["commit"], repository=repo, defaults={ - "branch": data.get("branch"), + "branch": data.get("branch") or repo.branch, "pullid": data.get("pr"), "merged": False if data.get("pr") is not None else None, "state": "pending", @@ -135,7 +152,7 @@ def post(self, request): "build_url": data.get("buildURL"), # build_url "job": data.get("job"), # job_code "flags": data.get("flags"), - "service": data.get("service"), # provider + "service": data.get("ci_service"), # provider "url": storage_path, # storage_path # these are used for dispatching the task below "commit": commit.commitid, diff --git a/upload/views/upload_completion.py b/upload/views/upload_completion.py index 49d1d7aed5..2fe637ef83 100644 --- a/upload/views/upload_completion.py +++ b/upload/views/upload_completion.py @@ -34,6 +34,16 @@ def get_exception_handler(self): return repo_auth_custom_exception_handler def post(self, request, *args, **kwargs): + sentry_metrics.incr( + "upload", + tags=generate_upload_sentry_metrics_tags( + action="coverage", + endpoint="upload_complete", + request=self.request, + is_shelter_request=self.is_shelter_request(), + position="start", + ), + ) repo = self.get_repo() commit = self.get_commit(repo) uploads_queryset = ReportSession.objects.filter( @@ -76,6 +86,7 @@ def post(self, request, *args, **kwargs): request=self.request, repository=repo, is_shelter_request=self.is_shelter_request(), + position="end", ), ) return Response( diff --git a/upload/views/uploads.py b/upload/views/uploads.py index 2bea0cdf12..5aa7300961 100644 --- a/upload/views/uploads.py +++ b/upload/views/uploads.py @@ -66,28 +66,31 @@ def get_exception_handler(self): return repo_auth_custom_exception_handler def perform_create(self, serializer: UploadSerializer): + sentry_metrics.incr( + "upload", + tags=generate_upload_sentry_metrics_tags( + action="coverage", + endpoint="create_upload", + request=self.request, + is_shelter_request=self.is_shelter_request(), + position="start", + ), + ) repository: Repository = self.get_repo() validate_activated_repo(repository) commit: Commit = self.get_commit(repository) report: CommitReport = self.get_report(commit) - sentry_tags = generate_upload_sentry_metrics_tags( - action="coverage", - endpoint="create_upload", - request=self.request, - repository=repository, - is_shelter_request=self.is_shelter_request(), - ) - - sentry_metrics.incr( - "upload", - tags=sentry_tags, - ) - sentry_metrics.set( "upload_set", repository.author.ownerid, - tags=sentry_tags, + tags=generate_upload_sentry_metrics_tags( + action="coverage", + endpoint="create_upload", + request=self.request, + repository=repository, + is_shelter_request=self.is_shelter_request(), + ), ) version = ( @@ -136,6 +139,17 @@ def perform_create(self, serializer: UploadSerializer): instance.storage_path = path instance.save() self.trigger_upload_task(repository, commit.commitid, instance, report) + sentry_metrics.incr( + "upload", + tags=generate_upload_sentry_metrics_tags( + action="coverage", + endpoint="create_upload", + request=self.request, + repository=repository, + is_shelter_request=self.is_shelter_request(), + position="end", + ), + ) metrics.incr("uploads.accepted", 1) self.activate_repo(repository) self.send_analytics_data(commit, instance, version) diff --git a/utils/__init__.py b/utils/__init__.py index 07b305e887..24819ae2ff 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1,3 +1,4 @@ +import math import uuid from typing import Any @@ -8,3 +9,18 @@ def is_uuid(value: Any) -> bool: return True except ValueError: return False + + +def round_decimals_down(number: float, decimals: int = 2) -> float: + """ + Returns a value rounded down to a specific number of decimal places. + """ + if not isinstance(decimals, int): + raise TypeError("decimal places must be an integer") + elif decimals < 0: + raise ValueError("decimal places has to be 0 or more") + elif decimals == 0: + return math.floor(number) + + factor = 10**decimals + return math.floor(number * factor) / factor diff --git a/utils/cache.py b/utils/cache.py index 9d0ff5e59b..66c2d621cc 100644 --- a/utils/cache.py +++ b/utils/cache.py @@ -1,3 +1,3 @@ -from shared.helpers.cache import OurOwnCache, RedisBackend +from shared.helpers.cache import OurOwnCache cache = OurOwnCache() diff --git a/utils/logging_configuration.py b/utils/logging_configuration.py index ee9f2f0494..71c2c24e0f 100644 --- a/utils/logging_configuration.py +++ b/utils/logging_configuration.py @@ -1,8 +1,8 @@ import json from logging import Filter -from pythonjsonlogger.jsonlogger import JsonFormatter, merge_record_extra -from sentry_sdk import Hub +from pythonjsonlogger.jsonlogger import JsonFormatter +from sentry_sdk import get_current_span class BaseLogger(JsonFormatter): @@ -44,7 +44,7 @@ def add_fields(self, log_record, record, message_dict): else: log_record["level"] = record.levelname - span = Hub.current.scope.span + span = get_current_span() if span and span.trace_id: log_record["sentry_trace_id"] = span.trace_id diff --git a/utils/model_utils.py b/utils/model_utils.py deleted file mode 100644 index 65837dbeb0..0000000000 --- a/utils/model_utils.py +++ /dev/null @@ -1,147 +0,0 @@ -import json -import logging -from typing import Any, Callable, Optional - -from shared.storage.exceptions import FileNotInStorageError -from shared.utils.ReportEncoder import ReportEncoder - -from services.archive import ArchiveService - -log = logging.getLogger(__name__) - - -class ArchiveFieldInterfaceMeta(type): - def __subclasscheck__(cls, subclass): - return ( - hasattr(subclass, "get_repository") - and callable(subclass.get_repository) - and hasattr(subclass, "get_commitid") - and callable(subclass.get_commitid) - and hasattr(subclass, "external_id") - ) - - -class ArchiveFieldInterface(metaclass=ArchiveFieldInterfaceMeta): - """Any class that uses ArchiveField must implement this interface""" - - external_id: str - - def get_repository(self): - """Returns the repository object associated with self""" - raise NotImplementedError() - - def get_commitid(self) -> Optional[str]: - """Returns the commitid associated with self. - If no commitid is associated return None. - """ - raise NotImplementedError() - - -class ArchiveField: - """This is a helper class that transparently handles models' fields that are saved in storage. - Classes that use the ArchiveField MUST implement ArchiveFieldInterface. It ill throw an error otherwise. - It uses the Descriptor pattern: https://docs.python.org/3/howto/descriptor.html - - Arguments: - should_write_to_storage_fn: Callable function that decides if data should be written to storage. - It should take 1 argument: the object instance. - - rehydrate_fn: Callable function to allow you to decode your saved data into internal representations. - The default value does nothing. - Data retrieved both from DB and storage pass through this function to guarantee consistency. - It should take 2 arguments: the object instance and the encoded data. - - default_value: Any value that will be returned if we can't save the data for whatever reason - - Example: - archive_field = ArchiveField( - should_write_to_storage_fn=should_write_data, - rehydrate_fn=rehidrate_data, - default_value='default' - ) - For a full example check utils/tests/unit/test_model_utils.py - """ - - def __init__( - self, - should_write_to_storage_fn: Callable[[object], bool], - rehydrate_fn: Callable[[object, object], Any] = lambda self, x: x, - json_encoder=ReportEncoder, - default_value_class=lambda: None, - ): - self.default_value_class = default_value_class - self.rehydrate_fn = rehydrate_fn - self.should_write_to_storage_fn = should_write_to_storage_fn - self.json_encoder = json_encoder - - def __set_name__(self, owner, name): - # Validate that the owner class has the methods we need - assert issubclass( - owner, ArchiveFieldInterface - ), "Missing some required methods to use AchiveField" - self.public_name = name - self.db_field_name = "_" + name - self.archive_field_name = "_" + name + "_storage_path" - self.cached_value_property_name = f"__{self.public_name}_cached_value" - - def _get_value_from_archive(self, obj): - repository = obj.get_repository() - archive_service = ArchiveService(repository=repository) - archive_field = getattr(obj, self.archive_field_name) - if archive_field: - try: - file_str = archive_service.read_file(archive_field) - return self.rehydrate_fn(obj, json.loads(file_str)) - except FileNotInStorageError: - log.error( - "Archive enabled field not in storage", - extra=dict( - storage_path=archive_field, - object_id=obj.id, - commit=obj.get_commitid(), - ), - ) - else: - log.debug( - "Both db_field and archive_field are None", - extra=dict( - object_id=obj.id, - commit=obj.get_commitid(), - ), - ) - return self.default_value_class() - - def __get__(self, obj, objtype=None): - cached_value = getattr(obj, self.cached_value_property_name, None) - if cached_value: - return cached_value - db_field = getattr(obj, self.db_field_name) - if db_field is not None: - value = self.rehydrate_fn(obj, db_field) - else: - value = self._get_value_from_archive(obj) - setattr(obj, self.cached_value_property_name, value) - return value - - def __set__(self, obj, value): - # Set the new value - if self.should_write_to_storage_fn(obj): - repository = obj.get_repository() - archive_service = ArchiveService(repository=repository) - old_file_path = getattr(obj, self.archive_field_name) - table_name = obj._meta.db_table - path = archive_service.write_json_data_to_storage( - commit_id=obj.get_commitid(), - table=table_name, - field=self.public_name, - external_id=obj.external_id, - data=value, - encoder=self.json_encoder, - ) - if old_file_path is not None and path != old_file_path: - archive_service.delete_file(old_file_path) - setattr(obj, self.archive_field_name, path) - setattr(obj, self.db_field_name, None) - else: - setattr(obj, self.db_field_name, value) - setattr(obj, self.cached_value_property_name, value) diff --git a/utils/rollouts.py b/utils/rollouts.py index f67a9edfa8..a9ffa1aa57 100644 --- a/utils/rollouts.py +++ b/utils/rollouts.py @@ -1,5 +1,3 @@ -from shared.rollouts import Feature - from codecov_auth.models import Owner diff --git a/utils/test_results.py b/utils/test_results.py new file mode 100644 index 0000000000..e1c1c34a64 --- /dev/null +++ b/utils/test_results.py @@ -0,0 +1,101 @@ +import datetime as dt +from dataclasses import dataclass + +from django.contrib.postgres.aggregates import ArrayAgg +from django.db.models import ( + Avg, + Case, + F, + FloatField, + Func, + IntegerField, + Max, + Q, + QuerySet, + Value, + When, +) +from django.db.models.functions import Coalesce +from shared.django_apps.reports.models import TestInstance + +thirty_days_ago = dt.datetime.now(dt.UTC) - dt.timedelta(days=30) + + +@dataclass +class TestResultsAggregation: + failure_rate: float | None + commits_where_fail: list[str] | None + average_duration: float | None + + +class ArrayLength(Func): + function = "CARDINALITY" + + +def aggregate_test_results( + repoid: int, + branch: str | None = None, + history: dt.timedelta | None = None, +) -> QuerySet: + """ + Function that retrieves aggregated information about all tests in a given repository, for a given time range, optionally filtered by branch name. + The fields it calculates are: the test failure rate, commits where this test failed, and average duration of the test. + + :param repoid: repoid of the repository we want to calculate aggregates for + :param branch: optional name of the branch we want to filter on, if this is provided the aggregates calculated will only take into account test instances generated on that branch. By default branches will not be filtered and test instances on all branches wil be taken into account. + :param history: optional timedelta field for filtering test instances used to calculated the aggregates by time, the test instances used will be those with a created at larger than now - history. + :returns: dictionary mapping test id to dictionary containing + + """ + time_ago = ( + dt.datetime.now(dt.UTC) - dt.timedelta(days=30) + if history is not None + else thirty_days_ago + ) + + pass_failure_error_test_instances = TestInstance.objects.filter( + repoid=repoid, + created_at__gt=time_ago, + outcome__in=["pass", "failure", "error"], + ) + + if branch is not None: + pass_failure_error_test_instances = pass_failure_error_test_instances.filter( + branch=branch + ) + + failure_rates_queryset = ( + pass_failure_error_test_instances.values("test") + .annotate( + failure_rate=Avg( + Case( + When(outcome="pass", then=Value(0.0)), + When(outcome__in=["failure", "error"], then=Value(1.0)), + output_field=FloatField(), + ) + ), + updated_at=Max("created_at"), + commits_where_fail=Coalesce( + ArrayLength( + ArrayAgg( + "commitid", + distinct=True, + filter=Q(outcome__in=["failure", "error"]), + ) + ), + 0, + output_field=IntegerField(), + ), + avg_duration=Avg("duration_seconds"), + name=F("test__name"), + ) + .values( + "failure_rate", + "commits_where_fail", + "avg_duration", + "name", + "updated_at", + ) + ) + + return failure_rates_queryset diff --git a/utils/tests/unit/test_logging.py b/utils/tests/unit/test_logging.py index 2d07749be9..940430f49a 100644 --- a/utils/tests/unit/test_logging.py +++ b/utils/tests/unit/test_logging.py @@ -1,5 +1,3 @@ -import os - from utils.logging_configuration import CustomLocalJsonFormatter diff --git a/utils/tests/unit/test_model_utils.py b/utils/tests/unit/test_model_utils.py deleted file mode 100644 index 4bcbc1a46c..0000000000 --- a/utils/tests/unit/test_model_utils.py +++ /dev/null @@ -1,135 +0,0 @@ -import json -from unittest.mock import MagicMock - -from shared.storage.exceptions import FileNotInStorageError -from shared.utils.ReportEncoder import ReportEncoder - -from core.models import Commit -from core.tests.factories import CommitFactory -from utils.model_utils import ArchiveField, ArchiveFieldInterface - - -class TestArchiveField(object): - class ClassWithArchiveField(object): - commit: Commit - id = 1 - external_id = "external_id" - _meta = MagicMock(db_table="test_table") - - _archive_field = "db_field" - _archive_field_storage_path = "archive_field_path" - - def should_write_to_storage(self): - return self.should_write_to_gcs - - def get_repository(self): - return self.commit.repository - - def get_commitid(self): - return self.commit.commitid - - def __init__( - self, commit, db_value, archive_value, should_write_to_gcs=False - ) -> None: - self.commit = commit - self._archive_field = db_value - self._archive_field_storage_path = archive_value - self.should_write_to_gcs = should_write_to_gcs - - archive_field = ArchiveField(should_write_to_storage_fn=should_write_to_storage) - - class ClassWithArchiveFieldMissingMethods: - commit: Commit - id = 1 - external_id = "external_id" - - def test_subclass_validation(self, mocker): - assert issubclass( - self.ClassWithArchiveField( - mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock() - ), - ArchiveFieldInterface, - ) - assert not issubclass( - self.ClassWithArchiveFieldMissingMethods, ArchiveFieldInterface - ) - - def test_archive_getter_db_field_set(self, db): - commit = CommitFactory() - test_class = self.ClassWithArchiveField(commit, "db_value", "gcs_path") - assert test_class._archive_field == "db_value" - assert test_class._archive_field_storage_path == "gcs_path" - assert test_class.archive_field == "db_value" - - def test_archive_getter_archive_field_set(self, db, mocker): - some_json = {"some": "data"} - mock_read_file = mocker.MagicMock(return_value=json.dumps(some_json)) - mock_archive_service = mocker.patch("utils.model_utils.ArchiveService") - mock_archive_service.return_value.read_file = mock_read_file - commit = CommitFactory() - test_class = self.ClassWithArchiveField(commit, None, "gcs_path") - - assert test_class._archive_field == None - assert test_class._archive_field_storage_path == "gcs_path" - assert test_class.archive_field == some_json - mock_read_file.assert_called_with("gcs_path") - mock_archive_service.assert_called_with(repository=commit.repository) - assert mock_read_file.call_count == 1 - # Test that caching also works - assert test_class.archive_field == some_json - assert mock_read_file.call_count == 1 - - def test_archive_getter_file_not_in_storage(self, db, mocker): - mock_read_file = mocker.MagicMock(side_effect=FileNotInStorageError()) - mock_archive_service = mocker.patch("utils.model_utils.ArchiveService") - mock_archive_service.return_value.read_file = mock_read_file - commit = CommitFactory() - test_class = self.ClassWithArchiveField(commit, None, "gcs_path") - - assert test_class._archive_field == None - assert test_class._archive_field_storage_path == "gcs_path" - assert test_class.archive_field == None - mock_read_file.assert_called_with("gcs_path") - mock_archive_service.assert_called_with(repository=commit.repository) - - def test_archive_setter_db_field(self, db, mocker): - commit = CommitFactory() - test_class = self.ClassWithArchiveField(commit, "db_value", "gcs_path", False) - assert test_class._archive_field == "db_value" - assert test_class._archive_field_storage_path == "gcs_path" - assert test_class.archive_field == "db_value" - mock_archive_service = mocker.patch("utils.model_utils.ArchiveService") - test_class.archive_field = "batata frita" - mock_archive_service.assert_not_called() - assert test_class._archive_field == "batata frita" - assert test_class.archive_field == "batata frita" - - def test_archive_setter_archive_field(self, db, mocker): - commit = CommitFactory() - test_class = self.ClassWithArchiveField(commit, "db_value", None, True) - some_json = {"some": "data"} - mock_read_file = mocker.MagicMock(return_value=json.dumps(some_json)) - mock_write_file = mocker.MagicMock(return_value="path/to/written/object") - mock_archive_service = mocker.patch("utils.model_utils.ArchiveService") - mock_archive_service.return_value.read_file = mock_read_file - mock_archive_service.return_value.write_json_data_to_storage = mock_write_file - - assert test_class._archive_field == "db_value" - assert test_class._archive_field_storage_path == None - assert test_class.archive_field == "db_value" - assert mock_read_file.call_count == 0 - - # Pretend there was something in the path. - # This should happen, but it will help us test the deletion of old data saved - test_class._archive_field_storage_path = "path/to/old/data" - - # Now we write to the property - test_class.archive_field = some_json - assert test_class._archive_field == None - assert test_class._archive_field_storage_path == "path/to/written/object" - assert test_class.archive_field == some_json - # The cache is updated on write, so reading doesn't trigger another read - assert mock_read_file.call_count == 0 - mock_archive_service.return_value.delete_file.assert_called_with( - "path/to/old/data" - ) diff --git a/utils/tests/unit/test_services.py b/utils/tests/unit/test_services.py index 5b841cc7ce..125384027a 100644 --- a/utils/tests/unit/test_services.py +++ b/utils/tests/unit/test_services.py @@ -1,5 +1,3 @@ -import os - from utils.services import get_long_service_name, get_short_service_name diff --git a/validate/views.py b/validate/views.py index 7164bbc532..e0d885e0bc 100644 --- a/validate/views.py +++ b/validate/views.py @@ -50,7 +50,7 @@ def post(self, request, *args, **kwargs): content_type="text/plain", ) - except YAMLError as e: + except YAMLError: return HttpResponse( "Can't parse YAML\n", status=status.HTTP_400_BAD_REQUEST, diff --git a/webhook_handlers/tests/test_bitbucket.py b/webhook_handlers/tests/test_bitbucket.py index e83cfeb857..87edd89e00 100644 --- a/webhook_handlers/tests/test_bitbucket.py +++ b/webhook_handlers/tests/test_bitbucket.py @@ -1,13 +1,11 @@ -import uuid from unittest.mock import patch -import pytest from rest_framework import status from rest_framework.reverse import reverse from rest_framework.test import APITestCase from codecov_auth.tests.factories import OwnerFactory -from core.models import Branch, Commit, Pull, PullStates, Repository +from core.models import Branch, Commit, PullStates from core.tests.factories import ( BranchFactory, CommitFactory, @@ -105,7 +103,7 @@ def test_pull_request_rejected(self): assert self.pull.state == PullStates.CLOSED def test_repo_push_branch_deleted(self): - branch = BranchFactory(repository=self.repo, name="name-of-branch") + BranchFactory(repository=self.repo, name="name-of-branch") response = self._post_event_data( event=BitbucketWebhookEvents.REPO_PUSH, data={ @@ -284,7 +282,7 @@ def test_repo_commit_status_change_in_progress(self): def test_repo_commit_status_change_commit_skip_processing(self): commitid = "9fec847784abb10b2fa567ee63b85bd238955d0e" - commit = CommitFactory( + CommitFactory( commitid=commitid, repository=self.repo, state=Commit.CommitStates.PENDING ) response = self._post_event_data( @@ -318,7 +316,7 @@ def test_repo_commit_status_change_commit_skip_processing(self): @patch("services.task.TaskService.notify") def test_repo_commit_status_change_commit_notifies(self, notify_mock): commitid = "9fec847784abb10b2fa567ee63b85bd238955d0e" - commit = CommitFactory( + CommitFactory( commitid=commitid, repository=self.repo, state=Commit.CommitStates.COMPLETE ) response = self._post_event_data( diff --git a/webhook_handlers/tests/test_bitbucket_server.py b/webhook_handlers/tests/test_bitbucket_server.py index 7fb7f02a9b..8cd461170d 100644 --- a/webhook_handlers/tests/test_bitbucket_server.py +++ b/webhook_handlers/tests/test_bitbucket_server.py @@ -1,16 +1,13 @@ -import uuid from unittest.mock import patch -import pytest from rest_framework import status from rest_framework.reverse import reverse from rest_framework.test import APITestCase from codecov_auth.tests.factories import OwnerFactory -from core.models import Branch, Commit, Pull, PullStates, Repository +from core.models import Branch, PullStates from core.tests.factories import ( BranchFactory, - CommitFactory, PullFactory, RepositoryFactory, ) @@ -132,7 +129,7 @@ def test_pull_request_rejected(self): assert self.pull.state == PullStates.CLOSED def test_repo_push_branch_deleted(self): - branch = BranchFactory(repository=self.repo, name="name-of-branch") + BranchFactory(repository=self.repo, name="name-of-branch") response = self._post_event_data( event=BitbucketServerWebhookEvents.REPO_REFS_CHANGED, data={ diff --git a/webhook_handlers/tests/test_github.py b/webhook_handlers/tests/test_github.py index d4a947002d..04d58df6f4 100644 --- a/webhook_handlers/tests/test_github.py +++ b/webhook_handlers/tests/test_github.py @@ -195,7 +195,7 @@ def test_webhook_counters(self): def test_get_repo_paths_dont_crash(self): with self.subTest("with ownerid success"): - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.REPOSITORY, data={ "action": "publicized", @@ -207,7 +207,7 @@ def test_get_repo_paths_dont_crash(self): ) with self.subTest("with not found owner"): - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.REPOSITORY, data={ "action": "publicized", @@ -219,7 +219,7 @@ def test_get_repo_paths_dont_crash(self): ) with self.subTest("with not found owner and not found repo"): - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.REPOSITORY, data={ "action": "publicized", @@ -228,7 +228,7 @@ def test_get_repo_paths_dont_crash(self): ) with self.subTest("with owner and not found repo"): - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.REPOSITORY, data={ "action": "publicized", @@ -277,8 +277,6 @@ def test_repository_privatized_sets_private_true(self): assert self.repo.private == True def test_repository_deleted_sets_deleted_activated_and_active(self): - repository_id = self.repo.repoid - response = self._post_event_data( event=GitHubWebhookEvents.REPOSITORY, data={"action": "deleted", "repository": {"id": self.repo.service_id}}, @@ -359,7 +357,7 @@ def test_push_updates_only_unmerged_commits_with_branch_name(self): merged=True, repository=self.repo, branch=merged_branch_name ) - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.PUSH, data={ "ref": "refs/heads/" + unmerged_branch_name, @@ -395,7 +393,7 @@ def test_push_updates_commit_on_default_branch(self): merged=True, repository=self.repo, branch=merged_branch_name ) - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.PUSH, data={ "ref": "refs/heads/" + repo_branch, @@ -475,7 +473,7 @@ def test_push_triggers_set_pending_task_on_most_recent_commit( commit2 = CommitFactory(merged=False, repository=self.repo) unmerged_branch_name = "unmerged" - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.PUSH, data={ "ref": "refs/heads/" + unmerged_branch_name, @@ -500,9 +498,8 @@ def test_push_doesnt_trigger_task_if_repo_not_part_of_beta_set( self, set_pending_mock ): commit1 = CommitFactory(merged=False, repository=self.repo) - unmerged_branch_name = "unmerged" - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.PUSH, data={ "ref": "refs/heads/" + "derp", @@ -517,7 +514,6 @@ def test_push_doesnt_trigger_task_if_repo_not_part_of_beta_set( @patch("services.task.TaskService.status_set_pending") def test_push_doesnt_trigger_task_if_ci_skipped(self, set_pending_mock): commit1 = CommitFactory(merged=False, repository=self.repo, message="[ci skip]") - unmerged_branch_name = "unmerged" response = self._post_event_data( event=GitHubWebhookEvents.PUSH, @@ -607,7 +603,7 @@ def test_pull_request_triggers_pulls_sync_task_for_valid_actions( valid_actions = ["opened", "closed", "reopened", "synchronize"] for action in valid_actions: - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.PULL_REQUEST, data={ "repository": {"id": self.repo.service_id}, @@ -643,7 +639,7 @@ def test_pull_request_updates_title_if_edited(self): def test_installation_creates_new_owner_if_dne_default_app(self, mock_refresh): username, service_id = "newuser", 123456 - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.INSTALLATION, data={ "installation": { @@ -706,7 +702,7 @@ def test_installation_creates_new_owner_if_dne_default_app(self, mock_refresh): def test_installation_creates_new_owner_if_dne_all_repos_non_default_app(self): username, service_id = "newuser", 123456 - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.INSTALLATION, data={ "installation": { @@ -739,7 +735,7 @@ def test_installation_creates_new_owner_if_dne_all_repos_non_default_app(self): assert installation.installation_id == 4 assert installation.app_id == 15 assert installation.name == "unconfigured_app" - assert installation.repository_service_ids == None + assert installation.repository_service_ids is None @patch( "services.task.TaskService.refresh", @@ -754,7 +750,7 @@ def test_installation_creates_new_owner_if_dne_all_repos_non_default_app(self): def test_installation_repositories_creates_new_owner_if_dne(self): username, service_id = "newuser", 123456 - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.INSTALLATION_REPOSITORIES, data={ "installation": { @@ -768,9 +764,7 @@ def test_installation_repositories_creates_new_owner_if_dne(self): }, ) - owner_set = Owner.objects.filter( - service="github", service_id=service_id, username=username - ) + owner_set = Owner.objects.filter(service="github", service_id=service_id) assert owner_set.exists() @@ -784,7 +778,7 @@ def test_installation_repositories_creates_new_owner_if_dne(self): assert installation.installation_id == 4 assert installation.app_id == 15 assert installation.name == "unconfigured_app" - assert installation.repository_service_ids == None + assert installation.repository_service_ids is None @patch( "services.task.TaskService.refresh", @@ -808,7 +802,7 @@ def test_installation_update_repos_existing_ghapp_installation(self): installation.save() assert owner.github_app_installations.count() == 1 - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.INSTALLATION, data={ "installation": { @@ -864,7 +858,7 @@ def test_installation_with_deleted_action_nulls_values(self): assert owner.github_app_installations.exists() - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.INSTALLATION, data={ "installation": { @@ -886,12 +880,12 @@ def test_installation_with_deleted_action_nulls_values(self): repo1.refresh_from_db() repo2.refresh_from_db() - assert owner.integration_id == None + assert owner.integration_id is None assert repo1.using_integration == False assert repo2.using_integration == False - assert repo1.bot == None - assert repo2.bot == None + assert repo1.bot is None + assert repo2.bot is None assert not owner.github_app_installations.exists() @@ -919,23 +913,15 @@ def test_installation_repositories_update_existing_ghapp(self): name=GITHUB_APP_INSTALLATION_DEFAULT_NAME, pem_path="some_path", ) - - owner.integration_id = 12 owner.save() - - repo1.using_integration, repo2.using_integration = True, True - repo1.bot, repo2.bot = owner, owner - repo1.save() repo2.save() - installation.save() assert owner.github_app_installations.exists() - assert installation.is_repo_covered_by_integration(repo2) is False - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.INSTALLATION_REPOSITORIES, data={ "installation": { @@ -979,20 +965,14 @@ def test_installation_repositories_update_existing_ghapp_all_repos(self): owner=owner, repository_service_ids=[repo1.service_id], installation_id=12 ) - owner.integration_id = 12 owner.save() - - repo1.using_integration, repo2.using_integration = True, True - repo1.bot, repo2.bot = owner, owner - repo1.save() repo2.save() - installation.save() assert owner.github_app_installations.exists() - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.INSTALLATION_REPOSITORIES, data={ "installation": { @@ -1011,7 +991,7 @@ def test_installation_repositories_update_existing_ghapp_all_repos(self): installation.refresh_from_db() assert installation.installation_id == 12 - assert installation.repository_service_ids == None + assert installation.repository_service_ids is None @patch( "services.task.TaskService.refresh", @@ -1032,7 +1012,7 @@ def test_installation_with_other_actions_sets_owner_integration_id_if_none( owner.integration_id = None owner.save() - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.INSTALLATION, data={ "installation": { @@ -1045,7 +1025,7 @@ def test_installation_with_other_actions_sets_owner_integration_id_if_none( {"id": "12321", "node_id": "R_kgDOG2tZYQ"}, {"id": "12343", "node_id": "R_kgDOG2tABC"}, ], - "action": "added", + "action": "suspend", "sender": {"type": "User"}, }, ) @@ -1062,6 +1042,7 @@ def test_installation_with_other_actions_sets_owner_integration_id_if_none( assert installation.installation_id == installation_id assert installation.app_id == DEFAULT_APP_ID assert installation.name == GITHUB_APP_INSTALLATION_DEFAULT_NAME + assert installation.is_suspended == True assert installation.repository_service_ids == ["12321", "12343"] @patch( @@ -1083,7 +1064,7 @@ def test_installation_repositories_with_other_actions_sets_owner_itegration_id_i owner.integration_id = None owner.save() - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.INSTALLATION_REPOSITORIES, data={ "installation": { @@ -1108,13 +1089,13 @@ def test_installation_repositories_with_other_actions_sets_owner_itegration_id_i assert ghapp_installations_set.count() == 1 installation = ghapp_installations_set.first() assert installation.installation_id == installation_id - assert installation.repository_service_ids == None + assert installation.repository_service_ids is None @patch("services.task.TaskService.refresh") def test_installation_trigger_refresh_with_other_actions(self, refresh_mock): owner = OwnerFactory(service=Service.GITHUB.value) - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.INSTALLATION, data={ "installation": { @@ -1156,7 +1137,7 @@ def test_organization_with_removed_action_removes_user_from_org_and_activated_us org.plan_activated_users = [user.ownerid] org.save() - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.ORGANIZATION, data={ "action": "member_removed", @@ -1227,7 +1208,7 @@ def test_marketplace_purchase_triggers_sync_plans_task( action = "purchased" account = {"type": "Organization", "id": 54678, "login": "username"} subscription_retrieve_mock.return_value = None - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.MARKETPLACE_PURCHASE, data={ "action": action, @@ -1257,7 +1238,7 @@ def test_marketplace_purchase_but_user_has_stripe_subscription( subscription_retrieve_mock.return_value = MockedSubscription( "active", plan, quantity ) - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.MARKETPLACE_PURCHASE, data={ "action": action, @@ -1353,7 +1334,7 @@ def test_member_removes_repo_permissions_if_member_removed(self): member = OwnerFactory( permission=[self.repo.repoid], service_id=6098, service=Service.GITHUB.value ) - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.MEMBER, data={ "action": "removed", @@ -1369,7 +1350,7 @@ def test_member_doesnt_crash_if_member_permission_array_is_None(self): member = OwnerFactory( permission=None, service_id=6098, service=Service.GITHUB.value ) - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.MEMBER, data={ "action": "removed", @@ -1384,7 +1365,7 @@ def test_member_doesnt_crash_if_member_didnt_have_permission(self): service_id=6098, service=Service.GITHUB.value, ) - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.MEMBER, data={ "action": "removed", @@ -1417,7 +1398,7 @@ def test_repo_not_found_when_owner_has_integration_creates_repo(self): owner = OwnerFactory( integration_id=4850403, service_id=97968493, service=Service.GITHUB.value ) - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.REPOSITORY, data={ "action": "publicized", @@ -1437,7 +1418,7 @@ def test_repo_creation_doesnt_crash_for_forked_repo(self): owner = OwnerFactory( integration_id=4850403, service_id=97968493, service=Service.GITHUB.value ) - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.REPOSITORY, data={ "action": "publicized", diff --git a/webhook_handlers/tests/test_github_enterprise.py b/webhook_handlers/tests/test_github_enterprise.py index c32482c506..1f13e1f8a3 100644 --- a/webhook_handlers/tests/test_github_enterprise.py +++ b/webhook_handlers/tests/test_github_enterprise.py @@ -1,7 +1,6 @@ import hmac import json import uuid -from collections import namedtuple from hashlib import sha256 from unittest.mock import call, patch @@ -13,7 +12,6 @@ from codecov_auth.models import GithubAppInstallation, Owner, Service from codecov_auth.tests.factories import OwnerFactory -from core.models import Repository from core.tests.factories import ( BranchFactory, CommitFactory, @@ -72,7 +70,7 @@ def setUp(self): def test_get_repo_paths_dont_crash(self): with self.subTest("with ownerid success"): - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.REPOSITORY, data={ "action": "publicized", @@ -84,7 +82,7 @@ def test_get_repo_paths_dont_crash(self): ) with self.subTest("with not found owner"): - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.REPOSITORY, data={ "action": "publicized", @@ -96,7 +94,7 @@ def test_get_repo_paths_dont_crash(self): ) with self.subTest("with not found owner and not found repo"): - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.REPOSITORY, data={ "action": "publicized", @@ -105,7 +103,7 @@ def test_get_repo_paths_dont_crash(self): ) with self.subTest("with owner and not found repo"): - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.REPOSITORY, data={ "action": "publicized", @@ -154,8 +152,6 @@ def test_repository_privatized_sets_private_true(self): assert self.repo.private == True def test_repository_deleted_sets_deleted_activated_and_active(self): - repository_id = self.repo.repoid - response = self._post_event_data( event=GitHubWebhookEvents.REPOSITORY, data={"action": "deleted", "repository": {"id": self.repo.service_id}}, @@ -209,7 +205,7 @@ def test_push_updates_only_unmerged_commits_with_branch_name(self): merged=True, repository=self.repo, branch=merged_branch_name ) - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.PUSH, data={ "ref": "refs/heads/" + unmerged_branch_name, @@ -245,7 +241,7 @@ def test_push_updates_commit_on_default_branch(self): merged=True, repository=self.repo, branch=merged_branch_name ) - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.PUSH, data={ "ref": "refs/heads/" + repo_branch, @@ -300,7 +296,7 @@ def test_push_triggers_set_pending_task_on_most_recent_commit( commit2 = CommitFactory(merged=False, repository=self.repo) unmerged_branch_name = "unmerged" - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.PUSH, data={ "ref": "refs/heads/" + unmerged_branch_name, @@ -325,9 +321,8 @@ def test_push_doesnt_trigger_task_if_repo_not_part_of_beta_set( self, set_pending_mock ): commit1 = CommitFactory(merged=False, repository=self.repo) - unmerged_branch_name = "unmerged" - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.PUSH, data={ "ref": "refs/heads/" + "derp", @@ -342,7 +337,6 @@ def test_push_doesnt_trigger_task_if_repo_not_part_of_beta_set( @patch("services.task.TaskService.status_set_pending") def test_push_doesnt_trigger_task_if_ci_skipped(self, set_pending_mock): commit1 = CommitFactory(merged=False, repository=self.repo, message="[ci skip]") - unmerged_branch_name = "unmerged" response = self._post_event_data( event=GitHubWebhookEvents.PUSH, @@ -432,7 +426,7 @@ def test_pull_request_triggers_pulls_sync_task_for_valid_actions( valid_actions = ["opened", "closed", "reopened", "synchronize"] for action in valid_actions: - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.PULL_REQUEST, data={ "repository": {"id": self.repo.service_id}, @@ -477,7 +471,7 @@ def test_pull_request_updates_title_if_edited(self): def test_installation_creates_new_owner_if_dne(self): username, service_id = "newuser", 123456 - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.INSTALLATION, data={ "installation": { @@ -526,7 +520,7 @@ def test_installation_creates_new_owner_if_dne(self): def test_installation_creates_new_owner_if_dne_all_repos(self): username, service_id = "newuser", 123456 - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.INSTALLATION, data={ "installation": { @@ -559,7 +553,7 @@ def test_installation_creates_new_owner_if_dne_all_repos(self): assert ghapp_installations_set.count() == 1 installation = ghapp_installations_set.first() assert installation.installation_id == 4 - assert installation.repository_service_ids == None + assert installation.repository_service_ids is None @freeze_time("2024-03-28T00:00:00") @patch( @@ -575,7 +569,7 @@ def test_installation_creates_new_owner_if_dne_all_repos(self): def test_installation_repositories_creates_new_owner_if_dne(self): username, service_id = "newuser", 123456 - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.INSTALLATION_REPOSITORIES, data={ "installation": { @@ -606,7 +600,7 @@ def test_installation_repositories_creates_new_owner_if_dne(self): assert ghapp_installations_set.count() == 1 installation = ghapp_installations_set.first() assert installation.installation_id == 4 - assert installation.repository_service_ids == None + assert installation.repository_service_ids is None def test_installation_with_deleted_action_nulls_values(self): # Should set integration_id to null for owner, @@ -633,7 +627,7 @@ def test_installation_with_deleted_action_nulls_values(self): assert owner.github_app_installations.exists() - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.INSTALLATION, data={ "installation": { @@ -655,12 +649,12 @@ def test_installation_with_deleted_action_nulls_values(self): repo1.refresh_from_db() repo2.refresh_from_db() - assert owner.integration_id == None + assert owner.integration_id is None assert repo1.using_integration == False assert repo2.using_integration == False - assert repo1.bot == None - assert repo2.bot == None + assert repo1.bot is None + assert repo2.bot is None assert not owner.github_app_installations.exists() @@ -697,7 +691,7 @@ def test_installation_repositories_update_existing_ghapp(self): assert owner.github_app_installations.exists() - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.INSTALLATION_REPOSITORIES, data={ "installation": { @@ -753,7 +747,7 @@ def test_installation_repositories_update_existing_ghapp_all_repos(self): assert owner.github_app_installations.exists() - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.INSTALLATION_REPOSITORIES, data={ "installation": { @@ -772,7 +766,7 @@ def test_installation_repositories_update_existing_ghapp_all_repos(self): installation.refresh_from_db() assert installation.installation_id == 12 - assert installation.repository_service_ids == None + assert installation.repository_service_ids is None @patch( "services.task.TaskService.refresh", @@ -793,7 +787,7 @@ def test_installation_with_other_actions_sets_owner_itegration_id_if_none( owner.integration_id = None owner.save() - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.INSTALLATION, data={ "installation": { @@ -842,7 +836,7 @@ def test_installation_repositories_with_other_actions_sets_owner_itegration_id_i owner.integration_id = None owner.save() - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.INSTALLATION_REPOSITORIES, data={ "installation": { @@ -867,13 +861,13 @@ def test_installation_repositories_with_other_actions_sets_owner_itegration_id_i assert ghapp_installations_set.count() == 1 installation = ghapp_installations_set.first() assert installation.installation_id == installation_id - assert installation.repository_service_ids == None + assert installation.repository_service_ids is None @patch("services.task.TaskService.refresh") def test_installation_trigger_refresh_with_other_actions(self, refresh_mock): owner = OwnerFactory(service=Service.GITHUB_ENTERPRISE.value) - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.INSTALLATION, data={ "installation": { @@ -917,7 +911,7 @@ def test_organization_with_removed_action_removes_user_from_org_and_activated_us org.plan_activated_users = [user.ownerid] org.save() - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.ORGANIZATION, data={ "action": "member_removed", @@ -990,7 +984,7 @@ def test_marketplace_purchase_triggers_sync_plans_task( action = "purchased" account = {"type": "Organization", "id": 54678, "login": "username"} subscription_retrieve_mock.return_value = None - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.MARKETPLACE_PURCHASE, data={ "action": action, @@ -1022,7 +1016,7 @@ def test_marketplace_purchase_but_user_has_stripe_subscription( subscription_retrieve_mock.return_value = MockedSubscription( "active", plan, quantity ) - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.MARKETPLACE_PURCHASE, data={ "action": action, @@ -1077,7 +1071,7 @@ def test_member_removes_repo_permissions_if_member_removed(self): service_id=6098, service=Service.GITHUB_ENTERPRISE.value, ) - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.MEMBER, data={ "action": "removed", @@ -1093,7 +1087,7 @@ def test_member_doesnt_crash_if_member_permission_array_is_None(self): member = OwnerFactory( permission=None, service_id=6098, service=Service.GITHUB_ENTERPRISE.value ) - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.MEMBER, data={ "action": "removed", @@ -1108,7 +1102,7 @@ def test_member_doesnt_crash_if_member_didnt_have_permission(self): service_id=6098, service=Service.GITHUB_ENTERPRISE.value, ) - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.MEMBER, data={ "action": "removed", @@ -1143,7 +1137,7 @@ def test_repo_not_found_when_owner_has_integration_creates_repo(self): service_id=97968493, service=Service.GITHUB_ENTERPRISE.value, ) - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.REPOSITORY, data={ "action": "publicized", @@ -1165,7 +1159,7 @@ def test_repo_creation_doesnt_crash_for_forked_repo(self): service_id=97968493, service=Service.GITHUB_ENTERPRISE.value, ) - response = self._post_event_data( + self._post_event_data( event=GitHubWebhookEvents.REPOSITORY, data={ "action": "publicized", diff --git a/webhook_handlers/tests/test_gitlab.py b/webhook_handlers/tests/test_gitlab.py index be6c5f8440..2000dcc383 100644 --- a/webhook_handlers/tests/test_gitlab.py +++ b/webhook_handlers/tests/test_gitlab.py @@ -5,9 +5,8 @@ from rest_framework.reverse import reverse from rest_framework.test import APITestCase -from codecov_auth.models import Owner from codecov_auth.tests.factories import OwnerFactory -from core.models import Commit, Pull, PullStates, Repository +from core.models import Commit, PullStates from core.tests.factories import CommitFactory, PullFactory, RepositoryFactory from webhook_handlers.constants import ( GitLabHTTPHeaders, @@ -18,15 +17,15 @@ def get_config_mock(*args, **kwargs): if args == ("setup", "enterprise_license"): - return True + return False elif args == ("gitlab", "webhook_validation"): - return True + return False else: return kwargs.get("default") class TestGitlabWebhookHandler(APITestCase): - def _post_event_data(self, event, data={}): + def _post_event_data(self, event, data): return self.client.post( reverse("gitlab-webhook"), data=data, @@ -111,7 +110,7 @@ def test_job_event_commit_not_found(self): def test_job_event_commit_not_complete(self): commit_sha = "2293ada6b400935a1378653304eaf6221e0fdb8f" - commit = CommitFactory( + CommitFactory( author=self.repo.author, repository=self.repo, commitid=commit_sha, @@ -247,373 +246,38 @@ def test_merge_request_event_action_update(self, pulls_sync_mock): pulls_sync_mock.assert_called_once_with(repoid=self.repo.repoid, pullid=pullid) - def test_handle_system_hook_not_enterprise(self): - def side_effect(*args, **kwargs): - if args == ("setup", "enterprise_license"): - return None - else: - return kwargs.get("default") - - self.get_config_mock.side_effect = side_effect - - username = "jsmith" - project_id = 74 - owner = OwnerFactory(service="gitlab", username=username) - - response = self._post_event_data( - event=GitLabWebhookEvents.SYSTEM, - data={ - "created_at": "2020-01-21T07:30:54Z", - "updated_at": "2020-01-21T07:38:22Z", - "event_name": "project_create", - "name": "StoreCloud", - "owner_email": "johnsmith@gmail.com", - "owner_name": "John Smith", - "path": "storecloud", - "path_with_namespace": f"{username}/storecloud", - "project_id": project_id, - "project_visibility": "private", - }, - ) - assert response.status_code == status.HTTP_403_FORBIDDEN - assert response.data.get("detail") == "No enterprise license detected" - - new_repo = Repository.objects.filter( - author__ownerid=owner.ownerid, service_id=project_id - ).first() - assert new_repo is None - - def test_handle_system_hook_project_create(self): - username = "jsmith" - project_id = 74 - owner = OwnerFactory(service="gitlab", username=username) - - response = self._post_event_data( - event=GitLabWebhookEvents.SYSTEM, - data={ - "created_at": "2020-01-21T07:30:54Z", - "updated_at": "2020-01-21T07:38:22Z", + def test_handle_system_hook_when_not_enterprise(self): + owner = OwnerFactory(service="gitlab") + repo = RepositoryFactory(author=owner) + + system_hook_events = [ + "project_create", + "project_destroy", + "project_rename", + "project_transfer", + "user_add_to_team", + "user_remove_from_team", + ] + + event_data = { + "event": GitLabWebhookEvents.SYSTEM, + "data": { "event_name": "project_create", - "name": "StoreCloud", - "owner_email": "johnsmith@gmail.com", - "owner_name": "John Smith", - "path": "storecloud", - "path_with_namespace": f"{username}/storecloud", - "project_id": project_id, - "project_visibility": "private", - }, - ) - assert response.status_code == status.HTTP_200_OK - assert response.data == "Repository created" - - new_repo = Repository.objects.get( - author__ownerid=owner.ownerid, service_id=project_id - ) - assert new_repo is not None - assert new_repo.private is True - assert new_repo.name == "storecloud" - - def test_handle_system_hook_project_destroy(self): - username = "jsmith" - project_id = 73 - owner = OwnerFactory(service="gitlab", username=username) - repo = RepositoryFactory( - name="testing", - author=owner, - service_id=project_id, - active=True, - activated=True, - deleted=False, - ) - - response = self._post_event_data( - event=GitLabWebhookEvents.SYSTEM, - data={ - "created_at": "2020-01-21T07:30:58Z", - "updated_at": "2020-01-21T07:38:22Z", - "event_name": "project_destroy", - "name": "Underscore", - "owner_email": "johnsmith@gmail.com", - "owner_name": "John Smith", - "path": "underscore", - "path_with_namespace": f"{username}/underscore", - "project_id": project_id, - "project_visibility": "internal", - }, - ) - assert response.status_code == status.HTTP_200_OK - assert response.data == "Repository deleted" - - repo.refresh_from_db() - assert repo.active is False - assert repo.activated is False - assert repo.deleted is True - assert repo.name == "testing-deleted" - - def test_handle_system_hook_project_rename(self): - username = "jsmith" - project_id = 73 - owner = OwnerFactory(service="gitlab", username=username) - repo = RepositoryFactory( - author=owner, - service_id=project_id, - name="overscore", - active=True, - activated=True, - deleted=False, - ) - - response = self._post_event_data( - event=GitLabWebhookEvents.SYSTEM, - data={ - "created_at": "2020-01-21T07:30:58Z", - "updated_at": "2020-01-21T07:38:22Z", - "event_name": "project_rename", - "name": "Underscore", - "path": "underscore", - "path_with_namespace": f"{username}/underscore", - "project_id": 73, - "owner_name": "John Smith", - "owner_email": "johnsmith@gmail.com", - "project_visibility": "internal", - "old_path_with_namespace": "jsmith/overscore", - }, - ) - assert response.status_code == status.HTTP_200_OK - assert response.data == "Repository renamed" - - repo.refresh_from_db() - assert repo.name == "underscore" - - def test_handle_system_hook_project_transfer(self): - old_owner_username = "jsmith" - new_owner_username = "scores" - project_id = 73 - new_owner = OwnerFactory(service="gitlab", username=new_owner_username) - old_owner = OwnerFactory(service="gitlab", username=old_owner_username) - repo = RepositoryFactory( - author=old_owner, - service_id=project_id, - name="overscore", - active=True, - activated=True, - deleted=False, - ) - - response = self._post_event_data( - event=GitLabWebhookEvents.SYSTEM, - data={ - "created_at": "2020-01-21T07:30:58Z", - "updated_at": "2020-01-21T07:38:22Z", - "event_name": "project_transfer", - "name": "Underscore", - "path": "underscore", - "path_with_namespace": f"{new_owner_username}/underscore", - "project_id": project_id, - "owner_name": "John Smith", - "owner_email": "johnsmith@gmail.com", - "project_visibility": "internal", - "old_path_with_namespace": f"{old_owner_username}/overscore", - }, - ) - assert response.status_code == status.HTTP_200_OK - assert response.data == "Repository transfered" - - repo.refresh_from_db() - assert repo.name == "underscore" - assert repo.author == new_owner - - def test_handle_system_hook_user_create(self): - gl_user_id = 41 - response = self._post_event_data( - event=GitLabWebhookEvents.SYSTEM, - data={ - "created_at": "2012-07-21T07:44:07Z", - "updated_at": "2012-07-21T07:38:22Z", - "email": "js@gitlabhq.com", - "event_name": "user_create", - "name": "John Smith", - "username": "js", - "user_id": gl_user_id, - }, - ) - assert response.status_code == status.HTTP_200_OK - assert response.data == "User created" - - new_user = Owner.objects.get(service="gitlab", service_id=gl_user_id) - assert new_user.name == "John Smith" - assert new_user.email == "js@gitlabhq.com" - assert new_user.username == "js" - - def test_handle_system_hook_user_add_to_team_no_existing_permissions(self): - gl_user_id = 41 - project_id = 74 - username = "johnsmith" - user = OwnerFactory( - service="gitlab", service_id=gl_user_id, username=username, permission=None - ) - repo = RepositoryFactory( - author=user, - service_id=project_id, - active=True, - activated=True, - deleted=False, - ) - response = self._post_event_data( - event=GitLabWebhookEvents.SYSTEM, - data={ - "created_at": "2012-07-21T07:30:56Z", - "updated_at": "2012-07-21T07:38:22Z", - "event_name": "user_add_to_team", - "access_level": "Maintainer", - "project_id": project_id, - "project_name": "StoreCloud", - "project_path": "storecloud", - "project_path_with_namespace": "jsmith/storecloud", - "user_email": "johnsmith@gmail.com", - "user_name": "John Smith", - "user_username": username, - "user_id": gl_user_id, - "project_visibility": "private", - }, - ) - assert response.status_code == status.HTTP_200_OK - assert response.data == "Permission added" - - user.refresh_from_db() - assert user.permission == [repo.repoid] - - def test_handle_system_hook_user_add_to_team(self): - gl_user_id = 41 - project_id = 74 - username = "johnsmith" - user = OwnerFactory( - service="gitlab", - service_id=gl_user_id, - username="johnsmith", - permission=[1, 2, 3, 100], - ) - repo = RepositoryFactory( - author=user, - service_id=project_id, - active=True, - activated=True, - deleted=False, - ) - response = self._post_event_data( - event=GitLabWebhookEvents.SYSTEM, - data={ - "created_at": "2012-07-21T07:30:56Z", - "updated_at": "2012-07-21T07:38:22Z", - "event_name": "user_add_to_team", - "access_level": "Maintainer", - "project_id": project_id, - "project_name": "StoreCloud", - "project_path": "storecloud", - "project_path_with_namespace": "jsmith/storecloud", - "user_email": "johnsmith@gmail.com", - "user_name": "John Smith", - "user_username": username, - "user_id": gl_user_id, - "project_visibility": "private", - }, - ) - assert response.status_code == status.HTTP_200_OK - assert response.data == "Permission added" - - user.refresh_from_db() - assert len(user.permission) == 5 - assert repo.repoid in user.permission - - def test_handle_system_hook_user_add_to_team_repo_public(self): - gl_user_id = 41 - project_id = 74 - username = "johnsmith" - user = OwnerFactory( - service="gitlab", - service_id=gl_user_id, - username=username, - permission=[1, 2, 3, 100], - ) - repo = RepositoryFactory( - author=user, - service_id=project_id, - active=True, - activated=True, - deleted=False, - ) - response = self._post_event_data( - event=GitLabWebhookEvents.SYSTEM, - data={ - "created_at": "2012-07-21T07:30:56Z", - "updated_at": "2012-07-21T07:38:22Z", - "event_name": "user_add_to_team", - "access_level": "Maintainer", - "project_id": project_id, - "project_name": "StoreCloud", - "project_path": "storecloud", - "project_path_with_namespace": "jsmith/storecloud", - "user_email": "johnsmith@gmail.com", - "user_name": "John Smith", - "user_username": username, - "user_id": gl_user_id, - "project_visibility": "public", - }, - ) - assert response.status_code == status.HTTP_200_OK - assert response.data is None - - user.refresh_from_db() - - assert user.permission == [1, 2, 3, 100] # no change - - def test_handle_system_hook_user_remove_from_team(self): - gl_user_id = 41 - project_id = 74 - username = "johnsmith" - user = OwnerFactory( - service="gitlab", service_id=gl_user_id, username=username, permission=None - ) - repo = RepositoryFactory( - author=user, - service_id=project_id, - active=True, - activated=True, - deleted=False, - ) - user.permission = [1, 2, 3, repo.repoid] - user.save() - - response = self._post_event_data( - event=GitLabWebhookEvents.SYSTEM, - data={ - "created_at": "2012-07-21T07:30:56Z", - "updated_at": "2012-07-21T07:38:22Z", - "event_name": "user_remove_from_team", - "access_level": "Maintainer", - "project_id": project_id, - "project_name": "StoreCloud", - "project_path": "storecloud", - "project_path_with_namespace": "jsmith/storecloud", - "user_email": "johnsmith@gmail.com", - "user_name": "John Smith", - "user_username": username, - "user_id": gl_user_id, - "project_visibility": "private", + "project_id": repo.service_id, }, - ) - assert response.status_code == status.HTTP_200_OK - assert response.data == "Permission removed" + } - user.refresh_from_db() - assert user.permission == [1, 2, 3] + for event in system_hook_events: + event_data["data"]["event_name"] = event + response = self._post_event_data(**event_data) + assert response.status_code == status.HTTP_403_FORBIDDEN def test_secret_validation(self): owner = OwnerFactory(service="gitlab") repo = RepositoryFactory( author=owner, service_id=uuid.uuid4(), - webhook_secret=uuid.uuid4(), + webhook_secret=uuid.uuid4(), # if repo has webhook secret, requires validation ) owner.permission = [repo.repoid] owner.save() diff --git a/webhook_handlers/tests/test_gitlab_enterprise.py b/webhook_handlers/tests/test_gitlab_enterprise.py index 7d7737d7c5..1e490ed3d7 100644 --- a/webhook_handlers/tests/test_gitlab_enterprise.py +++ b/webhook_handlers/tests/test_gitlab_enterprise.py @@ -7,9 +7,8 @@ from rest_framework.test import APITestCase from shared.utils.test_utils import mock_config_helper -from codecov_auth.models import Owner from codecov_auth.tests.factories import OwnerFactory -from core.models import Commit, Pull, PullStates, Repository +from core.models import Commit, PullStates, Repository from core.tests.factories import CommitFactory, PullFactory, RepositoryFactory from webhook_handlers.constants import ( GitLabHTTPHeaders, @@ -17,8 +16,6 @@ WebhookHandlerErrorMessages, ) -webhook_secret = "test-46204fb3-374e-4cfc-8cae-d7ca43371096" - class TestGitlabEnterpriseWebhookHandler(APITestCase): @pytest.fixture(scope="function", autouse=True) @@ -31,17 +28,16 @@ def mock_config(self, mocker): mocker, configs={ "setup.enterprise_license": True, - "gitlab_enterprise.webhook_secret": webhook_secret, - "gitlab_enterprise.webhook_validation": True, + "gitlab_enterprise.webhook_validation": False, }, ) - def _post_event_data(self, event, data={}): + def _post_event_data(self, event, data, token=None): return self.client.post( reverse("gitlab_enterprise-webhook"), **{ GitLabHTTPHeaders.EVENT: event, - GitLabHTTPHeaders.TOKEN: webhook_secret, + GitLabHTTPHeaders.TOKEN: token, }, data=data, format="json", @@ -117,7 +113,7 @@ def test_job_event_commit_not_found(self): def test_job_event_commit_not_complete(self): commit_sha = "2293ada6b400935a1378653304eaf6221e0fdb8f" - commit = CommitFactory( + CommitFactory( author=self.repo.author, repository=self.repo, commitid=commit_sha, @@ -255,9 +251,7 @@ def test_merge_request_event_action_update(self, pulls_sync_mock): def test_handle_system_hook_not_enterprise(self): mock_config_helper(self.mocker, configs={"setup.enterprise_license": None}) - username = "jsmith" - project_id = 74 - owner = OwnerFactory(service="gitlab_enterprise", username=username) + owner = OwnerFactory(service="gitlab_enterprise", username="jsmith") response = self._post_event_data( event=GitLabWebhookEvents.SYSTEM, @@ -269,56 +263,88 @@ def test_handle_system_hook_not_enterprise(self): "owner_email": "johnsmith@gmail.com", "owner_name": "John Smith", "path": "storecloud", - "path_with_namespace": f"{username}/storecloud", - "project_id": project_id, + "path_with_namespace": f"{owner.username}/storecloud", + "project_id": 74, "project_visibility": "private", }, ) assert response.status_code == status.HTTP_403_FORBIDDEN - assert response.data.get("detail") == "No enterprise license detected" new_repo = Repository.objects.filter( - author__ownerid=owner.ownerid, service_id=project_id + author__ownerid=owner.ownerid, service_id=74 ).first() assert new_repo is None - def test_handle_system_hook_project_create(self): - username = "jsmith" - project_id = 74 - owner = OwnerFactory(service="gitlab_enterprise", username=username) + @patch("services.refresh.RefreshService.trigger_refresh") + def test_handle_system_hook_project_create(self, mock_refresh_task): + sample_payload_from_gitlab_docs = { + "created_at": "2012-07-21T07:30:54Z", + "updated_at": "2012-07-21T07:38:22Z", + "event_name": "project_create", + "name": "StoreCloud", + "owner_email": "johnsmith@example.com", + "owner_name": "John Smith", + "owners": [{"name": "John", "email": "user1@example.com"}], + "path": "storecloud", + "path_with_namespace": "jsmith/storecloud", + "project_id": 74, + "project_visibility": "private", + } + + owner = OwnerFactory( + service="gitlab_enterprise", + username="jsmith", + name=sample_payload_from_gitlab_docs["owner_name"], + email=sample_payload_from_gitlab_docs["owner_email"], + oauth_token="123", + ) response = self._post_event_data( event=GitLabWebhookEvents.SYSTEM, - data={ - "created_at": "2020-01-21T07:30:54Z", - "updated_at": "2020-01-21T07:38:22Z", - "event_name": "project_create", - "name": "StoreCloud", - "owner_email": "johnsmith@gmail.com", - "owner_name": "John Smith", - "path": "storecloud", - "path_with_namespace": f"{username}/storecloud", - "project_id": project_id, - "project_visibility": "private", - }, + data=sample_payload_from_gitlab_docs, ) assert response.status_code == status.HTTP_200_OK - assert response.data == "Repository created" + assert response.data == "Sync initiated" + + mock_refresh_task.assert_called_once_with( + ownerid=owner.ownerid, + username=owner.username, + using_integration=False, + manual_trigger=False, + ) + + @patch("services.refresh.RefreshService.trigger_refresh") + def test_handle_system_hook_project_destroy(self, mock_refresh_task): + sample_payload_from_gitlab_docs = { + "created_at": "2012-07-21T07:30:58Z", + "updated_at": "2012-07-21T07:38:22Z", + "event_name": "project_destroy", + "name": "Underscore", + "owner_email": "johnsmith@example.com", + "owner_name": "John Smith", + "owners": [{"name": "John", "email": "user1@example.com"}], + "path": "underscore", + "path_with_namespace": "jsmith/underscore", + "project_id": 73, + "project_visibility": "internal", + } + + OwnerFactory( + service="gitlab_enterprise", + username="jsmith", + name=sample_payload_from_gitlab_docs["owner_name"], + email=sample_payload_from_gitlab_docs["owner_email"], + oauth_token="123", + ) - new_repo = Repository.objects.get( - author__ownerid=owner.ownerid, service_id=project_id + owner_org = OwnerFactory( + service="gitlab_enterprise", + oauth_token=None, ) - assert new_repo is not None - assert new_repo.private is True - assert new_repo.name == "storecloud" - def test_handle_system_hook_project_destroy(self): - username = "jsmith" - project_id = 73 - owner = OwnerFactory(service="gitlab_enterprise", username=username) repo = RepositoryFactory( - author=owner, - service_id=project_id, + author=owner_org, + service_id=sample_payload_from_gitlab_docs["project_id"], active=True, activated=True, deleted=False, @@ -326,301 +352,452 @@ def test_handle_system_hook_project_destroy(self): response = self._post_event_data( event=GitLabWebhookEvents.SYSTEM, - data={ - "created_at": "2020-01-21T07:30:58Z", - "updated_at": "2020-01-21T07:38:22Z", - "event_name": "project_destroy", - "name": "Underscore", - "owner_email": "johnsmith@gmail.com", - "owner_name": "John Smith", - "path": "underscore", - "path_with_namespace": f"{username}/underscore", - "project_id": project_id, - "project_visibility": "internal", - }, + data=sample_payload_from_gitlab_docs, ) assert response.status_code == status.HTTP_200_OK assert response.data == "Repository deleted" + mock_refresh_task.assert_not_called() + repo.refresh_from_db() assert repo.active is False assert repo.activated is False assert repo.deleted is True - def test_handle_system_hook_project_rename(self): - username = "jsmith" - project_id = 73 - owner = OwnerFactory(service="gitlab_enterprise", username=username) - repo = RepositoryFactory( - author=owner, - service_id=project_id, - name="overscore", + @patch("services.refresh.RefreshService.trigger_refresh") + def test_handle_system_hook_project_rename(self, mock_refresh_task): + # testing get owner by namespace in payload + sample_payload_from_gitlab_docs = { + "created_at": "2012-07-21T07:30:58Z", + "updated_at": "2012-07-21T07:38:22Z", + "event_name": "project_rename", + "name": "Underscore", + "path": "underscore", + "path_with_namespace": "jsmith/underscore", + "project_id": 73, + "owner_name": "John Smith", + "owner_email": "johnsmith@example.com", + "owners": [{"name": "John", "email": "user1@example.com"}], + "project_visibility": "internal", + "old_path_with_namespace": "jsmith/overscore", + } + + OwnerFactory( + service="gitlab_enterprise", + oauth_token="123", + username="jsmith", + ) + + owner_bot = OwnerFactory( + service="gitlab_enterprise", + oauth_token="123", + ) + + owner_org = OwnerFactory( + service="gitlab_enterprise", + oauth_token=None, + ) + + RepositoryFactory( + author=owner_org, + service_id=sample_payload_from_gitlab_docs["project_id"], active=True, activated=True, deleted=False, + bot=owner_bot, ) response = self._post_event_data( event=GitLabWebhookEvents.SYSTEM, - data={ - "created_at": "2020-01-21T07:30:58Z", - "updated_at": "2020-01-21T07:38:22Z", - "event_name": "project_rename", - "name": "Underscore", - "path": "underscore", - "path_with_namespace": f"{username}/underscore", - "project_id": 73, - "owner_name": "John Smith", - "owner_email": "johnsmith@gmail.com", - "project_visibility": "internal", - "old_path_with_namespace": "jsmith/overscore", - }, + data=sample_payload_from_gitlab_docs, ) assert response.status_code == status.HTTP_200_OK - assert response.data == "Repository renamed" - - repo.refresh_from_db() - assert repo.name == "underscore" + assert response.data == "Sync initiated" + + mock_refresh_task.assert_called_once_with( + ownerid=owner_bot.ownerid, + username=owner_bot.username, + using_integration=False, + manual_trigger=False, + ) + + @patch("services.refresh.RefreshService.trigger_refresh") + def test_handle_system_hook_project_transfer(self, mock_refresh_task): + # moving this repo from one namespace to another + sample_payload_from_gitlab_docs = { + "created_at": "2012-07-21T07:30:58Z", + "updated_at": "2012-07-21T07:38:22Z", + "event_name": "project_transfer", + "name": "Underscore", + "path": "underscore", + "path_with_namespace": "scores/underscore", + "project_id": 73, + "owner_name": "John Smith", + "owner_email": "johnsmith@example.com", + "owners": [{"name": "John", "email": "user1@example.com"}], + "project_visibility": "internal", + "old_path_with_namespace": "jsmith/overscore", + } + + owner_user = OwnerFactory( + service="gitlab_enterprise", + name=sample_payload_from_gitlab_docs["owner_name"], + email=sample_payload_from_gitlab_docs["owner_email"], + oauth_token="123", + ) - def test_handle_system_hook_project_transfer(self): - old_owner_username = "jsmith" - new_owner_username = "scores" - project_id = 73 - new_owner = OwnerFactory( - service="gitlab_enterprise", username=new_owner_username + non_usable_bot = OwnerFactory( + service="gitlab_enterprise", + oauth_token=None, ) - old_owner = OwnerFactory( - service="gitlab_enterprise", username=old_owner_username + + owner_org = OwnerFactory( + service="gitlab_enterprise", + oauth_token=None, + username="jsmith", ) - repo = RepositoryFactory( - author=old_owner, - service_id=project_id, - name="overscore", + + RepositoryFactory( + author=owner_org, + service_id=sample_payload_from_gitlab_docs["project_id"], active=True, activated=True, deleted=False, + bot=non_usable_bot, ) response = self._post_event_data( event=GitLabWebhookEvents.SYSTEM, - data={ - "created_at": "2020-01-21T07:30:58Z", - "updated_at": "2020-01-21T07:38:22Z", - "event_name": "project_transfer", - "name": "Underscore", - "path": "underscore", - "path_with_namespace": f"{new_owner_username}/underscore", - "project_id": project_id, - "owner_name": "John Smith", - "owner_email": "johnsmith@gmail.com", - "project_visibility": "internal", - "old_path_with_namespace": f"{old_owner_username}/overscore", - }, + data=sample_payload_from_gitlab_docs, ) assert response.status_code == status.HTTP_200_OK - assert response.data == "Repository transfered" + assert response.data == "Sync initiated" + + mock_refresh_task.assert_called_once_with( + ownerid=owner_user.ownerid, + username=owner_user.username, + using_integration=False, + manual_trigger=False, + ) + + @patch("services.refresh.RefreshService.trigger_refresh") + def test_handle_system_hook_user_create(self, mock_refresh_task): + sample_payload_from_gitlab_docs = { + "created_at": "2012-07-21T07:44:07Z", + "updated_at": "2012-07-21T07:38:22Z", + "email": "js@gitlabhq.com", + "event_name": "user_create", + "name": "John Smith", + "username": "js", + "user_id": 41, + } + response = self._post_event_data( + event=GitLabWebhookEvents.SYSTEM, + data=sample_payload_from_gitlab_docs, + ) + assert response.status_code == status.HTTP_404_NOT_FOUND + mock_refresh_task.assert_not_called() + + @patch("services.refresh.RefreshService.trigger_refresh") + def test_handle_system_hook_user_add_to_team(self, mock_refresh_task): + sample_payload_from_gitlab_docs = { + "created_at": "2012-07-21T07:30:56Z", + "updated_at": "2012-07-21T07:38:22Z", + "event_name": "user_add_to_team", + "access_level": "Maintainer", + "project_id": 74, + "project_name": "StoreCloud", + "project_path": "storecloud", + "project_path_with_namespace": "jsmith/storecloud", + "user_email": "johnsmith@example.com", + "user_name": "John Smith", + "user_username": "johnsmith", + "user_id": 41, + "project_visibility": "private", + } + + owner_user = OwnerFactory( + service="gitlab_enterprise", + name=sample_payload_from_gitlab_docs["user_name"], + email=sample_payload_from_gitlab_docs["user_email"], + oauth_token="123", + username=sample_payload_from_gitlab_docs["user_username"], + service_id=sample_payload_from_gitlab_docs["user_id"], + ) - repo.refresh_from_db() - assert repo.name == "underscore" - assert repo.author == new_owner + owner_org = OwnerFactory( + service="gitlab_enterprise", + oauth_token=None, + username="jsmith", + ) + + RepositoryFactory( + author=owner_org, + service_id=sample_payload_from_gitlab_docs["project_id"], + active=True, + activated=True, + deleted=False, + ) - def test_handle_system_hook_user_create(self): - gl_user_id = 41 response = self._post_event_data( event=GitLabWebhookEvents.SYSTEM, - data={ - "created_at": "2012-07-21T07:44:07Z", - "updated_at": "2012-07-21T07:38:22Z", - "email": "js@gitlabhq.com", - "event_name": "user_create", - "name": "John Smith", - "username": "js", - "user_id": gl_user_id, - }, + data=sample_payload_from_gitlab_docs, ) assert response.status_code == status.HTTP_200_OK - assert response.data == "User created" - - new_user = Owner.objects.get(service="gitlab_enterprise", service_id=gl_user_id) - assert new_user.name == "John Smith" - assert new_user.email == "js@gitlabhq.com" - assert new_user.username == "js" - - def test_handle_system_hook_user_add_to_team_no_existing_permissions(self): - gl_user_id = 41 - project_id = 74 - username = "johnsmith" - user = OwnerFactory( + assert response.data == "Sync initiated" + + mock_refresh_task.assert_called_once_with( + ownerid=owner_user.ownerid, + username=owner_user.username, + using_integration=False, + manual_trigger=False, + ) + + @patch("services.refresh.RefreshService.trigger_refresh") + def test_handle_system_hook_user_add_to_team_repo_public(self, mock_refresh_task): + sample_payload_from_gitlab_docs = { + "created_at": "2012-07-21T07:30:56Z", + "updated_at": "2012-07-21T07:38:22Z", + "event_name": "user_add_to_team", + "access_level": "Maintainer", + "project_id": 74, + "project_name": "StoreCloud", + "project_path": "storecloud", + "project_path_with_namespace": "jsmith/storecloud", + "user_email": "johnsmith@example.com", + "user_name": "John Smith", + "user_username": "johnsmith", + "user_id": 41, + "project_visibility": "public", + } + + OwnerFactory( service="gitlab_enterprise", - service_id=gl_user_id, - username=username, - permission=None, + name=sample_payload_from_gitlab_docs["user_name"], + email=sample_payload_from_gitlab_docs["user_email"], + oauth_token="123", + username=sample_payload_from_gitlab_docs["user_username"], + service_id=sample_payload_from_gitlab_docs["user_id"], ) - repo = RepositoryFactory( - author=user, - service_id=project_id, + + owner_org = OwnerFactory( + service="gitlab_enterprise", + oauth_token=None, + username="jsmith", + ) + + RepositoryFactory( + author=owner_org, + service_id=sample_payload_from_gitlab_docs["project_id"], active=True, activated=True, deleted=False, ) + response = self._post_event_data( event=GitLabWebhookEvents.SYSTEM, - data={ - "created_at": "2012-07-21T07:30:56Z", - "updated_at": "2012-07-21T07:38:22Z", - "event_name": "user_add_to_team", - "access_level": "Maintainer", - "project_id": project_id, - "project_name": "StoreCloud", - "project_path": "storecloud", - "project_path_with_namespace": "jsmith/storecloud", - "user_email": "johnsmith@gmail.com", - "user_name": "John Smith", - "user_username": username, - "user_id": gl_user_id, - "project_visibility": "private", - }, + data=sample_payload_from_gitlab_docs, ) assert response.status_code == status.HTTP_200_OK - assert response.data == "Permission added" + assert response.data is None - user.refresh_from_db() - assert user.permission == [repo.repoid] + mock_refresh_task.assert_not_called() + + @patch("services.refresh.RefreshService.trigger_refresh") + def test_handle_system_hook_user_remove_from_team(self, mock_refresh_task): + sample_payload_from_gitlab_docs = { + "created_at": "2012-07-21T07:30:56Z", + "updated_at": "2012-07-21T07:38:22Z", + "event_name": "user_remove_from_team", + "access_level": "Maintainer", + "project_id": 74, + "project_name": "StoreCloud", + "project_path": "storecloud", + "project_path_with_namespace": "jsmith/storecloud", + "user_email": "johnsmith@example.com", + "user_name": "John Smith", + "user_username": "johnsmith", + "user_id": 41, + "project_visibility": "private", + } + + owner_user = OwnerFactory( + service="gitlab_enterprise", + name=sample_payload_from_gitlab_docs["user_name"], + email=sample_payload_from_gitlab_docs["user_email"], + oauth_token="123", + username=sample_payload_from_gitlab_docs["user_username"], + service_id=sample_payload_from_gitlab_docs["user_id"], + ) - def test_handle_system_hook_user_add_to_team(self): - gl_user_id = 41 - project_id = 74 - username = "johnsmith" - user = OwnerFactory( + owner_org = OwnerFactory( service="gitlab_enterprise", - service_id=gl_user_id, - username="johnsmith", - permission=[1, 2, 3, 100], + oauth_token=None, + username="jsmith", ) - repo = RepositoryFactory( - author=user, - service_id=project_id, + + RepositoryFactory( + author=owner_org, + service_id=sample_payload_from_gitlab_docs["project_id"], active=True, activated=True, deleted=False, ) + response = self._post_event_data( event=GitLabWebhookEvents.SYSTEM, - data={ - "created_at": "2012-07-21T07:30:56Z", - "updated_at": "2012-07-21T07:38:22Z", - "event_name": "user_add_to_team", - "access_level": "Maintainer", - "project_id": project_id, - "project_name": "StoreCloud", - "project_path": "storecloud", - "project_path_with_namespace": "jsmith/storecloud", - "user_email": "johnsmith@gmail.com", - "user_name": "John Smith", - "user_username": username, - "user_id": gl_user_id, - "project_visibility": "private", - }, + data=sample_payload_from_gitlab_docs, ) assert response.status_code == status.HTTP_200_OK - assert response.data == "Permission added" - - user.refresh_from_db() - assert len(user.permission) == 5 - assert repo.repoid in user.permission + assert response.data == "Sync initiated" + + mock_refresh_task.assert_called_once_with( + ownerid=owner_user.ownerid, + username=owner_user.username, + using_integration=False, + manual_trigger=False, + ) + + @patch("services.refresh.RefreshService.trigger_refresh") + def test_handle_system_hook_unknown_repo(self, mock_refresh_task): + sample_payload_from_gitlab_docs = { + "created_at": "2012-07-21T07:30:56Z", + "updated_at": "2012-07-21T07:38:22Z", + "event_name": "user_add_to_team", + "access_level": "Maintainer", + "project_id": 74, + "project_name": "StoreCloud", + "project_path": "storecloud", + "project_path_with_namespace": "jsmith/storecloud", + "user_email": "johnsmith@example.com", + "user_name": "John Smith", + "user_username": "johnsmith", + "user_id": 41, + "project_visibility": "private", + } + + OwnerFactory( + service="gitlab_enterprise", + name=sample_payload_from_gitlab_docs["user_name"], + email=sample_payload_from_gitlab_docs["user_email"], + oauth_token="123", + username=sample_payload_from_gitlab_docs["user_username"], + service_id=sample_payload_from_gitlab_docs["user_id"], + ) - def test_handle_system_hook_user_add_to_team_repo_public(self): - gl_user_id = 41 - project_id = 74 - username = "johnsmith" - user = OwnerFactory( + owner_org = OwnerFactory( service="gitlab_enterprise", - service_id=gl_user_id, - username=username, - permission=[1, 2, 3, 100], + oauth_token=None, + username="jsmith", ) - repo = RepositoryFactory( - author=user, - service_id=project_id, + + RepositoryFactory( + author=owner_org, + service_id=sample_payload_from_gitlab_docs["project_id"] + 1, active=True, activated=True, deleted=False, ) + response = self._post_event_data( event=GitLabWebhookEvents.SYSTEM, - data={ - "created_at": "2012-07-21T07:30:56Z", - "updated_at": "2012-07-21T07:38:22Z", - "event_name": "user_add_to_team", - "access_level": "Maintainer", - "project_id": project_id, - "project_name": "StoreCloud", - "project_path": "storecloud", - "project_path_with_namespace": "jsmith/storecloud", - "user_email": "johnsmith@gmail.com", - "user_name": "John Smith", - "user_username": username, - "user_id": gl_user_id, - "project_visibility": "public", - }, + data=sample_payload_from_gitlab_docs, ) - assert response.status_code == status.HTTP_200_OK - assert response.data is None + assert response.status_code == status.HTTP_404_NOT_FOUND - user.refresh_from_db() + @patch("services.refresh.RefreshService.trigger_refresh") + def test_handle_system_hook_user_add_to_team_unknown_user(self, mock_refresh_task): + sample_payload_from_gitlab_docs = { + "created_at": "2012-07-21T07:30:56Z", + "updated_at": "2012-07-21T07:38:22Z", + "event_name": "user_add_to_team", + "access_level": "Maintainer", + "project_id": 74, + "project_name": "StoreCloud", + "project_path": "storecloud", + "project_path_with_namespace": "jsmith/storecloud", + "user_email": "johnsmith@example.com", + "user_name": "John Smith", + "user_username": "johnsmith", + "user_id": 41, + "project_visibility": "private", + } + + owner_org = OwnerFactory( + service="gitlab_enterprise", + oauth_token=None, + username="jsmith", + ) - assert user.permission == [1, 2, 3, 100] # no change + RepositoryFactory( + author=owner_org, + service_id=sample_payload_from_gitlab_docs["project_id"], + active=True, + activated=True, + deleted=False, + ) - def test_handle_system_hook_user_remove_from_team(self): - gl_user_id = 41 - project_id = 74 - username = "johnsmith" - user = OwnerFactory( + response = self._post_event_data( + event=GitLabWebhookEvents.SYSTEM, + data=sample_payload_from_gitlab_docs, + ) + assert response.status_code == status.HTTP_200_OK + assert response.data == "Sync initiated" + + mock_refresh_task.assert_not_called() + + @patch("services.refresh.RefreshService.trigger_refresh") + def test_handle_system_hook_no_bot_or_user_match(self, mock_refresh_task): + sample_payload_from_gitlab_docs = { + "created_at": "2012-07-21T07:30:58Z", + "updated_at": "2012-07-21T07:38:22Z", + "event_name": "project_rename", + "name": "Underscore", + "path": "underscore", + "path_with_namespace": "jsmith/underscore", + "project_id": 73, + "owner_name": "John Smith", + "owner_email": "johnsmith@example.com", + "owners": [{"name": "John", "email": "user1@example.com"}], + "project_visibility": "internal", + "old_path_with_namespace": "jsmith/overscore", + } + + OwnerFactory( service="gitlab_enterprise", - service_id=gl_user_id, - username=username, - permission=None, + name=sample_payload_from_gitlab_docs["owner_name"], + oauth_token="123", ) - repo = RepositoryFactory( - author=user, - service_id=project_id, + + owner_org = OwnerFactory( + service="gitlab_enterprise", + oauth_token=None, + username="jsmith", + ) + + RepositoryFactory( + author=owner_org, + service_id=sample_payload_from_gitlab_docs["project_id"], active=True, activated=True, deleted=False, ) - user.permission = [1, 2, 3, repo.repoid] - user.save() response = self._post_event_data( event=GitLabWebhookEvents.SYSTEM, - data={ - "created_at": "2012-07-21T07:30:56Z", - "updated_at": "2012-07-21T07:38:22Z", - "event_name": "user_remove_from_team", - "access_level": "Maintainer", - "project_id": project_id, - "project_name": "StoreCloud", - "project_path": "storecloud", - "project_path_with_namespace": "jsmith/storecloud", - "user_email": "johnsmith@gmail.com", - "user_name": "John Smith", - "user_username": username, - "user_id": gl_user_id, - "project_visibility": "private", - }, + data=sample_payload_from_gitlab_docs, ) assert response.status_code == status.HTTP_200_OK - assert response.data == "Permission removed" + assert response.data == "Sync initiated" - user.refresh_from_db() - assert user.permission == [1, 2, 3] + mock_refresh_task.assert_not_called() def test_secret_validation(self): owner = OwnerFactory(service="gitlab_enterprise") repo = RepositoryFactory( author=owner, service_id=uuid.uuid4(), - webhook_secret=uuid.uuid4(), + webhook_secret=uuid.uuid4(), # if repo has webhook secret, requires validation ) owner.permission = [repo.repoid] owner.save() @@ -650,3 +827,62 @@ def test_secret_validation(self): format="json", ) assert response.status_code == status.HTTP_200_OK + + def test_secret_validation_required_by_config(self): + webhook_secret = uuid.uuid4() + # if repo has webhook_validation config set to True, requires validation + mock_config_helper( + self.mocker, + configs={ + "gitlab_enterprise.webhook_validation": True, + }, + ) + owner = OwnerFactory(service="gitlab_enterprise") + repo = RepositoryFactory( + author=owner, + service_id=uuid.uuid4(), + webhook_secret=None, + ) + owner.permission = [repo.repoid] + owner.save() + + response = self.client.post( + reverse("gitlab_enterprise-webhook"), + **{ + GitLabHTTPHeaders.EVENT: "", + GitLabHTTPHeaders.TOKEN: "", + }, + data={ + "project_id": repo.service_id, + }, + format="json", + ) + assert response.status_code == status.HTTP_403_FORBIDDEN + + response = self.client.post( + reverse("gitlab_enterprise-webhook"), + **{ + GitLabHTTPHeaders.EVENT: "", + GitLabHTTPHeaders.TOKEN: webhook_secret, + }, + data={ + "project_id": repo.service_id, + }, + format="json", + ) + assert response.status_code == status.HTTP_403_FORBIDDEN + + repo.webhook_secret = webhook_secret + repo.save() + response = self.client.post( + reverse("gitlab_enterprise-webhook"), + **{ + GitLabHTTPHeaders.EVENT: "", + GitLabHTTPHeaders.TOKEN: webhook_secret, + }, + data={ + "project_id": repo.service_id, + }, + format="json", + ) + assert response.status_code == status.HTTP_200_OK diff --git a/webhook_handlers/urls.py b/webhook_handlers/urls.py index 914ca44796..ff73bb7a64 100644 --- a/webhook_handlers/urls.py +++ b/webhook_handlers/urls.py @@ -1,4 +1,4 @@ -from django.urls import include, path +from django.urls import path # to remove when in production we send the webhooks to /billing/stripe/webhooks from billing.views import StripeWebhookHandler diff --git a/webhook_handlers/views/bitbucket.py b/webhook_handlers/views/bitbucket.py index 1d76848eba..bd6d2d33b5 100644 --- a/webhook_handlers/views/bitbucket.py +++ b/webhook_handlers/views/bitbucket.py @@ -1,13 +1,11 @@ import logging from django.shortcuts import get_object_or_404 -from rest_framework import status from rest_framework.permissions import AllowAny from rest_framework.response import Response from rest_framework.views import APIView from shared.helpers.yaml import walk -from codecov_auth.models import Owner from core.models import Branch, Commit, Pull, PullStates, Repository from services.task import TaskService from webhook_handlers.constants import ( diff --git a/webhook_handlers/views/bitbucket_server.py b/webhook_handlers/views/bitbucket_server.py index 364413b90a..72c58a833a 100644 --- a/webhook_handlers/views/bitbucket_server.py +++ b/webhook_handlers/views/bitbucket_server.py @@ -1,14 +1,11 @@ import logging from django.shortcuts import get_object_or_404 -from rest_framework import status from rest_framework.permissions import AllowAny from rest_framework.response import Response from rest_framework.views import APIView -from shared.helpers.yaml import walk -from codecov_auth.models import Owner -from core.models import Branch, Commit, Pull, PullStates, Repository +from core.models import Branch, Pull, PullStates, Repository from services.task import TaskService from webhook_handlers.constants import ( BitbucketServerHTTPHeaders, diff --git a/webhook_handlers/views/github.py b/webhook_handlers/views/github.py index 2e5bc34b3e..0e5eaac4a0 100644 --- a/webhook_handlers/views/github.py +++ b/webhook_handlers/views/github.py @@ -3,7 +3,7 @@ import re from contextlib import suppress from hashlib import sha1, sha256 -from typing import Optional, Union +from typing import Optional from django.utils import timezone from django.utils.crypto import constant_time_compare @@ -20,7 +20,6 @@ Owner, ) from core.models import Branch, Commit, Pull, Repository -from services.archive import ArchiveService from services.billing import BillingService from services.redis_configuration import get_redis_connection from services.task import TaskService @@ -202,7 +201,7 @@ def repository(self, request, *args, **kwargs): return Response() def delete(self, request, *args, **kwargs): - ref_type = request.data.get("ref_type") + ref_type = request.data.get("ref_type", "") _incr_event(GitHubWebhookEvents.DELETE + "." + ref_type) repo = self._get_repo(request) if ref_type != "branch": @@ -233,7 +232,7 @@ def public(self, request, *args, **kwargs): return Response() def push(self, request, *args, **kwargs): - ref_type = "branch" if request.data.get("ref")[5:10] == "heads" else "tag" + ref_type = "branch" if request.data.get("ref", "")[5:10] == "heads" else "tag" _incr_event(GitHubWebhookEvents.PUSH + "." + ref_type) repo = self._get_repo(request) if ref_type != "branch": @@ -438,9 +437,12 @@ def _handle_installation_repository_events(self, request, *args, **kwargs): owner, _ = Owner.objects.get_or_create( service=self.service_name, service_id=service_id, - username=username, - defaults={"createstamp": timezone.now()}, + defaults={ + "username": username, + "createstamp": timezone.now(), + }, ) + installation_id = request.data["installation"]["id"] ghapp_installation, _ = GithubAppInstallation.objects.get_or_create( @@ -479,8 +481,10 @@ def _handle_installation_events( owner, _ = Owner.objects.get_or_create( service=self.service_name, service_id=service_id, - username=username, - defaults={"createstamp": timezone.now()}, + defaults={ + "username": username, + "createstamp": timezone.now(), + }, ) installation_id = request.data["installation"]["id"] @@ -527,6 +531,19 @@ def _handle_installation_events( map(lambda obj: obj["id"], request.data.get("repositories", [])) ) ghapp_installation.repository_service_ids = repositories_service_ids + + if action in ["suspend", "unsuspend"]: + log.info( + "Request to suspend/unsuspend App", + extra=dict( + action=action, + is_currently_suspended=ghapp_installation.is_suspended, + ownerid=owner.ownerid, + installation_id=request.data["installation"]["id"], + ), + ) + ghapp_installation.is_suspended = action == "suspend" + ghapp_installation.save() # This flow is deprecated and should be removed once the @@ -537,18 +554,6 @@ def _handle_installation_events( owner.save() # Deprecated flow - END - # We need to understand if users are suspending / not-suspending apps - # and if this is related to RepositoryWithoutValidBot errors we see - if action in ["suspend", "unsuspend"]: - log.info( - "Request to suspend/unsuspend App", - extra=dict( - action=action, - ownerid=owner.ownerid, - installation_id=request.data["installation"]["id"], - ), - ) - log.info( "Triggering refresh task to sync repos", extra=dict(ownerid=owner.ownerid, github_webhook_event=self.event), @@ -729,7 +734,7 @@ def member(self, request, *args, **kwargs): def post(self, request, *args, **kwargs): self.event = self.request.META.get(GitHubHTTPHeaders.EVENT) - log.debug( + log.info( "GitHub Webhook Handler invoked", extra=dict( github_webhook_event=self.event, diff --git a/webhook_handlers/views/gitlab.py b/webhook_handlers/views/gitlab.py index 54135b2e52..a540437a56 100644 --- a/webhook_handlers/views/gitlab.py +++ b/webhook_handlers/views/gitlab.py @@ -3,14 +3,14 @@ from django.http import HttpRequest from django.shortcuts import get_object_or_404 from django.utils.crypto import constant_time_compare -from rest_framework import status from rest_framework.exceptions import PermissionDenied from rest_framework.permissions import AllowAny from rest_framework.response import Response from rest_framework.views import APIView from codecov_auth.models import Owner -from core.models import Branch, Commit, Pull, PullStates, Repository +from core.models import Commit, Pull, PullStates, Repository +from services.refresh import RefreshService from services.task import TaskService from utils.config import get_config from webhook_handlers.constants import ( @@ -27,22 +27,44 @@ class GitLabWebhookHandler(APIView): service_name = "gitlab" def post(self, request, *args, **kwargs): + """ + Helpful docs for working with GitLab webhooks + https://docs.gitlab.com/ee/user/project/integrations/webhooks.html#webhook-receiver-requirements + for those special system hooks: https://docs.gitlab.com/ee/administration/system_hooks.html#hooks-request-example + all the other hooks: https://docs.gitlab.com/ee/user/project/integrations/webhook_events.html + """ event = self.request.META.get(GitLabHTTPHeaders.EVENT) - repo = None log.info("GitLab webhook message received", extra=dict(event=event)) project_id = request.data.get("project_id") or request.data.get( "object_attributes", {} ).get("target_project_id") - repo = None - if project_id and request.data.get("event_name") != "project_create": - # make sure the repo exists in the repos table - repo = get_object_or_404( - Repository, author__service=self.service_name, service_id=project_id - ) - if repo is not None and repo.webhook_secret is not None: + event_name = self.request.data.get( + "event_name", self.request.data.get("object_kind") + ) + + is_enterprise = True if get_config("setup", "enterprise_license") else False + + # special case - only event that doesn't have a repo yet + if event_name == "project_create": + if event == GitLabWebhookEvents.SYSTEM and is_enterprise: + return self._handle_system_project_create_hook_event() + else: + raise PermissionDenied() + + # all other events should correspond to a repo in the db + repo = get_object_or_404( + Repository, author__service=self.service_name, service_id=project_id + ) + + webhook_validation = bool( + get_config( + self.service_name, "webhook_validation", default=False + ) # TODO: backfill migration then switch to True + ) + if webhook_validation or repo.webhook_secret: self._validate_secret(request, repo.webhook_secret) if event == GitLabWebhookEvents.PUSH: @@ -52,7 +74,10 @@ def post(self, request, *args, **kwargs): elif event == GitLabWebhookEvents.MERGE_REQUEST: return self._handle_merge_request_event(repo) elif event == GitLabWebhookEvents.SYSTEM: - return self._handle_system_hook_event(repo) + # SYSTEM events have always been gated behind is_enterprise, requires an enterprise_license + if not is_enterprise: + raise PermissionDenied() + return self._handle_system_hook_event(repo, event_name) return Response() @@ -129,39 +154,70 @@ def _handle_merge_request_event(self, repo): return Response(data=message) - def _handle_system_hook_event(self, repo): + def _initiate_sync_for_owner(self, owner): """ - GitLab Enterprise instance can send system hooks for changes on user, group, project, etc - - http://doc.gitlab.com/ee/system_hooks/system_hooks.html + default: will sync_teams and sync_repos for owner + sync_teams to update owner.organizations list (expired memberships are removed and new memberships are added), + and username, name, email, and avatar of each Org in owner.organizations. + sync_repos to update owner.permission list (private repo access), + and name, language, private, repoid, and deleted=False for each repo the owner has access to. """ - if not get_config("setup", "enterprise_license"): - raise PermissionDenied("No enterprise license detected") - - event_name = self.request.data.get("event_name") - message = None + RefreshService().trigger_refresh( + ownerid=owner.ownerid, + username=owner.username, + using_integration=False, + manual_trigger=False, + ) - if event_name == "project_create": - owner_username, repo_name = self.request.data.get( - "path_with_namespace" - ).split("/", 2) + def _try_initiate_sync_for_owner(self): + owner_email = self.request.data.get("owner_email") + # email is a strong identifier (GL users must have a unique email) + try: + owner = Owner.objects.get( + service=self.service_name, + oauth_token__isnull=False, + email=owner_email, + ) + except (Owner.DoesNotExist, Owner.MultipleObjectsReturned): + # could be the username of the OwnerUser or OwnerOrg. Sync only works with an OwnerUser. + owner_username_best_guess = self.request.data.get( + "path_with_namespace", "" + ).split("/")[0] try: owner = Owner.objects.get( - service=self.service_name, username=owner_username + service=self.service_name, + oauth_token__isnull=False, + username=owner_username_best_guess, ) + except (Owner.DoesNotExist, Owner.MultipleObjectsReturned): + return - obj, created = Repository.objects.get_or_create( - author=owner, - service_id=self.request.data.get("project_id"), - name=repo_name, - private=self.request.data.get("project_visibility") == "private", - ) - message = "Repository created" - except Owner.DoesNotExist: - message = "Repository not created - unknown owner" + self._initiate_sync_for_owner(owner) + + def _handle_system_project_create_hook_event(self): + self._try_initiate_sync_for_owner() + return Response(data="Sync initiated") + + def _try_initiate_sync_for_repo(self, repo): + # most GL repos have bots - try to sync with bot as Owner + if repo.bot: + bot_owner = Owner.objects.filter( + service=self.service_name, + ownerid=repo.bot.ownerid, + oauth_token__isnull=False, + ).first() + if bot_owner: + return self._initiate_sync_for_owner(owner=bot_owner) + self._try_initiate_sync_for_owner() - elif event_name == "project_destroy": + def _handle_system_hook_event(self, repo, event_name): + """ + GitLab Enterprise instance can send system hooks for changes on user, group, project, etc + """ + message = None + + if event_name == "project_destroy": repo.deleted = True repo.activated = False repo.active = False @@ -169,72 +225,33 @@ def _handle_system_hook_event(self, repo): repo.save(update_fields=["deleted", "activated", "active", "name"]) message = "Repository deleted" - elif event_name == "project_rename": - new_name = self.request.data.get("path_with_namespace").split("/")[-1] - repo.name = new_name - repo.save(update_fields=["name"]) - message = "Repository renamed" - - elif event_name == "project_transfer": - owner_username, repo_name = self.request.data.get( - "path_with_namespace" - ).split("/") - new_owner = Owner.objects.filter( - service=self.service_name, username=owner_username - ).first() - - if new_owner: - repo.author = new_owner - repo.name = repo_name - repo.save(update_fields=["author", "name"]) - message = "Repository transfered" - - elif event_name == "user_create": - obj, created = Owner.objects.update_or_create( - service=self.service_name, - service_id=self.request.data.get("user_id"), - username=self.request.data.get("username"), - email=self.request.data.get("email"), - name=self.request.data.get("name"), - ) - message = "User created" + elif event_name in ("project_rename", "project_transfer"): + self._try_initiate_sync_for_repo(repo=repo) + message = "Sync initiated" elif ( event_name in ("user_add_to_team", "user_remove_from_team") and self.request.data.get("project_visibility") == "private" ): + # the payload from these hooks includes the ownerid + ownerid = self.request.data.get("user_id") user = Owner.objects.filter( service=self.service_name, - service_id=self.request.data.get("user_id"), + service_id=ownerid, oauth_token__isnull=False, ).first() - + message = "Sync initiated" if user: - if event_name == "user_add_to_team": - user.permission = list( - set((user.permission or []) + [int(repo.repoid)]) - ) - user.save(update_fields=["permission"]) - message = "Permission added" - else: - new_permissions = set((user.permission or [])) - new_permissions.remove(int(repo.repoid)) - user.permission = list(new_permissions) - user.save(update_fields=["permission"]) - message = "Permission removed" - else: - message = "User not found or not active" + self._initiate_sync_for_owner(owner=user) return Response(data=message) def _validate_secret(self, request: HttpRequest, webhook_secret: str): - webhook_validation = bool( - get_config(self.service_name, "webhook_validation", default=False) - ) - if webhook_validation: - token = request.META.get(GitLabHTTPHeaders.TOKEN) - if not constant_time_compare(webhook_secret, token): - raise PermissionDenied() + token = request.META.get(GitLabHTTPHeaders.TOKEN) + if token and webhook_secret: + if constant_time_compare(webhook_secret, token): + return + raise PermissionDenied() class GitLabEnterpriseWebhookHandler(GitLabWebhookHandler):