Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
felixrindt committed Nov 25, 2024
1 parent 45ee25b commit 91fb9c0
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 55 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
# Generated by Django 5.0.8 on 2024-11-22 09:25

from django.db import migrations, models
from django.db.models import Value
from django.db.models.functions import Concat, Random, Round


def set_idp_internal_names(apps, schema_editor):
IdentityProvider = apps.get_model("core", "IdentityProvider")
for idp in IdentityProvider.objects.all():
idp.internal_name = idp.label
idp.save()


class Migration(migrations.Migration):
Expand All @@ -16,10 +21,9 @@ class Migration(migrations.Migration):
model_name="identityprovider",
name="internal_name",
field=models.CharField(
db_default=Concat(Value("Identity Provider "), (Round(Random() * Value(10000)))),
help_text="Internal name for this provider.",
max_length=255,
unique=True,
unique=False,
verbose_name="internal name",
),
),
Expand All @@ -38,6 +42,7 @@ class Migration(migrations.Migration):
name="qualification_codename_to_uuid",
field=models.JSONField(
default=dict,
blank=True,
help_text="A json encoded dictionary containing mappings of qualification names as they appear in thequalification claim to the qualification uuid. If a key is not found, use the key directly.",
verbose_name="qualification codename to uuid",
),
Expand All @@ -54,4 +59,5 @@ class Migration(migrations.Migration):
),
preserve_default=False,
),
migrations.RunPython(set_idp_internal_names),
]
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
# Generated by Django 5.0.8 on 2024-11-23 12:48
# Generated by Django 5.0.9 on 2024-11-24 21:52

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
("auth", "0012_alter_user_first_name_max_length"),
("core", "0030_identityprovider_internal_name_and_more"),
]

Expand All @@ -20,16 +19,6 @@ class Migration(migrations.Migration):
verbose_name="externally managed",
),
),
migrations.AlterField(
model_name="identityprovider",
name="default_groups",
field=models.ManyToManyField(
blank=True,
help_text="The groups that users logging in with this provider will be added to. ",
to="auth.group",
verbose_name="default groups",
),
),
migrations.AlterField(
model_name="identityprovider",
name="internal_name",
Expand Down
1 change: 1 addition & 0 deletions ephios/core/models/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,7 @@ class IdentityProvider(Model):
"qualification claim to the qualification uuid. If a key is not found, use the key directly."
),
default=dict,
blank=True,
)

def __str__(self):
Expand Down
78 changes: 39 additions & 39 deletions ephios/extra/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,50 +49,50 @@ def update_user(self, user, claims):
except ValueError:
pass
user.save()
self._update_user_groups(user, claims)
self._update_user_qualifications(user, claims)
oidc_update_user.send(self, user=user, claims=claims, provider=self.provider)
return user

if self.provider.group_claim:
groups = set(self.provider.default_groups.all())
groups_in_claims = dotted_get(claims, self.provider.group_claim, [])
for group_name in groups_in_claims:
try:
groups.add(Group.objects.get(name__iexact=group_name))
except Group.DoesNotExist:
if self.provider.create_missing_groups:
groups.add(Group.objects.create(name=group_name))
user.groups.set(groups)
else:
user.groups.add(*self.provider.default_groups.all())

if self.provider.qualification_claim:
target_qualification_uuids = []
for codename in dotted_get(claims, self.provider.qualification_claim, []):
try:
target_qualification_uuids.append(
uuid.UUID(
str(
self.provider.qualification_codename_to_uuid.get(codename, codename)
)
)
def _update_user_qualifications(self, user, claims):
if not self.provider.qualification_claim:
return
target_qualification_uuids = []
for codename in dotted_get(claims, self.provider.qualification_claim, []):
try:
target_qualification_uuids.append(
uuid.UUID(
str(self.provider.qualification_codename_to_uuid.get(codename, codename))
)
except ValueError:
pass
)
except ValueError:
pass

target_qualifications = Qualification.objects.filter(
uuid__in=target_qualification_uuids
)
QualificationGrant.objects.filter(
target_qualifications = Qualification.objects.filter(uuid__in=target_qualification_uuids)
QualificationGrant.objects.filter(
user=user,
externally_managed=True,
).exclude(qualification__in=target_qualifications).delete()
for qualification in target_qualifications:
QualificationGrant.objects.get_or_create(
defaults={"expires": None, "externally_managed": True},
user=user,
externally_managed=True,
).exclude(qualification__in=target_qualifications).delete()
for qualification in target_qualifications:
QualificationGrant.objects.get_or_create(
defaults={"expires": None, "externally_managed": True},
user=user,
qualification=qualification,
)
qualification=qualification,
)

oidc_update_user.send(self, user=user, claims=claims, provider=self.provider)
return user
def _update_user_groups(self, user, claims):
if not self.provider.group_claim:
user.groups.add(*self.provider.default_groups.all())
return
groups = set(self.provider.default_groups.all())
groups_in_claims = dotted_get(claims, self.provider.group_claim, [])
for group_name in groups_in_claims:
try:
groups.add(Group.objects.get(name__iexact=group_name))
except Group.DoesNotExist:
if self.provider.create_missing_groups:
groups.add(Group.objects.create(name=group_name))
user.groups.set(groups)

def authenticate(self, request, username=None, password=None, **kwargs):
try:
Expand Down

0 comments on commit 91fb9c0

Please sign in to comment.