diff --git a/server/utils/email.py b/server/utils/email.py index a88bc2b..5b3353b 100644 --- a/server/utils/email.py +++ b/server/utils/email.py @@ -18,11 +18,17 @@ class Domains: aliases: tuple[str, ...] +def subdomain_of(domain: str, parent: str) -> bool: + """Check if the `domain` is a subdomain of `parent`.""" + return domain == parent or domain.endswith(f".{parent}") + + def normalize_email( address: str, tag: str | None = "+", dots: bool = True, domains: Domains | None = None, + allow_subdomains: bool = True, ) -> str: """ Normalize an email address. @@ -32,6 +38,7 @@ def normalize_email( - If provided, remove the `tag` character (+) and everything after it - If requested, remove dots (.) from the address - If provided, replace the domain with the primary domain if it is an alias + - If requested, remove subdomains from the domain You must have previously validated the email address. @@ -45,8 +52,13 @@ def normalize_email( local = local.split(tag, maxsplit=1)[0] if dots: local = local.replace(".", "") - if domains and domain in domains.aliases: - domain = domains.primary + if domains: + if allow_subdomains and any( + subdomain_of(domain, d) for d in [domains.primary, *domains.aliases] + ): + domain = domains.primary + elif not allow_subdomains and domain in domains.aliases: + domain = domains.primary # FORCE ascii for now (yes, this is absurd). local = local.encode("ascii", "ignore").decode("ascii") domain = domain.encode("ascii", "ignore").decode("ascii") diff --git a/server/utils/test_email.py b/server/utils/test_email.py index d7dbe6e..6187264 100644 --- a/server/utils/test_email.py +++ b/server/utils/test_email.py @@ -77,3 +77,43 @@ def test_force_ascii(self): expected = "tst@xample.com" result = e.normalize_email(email) self.assertEqual(result, expected) + + def test_allow_subdomains(self): + """Test an email address with a subdomain.""" + email = "test@subdomain.example.com" + expected = "test@example.com" + domains = e.Domains("example.com", ()) + result = e.normalize_email(email, domains=domains, allow_subdomains=True) + self.assertEqual(result, expected) + + def test_allow_submdomains_invalid(self): + """Test an email address with a non-subdomain.""" + email = "test@subdomainexample.com" + expected = "test@subdomainexample.com" + domains = e.Domains("example.com", ()) + result = e.normalize_email(email, domains=domains, allow_subdomains=True) + self.assertEqual(result, expected) + + def test_allow_subdomains_aliases(self): + """Test an email address with domain aliases and subdomains.""" + email = "test@party.example.edu" + expected = "test@example.com" + domains = e.Domains("example.com", ("example.edu",)) + result = e.normalize_email(email, domains=domains, allow_subdomains=True) + self.assertEqual(result, expected) + + def test_disallow_subdomains(self): + """Test an email address with a subdomain.""" + email = "test@subdomain.example.com" + expected = "test@subdomain.example.com" + domains = e.Domains("example.com", ()) + result = e.normalize_email(email, domains=domains, allow_subdomains=False) + self.assertEqual(result, expected) + + def test_disallow_subdomains_aliases(self): + """Test an email address with domain aliases and subdomains.""" + email = "test@party.example.edu" + expected = "test@party.example.edu" + domains = e.Domains("example.com", ("example.edu",)) + result = e.normalize_email(email, domains=domains, allow_subdomains=False) + self.assertEqual(result, expected) diff --git a/server/vb/migrations/0013_add_subdomains_flag.py b/server/vb/migrations/0013_add_subdomains_flag.py new file mode 100644 index 0000000..957097e --- /dev/null +++ b/server/vb/migrations/0013_add_subdomains_flag.py @@ -0,0 +1,18 @@ +# Generated by Django 5.0.3 on 2024-05-17 17:20 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('vb', '0012_add_percent_voted'), + ] + + operations = [ + migrations.AddField( + model_name='school', + name='allow_subdomains', + field=models.BooleanField(default=True, help_text='Whether to allow arbitrary subdomains in school emails.'), + ), + ] diff --git a/server/vb/models.py b/server/vb/models.py index 2e871ae..81577f5 100644 --- a/server/vb/models.py +++ b/server/vb/models.py @@ -51,6 +51,10 @@ class School(models.Model): default=True, help_text="Whether to remove dots from the local part of school emails.", ) # noqa + allow_subdomains = models.BooleanField( + default=True, + help_text="Whether to allow arbitrary subdomains in school emails.", + ) logo: "Logo" contests: "ContestManager" @@ -64,6 +68,7 @@ def normalize_email(self, address: str) -> str: tag=self.mail_tag if self.mail_tag else None, dots=self.mail_dots, domains=domains, + allow_subdomains=self.allow_subdomains, ) def hash_email(self, address: str) -> str: