Skip to content

Commit

Permalink
[change] Implemented-suggested-changes openwisp#523
Browse files Browse the repository at this point in the history
-Added Validation for email
-If validation is failed we try to get email from attributes
-Added tests to see if Exception is raised when invalid mail is provided

Fixes openwisp#523
  • Loading branch information
kaushikaryan04 committed Aug 8, 2024
1 parent 5d9faeb commit 124db52
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 10 deletions.
13 changes: 13 additions & 0 deletions openwisp_radius/saml/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,16 @@ def get_url_or_path(url):
if parsed_url.netloc:
return f'{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path}'
return parsed_url.path


def get_email_from_ava(ava):
email_keys = (
'email',
'mail',
'uid',
)
for key in email_keys:
email = ava.get(key, None)
if email is not None:
return email[0]
return None
32 changes: 24 additions & 8 deletions openwisp_radius/saml/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import swapper
from allauth.account.models import EmailAddress
from allauth.utils import ValidationError
from django.conf import settings
from django.contrib.auth import logout
from django.core.exceptions import ObjectDoesNotExist, PermissionDenied
Expand All @@ -17,7 +18,7 @@
from .. import settings as app_settings
from ..api.views import RadiusTokenMixin
from ..utils import get_organization_radius_settings, load_model
from .utils import get_url_or_path
from .utils import get_email_from_ava, get_url_or_path

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -74,15 +75,30 @@ def post_login_hook(self, request, user, session_info):
if uid_is_email:
email = session_info['name_id'].text
if email is None:
email = session_info['ava'].get('email', [None])[0]
email = get_email_from_ava(session_info['ava'])
if email:
user.email = email
user.save()
email_address = EmailAddress.objects.create(
user=user, email=email, verified=True, primary=True
)
email_address.save()

try:
user.full_clean()
user.save()
EmailAddress.objects.create(
user=user, email=email, verified=True, primary=True
)
except ValidationError:
assertion_email = get_email_from_ava(session_info['ava'])
if assertion_email and assertion_email != email:
user.email = assertion_email
try:
user.full_clean()
user.save()
EmailAddress.objects.create(
user=user,
email=assertion_email,
verified=True,
primary=True,
)
except ValidationError:
raise ValidationError('Email Verification Failed')
registered_user = RegisteredUser(
user=user, method='saml', is_verified=app_settings.SAML_IS_VERIFIED
)
Expand Down
23 changes: 21 additions & 2 deletions openwisp_radius/tests/test_saml/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import swapper
from allauth.account.models import EmailAddress
from django.contrib.auth import SESSION_KEY, get_user_model
from django.core.validators import ValidationError
from django.test import TestCase, override_settings
from django.urls import reverse
from djangosaml2.tests import auth_response, conf
Expand Down Expand Up @@ -59,12 +60,12 @@ class TestAssertionConsumerServiceView(TestSamlMixin, TestCase):
def _get_relay_state(self, redirect_url, org_slug):
return f'{redirect_url}?org={org_slug}'

def _get_saml_response_for_acs_view(self, relay_state):
def _get_saml_response_for_acs_view(self, relay_state, uid='org_user@example.com'):
response = self.client.get(self.login_url, {'RelayState': relay_state})
saml2_req = saml2_from_httpredirect_request(response.url)
session_id = get_session_id_from_saml2(saml2_req)
self.add_outstanding_query(session_id, relay_state)
return auth_response(session_id, 'org_user@example.com'), relay_state
return auth_response(session_id, uid), relay_state

def _post_successful_auth_assertions(self, query_params, org_slug):
self.assertEqual(User.objects.count(), 1)
Expand Down Expand Up @@ -103,6 +104,24 @@ def test_organization_slug_present(self):
query_params = parse_qs(urlparse(response.url).query)
self._post_successful_auth_assertions(query_params, org_slug)

@capture_any_output()
def test_invalid_email_raise_validation_error(self):
invalid_email = 'invalid_email@example'
relay_state = self._get_relay_state(
redirect_url='https://captive-portal.example.com', org_slug='default'
)
saml_response, relay_state = self._get_saml_response_for_acs_view(
relay_state, uid=invalid_email
)
with self.assertRaises(ValidationError):
self.client.post(
reverse('radius:saml2_acs'),
{
'SAMLResponse': self.b64_for_post(saml_response),
'RelayState': relay_state,
},
)

@capture_any_output()
def test_relay_state_relative_path(self):
expected_redirect_path = '/captive/portal/page'
Expand Down

0 comments on commit 124db52

Please sign in to comment.