Skip to content

Commit

Permalink
chore(integrations): oauth2 identity pipeline metrics round 2 (#80216)
Browse files Browse the repository at this point in the history
  • Loading branch information
cathteng authored Nov 11, 2024
1 parent c3b32a7 commit 482dbd6
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 72 deletions.
167 changes: 104 additions & 63 deletions src/sentry/identity/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@
from requests.exceptions import SSLError

from sentry.auth.exceptions import IdentityNotValid
from sentry.exceptions import NotRegistered
from sentry.http import safe_urlopen, safe_urlread
from sentry.integrations.base import IntegrationDomain
from sentry.integrations.utils.metrics import (
IntegrationPipelineViewEvent,
IntegrationPipelineViewType,
)
from sentry.pipeline import PipelineView
from sentry.shared_integrations.exceptions import ApiError
from sentry.utils.http import absolute_uri
Expand All @@ -23,6 +29,7 @@

logger = logging.getLogger(__name__)
ERR_INVALID_STATE = "An error occurred while validating your request."
ERR_TOKEN_RETRIEVAL = "Failed to retrieve token from the upstream service."


class OAuth2Provider(Provider):
Expand Down Expand Up @@ -207,6 +214,19 @@ def refresh_identity(self, identity, *args, **kwargs):
from rest_framework.request import Request


def record_event(event: IntegrationPipelineViewType, provider: str):
from sentry.identity import default_manager as identity_manager

try:
identity_manager.get(provider)
except NotRegistered:
logger.exception("oauth2.record_event.invalid_provider", extra={"provider": provider})

return IntegrationPipelineViewEvent(
event, domain=IntegrationDomain.IDENTITY, provider_key=provider
)


class OAuth2LoginView(PipelineView):
authorize_url = None
client_id = None
Expand Down Expand Up @@ -238,22 +258,23 @@ def get_authorize_params(self, state, redirect_uri):

@method_decorator(csrf_exempt)
def dispatch(self, request: Request, pipeline) -> HttpResponse:
for param in ("code", "error", "state"):
if param in request.GET:
return pipeline.next_step()
with record_event(IntegrationPipelineViewType.OAUTH_LOGIN, pipeline.provider.key).capture():
for param in ("code", "error", "state"):
if param in request.GET:
return pipeline.next_step()

state = secrets.token_hex()
state = secrets.token_hex()

params = self.get_authorize_params(
state=state, redirect_uri=absolute_uri(pipeline.redirect_url())
)
redirect_uri = f"{self.get_authorize_url()}?{urlencode(params)}"
params = self.get_authorize_params(
state=state, redirect_uri=absolute_uri(pipeline.redirect_url())
)
redirect_uri = f"{self.get_authorize_url()}?{urlencode(params)}"

pipeline.bind_state("state", state)
if request.subdomain:
pipeline.bind_state("subdomain", request.subdomain)
pipeline.bind_state("state", state)
if request.subdomain:
pipeline.bind_state("subdomain", request.subdomain)

return self.redirect(redirect_uri)
return self.redirect(redirect_uri)


class OAuth2CallbackView(PipelineView):
Expand All @@ -280,70 +301,90 @@ def get_token_params(self, code, redirect_uri):
}

def exchange_token(self, request: Request, pipeline, code):
# TODO: this needs the auth yet
data = self.get_token_params(code=code, redirect_uri=absolute_uri(pipeline.redirect_url()))
verify_ssl = pipeline.config.get("verify_ssl", True)
try:
req = safe_urlopen(self.access_token_url, data=data, verify_ssl=verify_ssl)
body = safe_urlread(req)
if req.headers.get("Content-Type", "").startswith("application/x-www-form-urlencoded"):
return dict(parse_qsl(body))
return orjson.loads(body)
except SSLError:
logger.info(
"identity.oauth2.ssl-error",
extra={"url": self.access_token_url, "verify_ssl": verify_ssl},
with record_event(
IntegrationPipelineViewType.TOKEN_EXCHANGE, pipeline.provider.key
).capture() as lifecycle:
# TODO: this needs the auth yet
data = self.get_token_params(
code=code, redirect_uri=absolute_uri(pipeline.redirect_url())
)
url = self.access_token_url
return {
"error": "Could not verify SSL certificate",
"error_description": f"Ensure that {url} has a valid SSL certificate",
}
except ConnectionError:
url = self.access_token_url
logger.info("identity.oauth2.connection-error", extra={"url": url})
return {
"error": "Could not connect to host or service",
"error_description": f"Ensure that {url} is open to connections",
}
except orjson.JSONDecodeError:
logger.info("identity.oauth2.json-error", extra={"url": self.access_token_url})
return {
"error": "Could not decode a JSON Response",
"error_description": "We were not able to parse a JSON response, please try again.",
}
verify_ssl = pipeline.config.get("verify_ssl", True)
try:
req = safe_urlopen(self.access_token_url, data=data, verify_ssl=verify_ssl)
body = safe_urlread(req)
if req.headers.get("Content-Type", "").startswith(
"application/x-www-form-urlencoded"
):
return dict(parse_qsl(body))
return orjson.loads(body)
except SSLError:
logger.info(
"identity.oauth2.ssl-error",
extra={"url": self.access_token_url, "verify_ssl": verify_ssl},
)
lifecycle.record_failure({"failure_reason": "ssl_error"})
url = self.access_token_url
return {
"error": "Could not verify SSL certificate",
"error_description": f"Ensure that {url} has a valid SSL certificate",
}
except ConnectionError:
url = self.access_token_url
logger.info("identity.oauth2.connection-error", extra={"url": url})
lifecycle.record_failure({"failure_reason": "connection_error"})
return {
"error": "Could not connect to host or service",
"error_description": f"Ensure that {url} is open to connections",
}
except orjson.JSONDecodeError:
logger.info("identity.oauth2.json-error", extra={"url": self.access_token_url})
lifecycle.record_failure({"failure_reason": "json_error"})
return {
"error": "Could not decode a JSON Response",
"error_description": "We were not able to parse a JSON response, please try again.",
}

def dispatch(self, request: Request, pipeline) -> HttpResponse:
error = request.GET.get("error")
state = request.GET.get("state")
code = request.GET.get("code")

if error:
pipeline.logger.info("identity.token-exchange-error", extra={"error": error})
return pipeline.error(ERR_INVALID_STATE)
with record_event(
IntegrationPipelineViewType.OAUTH_CALLBACK, pipeline.provider.key
).capture() as lifecycle:
error = request.GET.get("error")
state = request.GET.get("state")
code = request.GET.get("code")

if error:
pipeline.logger.info("identity.token-exchange-error", extra={"error": error})
lifecycle.record_failure(
{"failure_reason": "token_exchange_error", "msg": ERR_INVALID_STATE}
)
return pipeline.error(ERR_INVALID_STATE)

if state != pipeline.fetch_state("state"):
pipeline.logger.info(
"identity.token-exchange-error",
extra={
"error": "invalid_state",
"state": state,
"pipeline_state": pipeline.fetch_state("state"),
"code": code,
},
)
return pipeline.error(ERR_INVALID_STATE)
if state != pipeline.fetch_state("state"):
pipeline.logger.info(
"identity.token-exchange-error",
extra={
"error": "invalid_state",
"state": state,
"pipeline_state": pipeline.fetch_state("state"),
"code": code,
},
)
lifecycle.record_failure(
{"failure_reason": "token_exchange_error", "msg": ERR_INVALID_STATE}
)
return pipeline.error(ERR_INVALID_STATE)

# separate lifecycle event inside exchange_token
data = self.exchange_token(request, pipeline, code)

# these errors are based off of the results of exchange_token, lifecycle errors are captured inside
if "error_description" in data:
error = data.get("error")
pipeline.logger.info("identity.token-exchange-error", extra={"error": error})
return pipeline.error(data["error_description"])

if "error" in data:
pipeline.logger.info("identity.token-exchange-error", extra={"error": data["error"]})
return pipeline.error("Failed to retrieve token from the upstream service.")
return pipeline.error(ERR_TOKEN_RETRIEVAL)

# we can either expect the API to be implicit and say "im looking for
# blah within state data" or we need to pass implementation + call a
Expand Down
10 changes: 6 additions & 4 deletions src/sentry/integrations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ class IntegrationProviderSlug(StrEnum):
GITHUB_ENTERPRISE = "github_enterprise"
GITLAB = "gitlab"
BITBUCKET = "bitbucket"
BITBUCKET_SERVER = "bitbucket_server"
PAGERDUTY = "pagerduty"
OPSGENIE = "opsgenie"

Expand All @@ -159,16 +160,13 @@ class IntegrationProviderSlug(StrEnum):
IntegrationDomain.PROJECT_MANAGEMENT: [
IntegrationProviderSlug.JIRA,
IntegrationProviderSlug.JIRA_SERVER,
IntegrationProviderSlug.GITHUB,
IntegrationProviderSlug.GITHUB_ENTERPRISE,
IntegrationProviderSlug.GITLAB,
IntegrationProviderSlug.AZURE_DEVOPS,
],
IntegrationDomain.SOURCE_CODE_MANAGEMENT: [
IntegrationProviderSlug.GITHUB,
IntegrationProviderSlug.GITHUB_ENTERPRISE,
IntegrationProviderSlug.GITLAB,
IntegrationProviderSlug.BITBUCKET,
IntegrationProviderSlug.BITBUCKET_SERVER,
IntegrationProviderSlug.AZURE_DEVOPS,
],
IntegrationDomain.ON_CALL_SCHEDULING: [
Expand All @@ -177,6 +175,10 @@ class IntegrationProviderSlug(StrEnum):
],
}

INTEGRATION_PROVIDER_TO_TYPE = {
v: k for k, values in INTEGRATION_TYPE_TO_PROVIDER.items() for v in values
}


class IntegrationProvider(PipelineProvider, abc.ABC):
"""
Expand Down
1 change: 1 addition & 0 deletions src/sentry/integrations/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ class IntegrationPipelineViewType(Enum):
# IdentityProviderPipeline
IDENTITY_LOGIN = "IDENTITY_LOGIN"
IDENTITY_LINK = "IDENTITY_LINK"
TOKEN_EXCHANGE = "TOKEN_EXCHANGE"

# GitHub
OAUTH_LOGIN = "OAUTH_LOGIN"
Expand Down
40 changes: 35 additions & 5 deletions tests/sentry/identity/test_oauth2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections import namedtuple
from functools import cached_property
from unittest.mock import patch
from urllib.parse import parse_qs, parse_qsl, urlparse

import responses
Expand All @@ -10,13 +11,16 @@
from sentry.identity.oauth2 import OAuth2CallbackView, OAuth2LoginView
from sentry.identity.pipeline import IdentityProviderPipeline
from sentry.identity.providers.dummy import DummyProvider
from sentry.integrations.utils.metrics import EventLifecycleOutcome
from sentry.testutils.cases import TestCase
from sentry.testutils.silo import control_silo_test

MockResponse = namedtuple("MockResponse", ["headers", "content"])


@control_silo_test
@patch("sentry.integrations.base.INTEGRATION_PROVIDER_TO_TYPE", return_value={"dummy": "dummy"})
@patch("sentry.integrations.utils.metrics.EventLifecycle.record_event")
class OAuth2CallbackViewTest(TestCase):
def setUp(self):
sentry.identity.register(DummyProvider)
Expand All @@ -36,8 +40,18 @@ def view(self):
client_secret="secret-value",
)

def assert_failure_metric(self, mock_record, error_msg):
(event_failures,) = (
call for call in mock_record.mock_calls if call.args[0] == EventLifecycleOutcome.FAILURE
)
assert event_failures.args[1]["failure_reason"] == error_msg

@responses.activate
def test_exchange_token_success(self):
def test_exchange_token_success(
self,
mock_record,
mock_integration_const,
):
responses.add(
responses.POST, "https://example.org/oauth/token", json={"token": "a-fake-token"}
)
Expand All @@ -59,8 +73,13 @@ def test_exchange_token_success(self):
"redirect_uri": "http://testserver/extensions/default/setup/",
}

assert len(mock_record.mock_calls) == 2
start, success = mock_record.mock_calls
assert start.args[0] == EventLifecycleOutcome.STARTED
assert success.args[0] == EventLifecycleOutcome.SUCCESS

@responses.activate
def test_exchange_token_success_customer_domains(self):
def test_exchange_token_success_customer_domains(self, mock_record, mock_integration_const):
responses.add(
responses.POST, "https://example.org/oauth/token", json={"token": "a-fake-token"}
)
Expand All @@ -82,8 +101,13 @@ def test_exchange_token_success_customer_domains(self):
"redirect_uri": "http://testserver/extensions/default/setup/",
}

assert len(mock_record.mock_calls) == 2
start, success = mock_record.mock_calls
assert start.args[0] == EventLifecycleOutcome.STARTED
assert success.args[0] == EventLifecycleOutcome.SUCCESS

@responses.activate
def test_exchange_token_ssl_error(self):
def test_exchange_token_ssl_error(self, mock_record, mock_integration_const):
def ssl_error(request):
raise SSLError("Could not build connection")

Expand All @@ -98,8 +122,10 @@ def ssl_error(request):
assert "error_description" in result
assert "SSL" in result["error_description"]

self.assert_failure_metric(mock_record, "ssl_error")

@responses.activate
def test_connection_error(self):
def test_connection_error(self, mock_record, mock_integration_const):
def connection_error(request):
raise ConnectionError("Name or service not known")

Expand All @@ -114,8 +140,10 @@ def connection_error(request):
assert "connect" in result["error"]
assert "error_description" in result

self.assert_failure_metric(mock_record, "connection_error")

@responses.activate
def test_exchange_token_no_json(self):
def test_exchange_token_no_json(self, mock_record, mock_integration_const):
responses.add(responses.POST, "https://example.org/oauth/token", body="")
pipeline = IdentityProviderPipeline(request=self.request, provider_key="dummy")
code = "auth-code"
Expand All @@ -125,6 +153,8 @@ def test_exchange_token_no_json(self):
assert "error_description" in result
assert "JSON" in result["error_description"]

self.assert_failure_metric(mock_record, "json_error")


@control_silo_test
class OAuth2LoginViewTest(TestCase):
Expand Down

0 comments on commit 482dbd6

Please sign in to comment.