Skip to content

Commit

Permalink
refactor tests and move update_institiutions to it's own file
Browse files Browse the repository at this point in the history
  • Loading branch information
John Tordoff committed Jul 8, 2024
1 parent 268d5c6 commit 164aa66
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 46 deletions.
2 changes: 1 addition & 1 deletion api/draft_registrations/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from api.nodes.serializers import (
DraftRegistrationLegacySerializer,
DraftRegistrationDetailLegacySerializer,
update_institutions,
get_license_details,
NodeSerializer,
NodeLicenseSerializer,
Expand All @@ -18,6 +17,7 @@
NodeContributorDetailSerializer,
RegistrationSchemaRelationshipField,
)
from api.institutions.utils import update_institutions
from api.taxonomies.serializers import TaxonomizableSerializerMixin
from osf.exceptions import DraftRegistrationStateError
from osf.models import Node
Expand Down
38 changes: 38 additions & 0 deletions api/institutions/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from rest_framework import exceptions
from osf.models import Institution
from osf.utils import permissions as osf_permissions
from api.base.serializers import relationship_diff


def get_institutions_to_add_remove(institutions, new_institutions):
diff = relationship_diff(
current_items={inst._id: inst for inst in institutions.all()},
new_items={inst['_id']: inst for inst in new_institutions},
)

insts_to_add = []
for inst_id in diff['add']:
inst = Institution.load(inst_id)
if not inst:
raise exceptions.NotFound(detail=f'Institution with id "{inst_id}" was not found')
insts_to_add.append(inst)

return insts_to_add, diff['remove'].values()


def update_institutions(node, new_institutions, user, post=False):
add, remove = get_institutions_to_add_remove(
institutions=node.affiliated_institutions,
new_institutions=new_institutions,
)

if not post:
for inst in remove:
if not user.is_affiliated_with_institution(inst) and not node.has_permission(user, osf_permissions.ADMIN):
raise exceptions.PermissionDenied(detail=f'User needs to be affiliated with {inst.name}')
node.remove_affiliated_institution(inst, user)

for inst in add:
if not user.is_affiliated_with_institution(inst):
raise exceptions.PermissionDenied(detail=f'User needs to be affiliated with {inst.name}',)
node.add_affiliated_institution(inst, user)
43 changes: 3 additions & 40 deletions api/nodes/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
JSONAPISerializer, LinksField,
NodeFileHyperLinkField, RelationshipField,
ShowIfVersion, TargetTypeField, TypeField,
WaterbutlerLink, relationship_diff, BaseAPISerializer,
WaterbutlerLink, BaseAPISerializer,
HideIfWikiDisabled, ShowIfAdminScopeOrAnonymous,
ValuesListField, TargetField,
)
Expand All @@ -20,6 +20,7 @@
get_user_auth, is_truthy,
)
from api.base.versioning import get_kebab_snake_case_field
from api.institutions.utils import update_institutions
from api.taxonomies.serializers import TaxonomizableSerializerMixin
from django.apps import apps
from django.conf import settings
Expand All @@ -33,7 +34,7 @@
from addons.osfstorage.models import Region
from osf.exceptions import NodeStateError
from osf.models import (
Comment, DraftRegistration, ExternalAccount, Institution,
Comment, DraftRegistration, ExternalAccount,
RegistrationSchema, AbstractNode, PrivateLink, Preprint,
RegistrationProvider, OSFGroup, NodeLicense, DraftNode,
Registration, Node,
Expand All @@ -51,44 +52,6 @@ def to_internal_value(self, data):
return self.get_object(data)


def get_institutions_to_add_remove(institutions, new_institutions):
diff = relationship_diff(
current_items={inst._id: inst for inst in institutions.all()},
new_items={inst['_id']: inst for inst in new_institutions},
)

insts_to_add = []
for inst_id in diff['add']:
inst = Institution.load(inst_id)
if not inst:
raise exceptions.NotFound(detail='Institution with id "{}" was not found'.format(inst_id))
insts_to_add.append(inst)

return insts_to_add, diff['remove'].values()


def update_institutions(node, new_institutions, user, post=False):
add, remove = get_institutions_to_add_remove(
institutions=node.affiliated_institutions,
new_institutions=new_institutions,
)

if not post:
for inst in remove:
if not user.is_affiliated_with_institution(inst) and not node.has_permission(user, osf_permissions.ADMIN):
raise exceptions.PermissionDenied(
detail='User needs to be affiliated with {}'.format(inst.name),
)
node.remove_affiliated_institution(inst, user)

for inst in add:
if not user.is_affiliated_with_institution(inst):
raise exceptions.PermissionDenied(
detail='User needs to be affiliated with {}'.format(inst.name),
)
node.add_affiliated_institution(inst, user)


class RegionRelationshipField(RelationshipField):

def to_internal_value(self, data):
Expand Down
2 changes: 1 addition & 1 deletion api/preprints/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,5 @@
re_path(r'^(?P<preprint_id>\w+)/requests/$', views.PreprintRequestListCreate.as_view(), name=views.PreprintRequestListCreate.view_name),
re_path(r'^(?P<preprint_id>\w+)/subjects/$', views.PreprintSubjectsList.as_view(), name=views.PreprintSubjectsList.view_name),
re_path(r'^(?P<preprint_id>\w+)/institutions/$', views.PreprintInstitutionsList.as_view(), name=views.PreprintInstitutionsList.view_name),
re_path(r'^(?P<preprint_id>\w+)/relationships/institutions/$', views.PreprintInstitutionsRelationshipList.as_view(), name=views.PreprintInstitutionsRelationshipList.view_name),
re_path(r'^(?P<preprint_id>\w+)/relationships/institutions/$', views.PreprintInstitutionsRelationship.as_view(), name=views.PreprintInstitutionsRelationship.view_name),
]
2 changes: 1 addition & 1 deletion api/preprints/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@ def get_queryset(self):
return self.get_resource().affiliated_institutions.all()


class PreprintInstitutionsRelationshipList(JSONAPIBaseView, generics.RetrieveUpdateDestroyAPIView, generics.CreateAPIView, PreprintMixin):
class PreprintInstitutionsRelationship(JSONAPIBaseView, generics.RetrieveUpdateDestroyAPIView, generics.CreateAPIView, PreprintMixin):
""" """
permission_classes = (
drf_permissions.IsAuthenticatedOrReadOnly,
Expand Down
4 changes: 2 additions & 2 deletions api/registrations/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
NodeStorageProviderSerializer,
NodeLicenseRelationshipField,
NodeLinksSerializer,
update_institutions,
NodeLicenseSerializer,
NodeContributorsSerializer,
RegistrationProviderRelationshipField,
Expand All @@ -31,13 +30,14 @@
ShowIfVersion, VersionedDateTimeField, ValuesListField,
HideIfWithdrawalOrWikiDisabled,
)
from api.institutions.utils import update_institutions

from framework.auth.core import Auth
from osf.exceptions import NodeStateError
from osf.models import Node
from osf.utils.registrations import strip_registered_meta_comments
from osf.utils.workflows import ApprovalStates


class RegistrationSerializer(NodeSerializer):
admin_only_editable_fields = [
'custom_citation',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


@pytest.mark.django_db
class TestPreprintInstitutionsList:
class TestPreprintInstitutionsRelationship:

@pytest.fixture()
def user(self):
Expand All @@ -34,6 +34,10 @@ def admin_without_institutional_affilation(self, institution, preprint):
preprint.add_permission(user, 'admin')
return user

@pytest.fixture()
def institutions(self):
return [InstitutionFactory() for _ in range(3)]

@pytest.fixture()
def institution(self):
return InstitutionFactory()
Expand Down Expand Up @@ -161,3 +165,90 @@ def test_preprint_institutions_list_get(self, app, user, admin_with_institutiona

assert res.json['data'][0]['id'] == institution._id
assert res.json['data'][0]['type'] == 'institutions'

def test_post_affiliated_institutions(self, app, user, admin_with_institutional_affilation, preprint, url,
institutions, institution):
add_institutions_payload = {
'data': [{'type': 'institutions', 'id': institution._id} for institution in institutions]
}

res = app.post_json_api(
url,
add_institutions_payload,
auth=admin_with_institutional_affilation.auth,
expect_errors=True
)
assert res.status_code == 403 # Adding affilations you don't have

add_institutions_payload = {
'data': [{'type': 'institutions', 'id': institution._id}],
}

res = app.post_json_api(
url,
add_institutions_payload,
auth=admin_with_institutional_affilation.auth
)
assert res.status_code == 201

preprint.reload()
assert preprint.affiliated_institutions.all()[0] == institution

def test_delete_affiliated_institution(self, app, user, admin_with_institutional_affilation, admin_without_institutional_affilation, preprint, url,
institution):

preprint.affiliated_institutions.add(institution)
preprint.save()

res = app.delete_json_api(
url,
{'data': [{'type': 'institutions', 'id': institution._id}]},
auth=admin_with_institutional_affilation.auth
)
assert res.status_code == 204

preprint.reload()
assert institution not in preprint.affiliated_institutions.all()

def test_complex_institutional_affiliations(self, app, user, admin_with_institutional_affilation, admin_without_institutional_affilation, preprint, url,
institutions):
# Add multiple institutions
add_institutions_payload = {
'data': [{'type': 'institutions', 'id': institution._id} for institution in institutions]
}

res = app.post_json_api(
url,
add_institutions_payload,
auth=admin_with_institutional_affilation.auth,
expect_errors=True
)
assert res.status_code == 403 # Adding affilations you don't have

preprint.reload()
assert len(preprint.affiliated_institutions.all()) == 0

# add one institution
remove_institution_payload = {
'data': [{'type': 'institutions', 'id': institutions[0]._id}]
}

res = app.put_json_api(
url,
remove_institution_payload,
auth=admin_with_institutional_affilation.auth
)
assert res.status_code == 201

preprint.reload()
assert len(preprint.affiliated_institutions.all()) == 1
assert len(preprint.affiliated_institutions.all()) == 1

# Check user affiliations
other_user = AuthUserFactory()
other_user.add_or_update_affiliated_institution(institutions[1])
other_user.add_or_update_affiliated_institution(institutions[2])

res = app.get(url, auth=other_user.auth, expect_errors=True)
assert res.status_code == 200
assert len(res.json['data']) == 2

0 comments on commit 164aa66

Please sign in to comment.