Skip to content

Commit

Permalink
Backend Security Audit fixes (#352)
Browse files Browse the repository at this point in the history
  • Loading branch information
oudeismetis authored Sep 12, 2024
1 parent 7c23c32 commit 8ac3f59
Show file tree
Hide file tree
Showing 13 changed files with 134 additions and 48 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,4 @@ jobs:
mkdir -p client/dist/static
pipenv run python server/manage.py collectstatic
pipenv run pytest --mccabe --cov=my_project -vv server/my_project
pipenv run coverage report --fail-under=20
pipenv run coverage report --fail-under=60
2 changes: 1 addition & 1 deletion {{cookiecutter.project_slug}}/pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
DJANGO_SETTINGS_MODULE = {{ cookiecutter.project_slug }}.test_settings
python_files = tests.py test_*.py *_tests.py
addopts = --strict-markers --no-migrations
mccabe-complexity=10
mccabe-complexity=8
filterwarnings =
ignore::DeprecationWarning

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ class Meta:
abstract = True

def __str__(self):
return "ah yes"
return "__str__ not defined for this model"
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class CustomUserAdmin(UserAdmin):
)
},
),
("Admin Options", {"classes": ("collapse",), "fields": ("is_staff", "groups")}),
("Admin Options", {"classes": ("collapse",), "fields": ("is_active", "is_staff", "is_superuser", "groups")}),
)
add_fieldsets = (
(
Expand All @@ -35,7 +35,15 @@ class CustomUserAdmin(UserAdmin):
},
),
)
list_display = ("email", "first_name", "last_name", "is_active", "is_staff", "is_superuser", "permissions")
list_display = (
"email",
"permissions",
"is_active",
"is_staff",
"is_superuser",
"first_name",
"last_name",
)
list_display_links = (
"is_active",
"email",
Expand All @@ -57,7 +65,7 @@ class CustomUserAdmin(UserAdmin):
ordering = []

def permissions(self, obj):
return ", ".join([g.name for g in obj.groups.all()])
return [g.name for g in obj.groups.all()]

class Media(AutocompleteAdminMedia):
pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
class UserFactory(factory.Factory):
email = factory.faker.Faker("email")
password = factory.PostGenerationMethodCall("set_password", "password")
first_name = factory.faker.Faker("first_name")
last_name = factory.faker.Faker("last_name")

class Meta:
model = User
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,26 @@
logger = logging.getLogger(__name__)


class UserQuerySet(models.QuerySet):
def for_user(self, user):
if not user or user.is_anonymous:
return self.none()
elif user.is_staff:
return self.all()
return self.filter(pk=user.pk)


class UserManager(BaseUserManager):
"""Custom User model manager, eliminating the 'username' field."""

use_in_migrations = True

def get_queryset(self):
return UserQuerySet(self.model, using=self.db)

def for_user(self, user):
return self.get_queryset().for_user(user)

def _create_user(self, email, password, **extra_fields):
"""
Create and save a User with the given email and password.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from rest_framework import permissions


class CreateOnlyPermissions(permissions.BasePermission):
def has_permission(self, request, view):
if view.action == "create":
return True
return False
class HasUserPermissions(permissions.BasePermission):
"""Admins should be able to perform any action, regular users should be able to edit and delete self."""

def has_object_permission(self, request, view, obj):
return request.user.is_authenticated and (request.user.is_staff or obj == request.user)
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class Meta:
"last_name",
"full_name",
)
read_only_fields = ["email"]


class UserLoginSerializer(serializers.ModelSerializer):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from django.contrib.auth import authenticate
from django.test import override_settings
from django.test.client import RequestFactory
from rest_framework import status
from rest_framework.response import Response

from .models import User
Expand Down Expand Up @@ -34,6 +35,13 @@ def test_create_user():
assert not user.is_superuser


@pytest.mark.django_db
def test_create_user_api(api_client):
data = {"email": "example@example.com", "password": "password", "first_name": "Test", "last_name": "User"}
res = api_client.post("/api/users/", data, format="json")
assert res.status_code == status.HTTP_201_CREATED, res.data


@pytest.mark.django_db
def test_create_superuser():
superuser = User.objects.create_superuser(email="test@example.com", password="password", first_name="Leslie", last_name="Burke")
Expand All @@ -50,7 +58,59 @@ def test_create_user_from_factory(sample_user):
@pytest.mark.django_db
def test_user_can_login(api_client, sample_user):
res = api_client.post("/api/login/", {"email": sample_user.email, "password": "password"}, format="json")
assert res.status_code == 200
assert res.status_code == status.HTTP_200_OK


@pytest.mark.django_db
def test_wrong_email(api_client, sample_user):
res = api_client.post("/api/login/", {"email": "wrong@example.com", "password": "password"}, format="json")
assert res.status_code == status.HTTP_400_BAD_REQUEST


@pytest.mark.django_db
def test_wrong_password(api_client, sample_user):
res = api_client.post("/api/login/", {"email": sample_user.email, "password": "wrong"}, format="json")
assert res.status_code == status.HTTP_400_BAD_REQUEST


@pytest.mark.django_db
def test_get_user(api_client, sample_user):
api_client.force_authenticate(sample_user)
res = api_client.get(f"/api/users/{sample_user.pk}/")
assert res.status_code == status.HTTP_200_OK
assert res.data["email"] == sample_user.email


@pytest.mark.django_db
def test_get_other_user(api_client, sample_user, user_factory):
api_client.force_authenticate(sample_user)
other_user = user_factory()
other_user.save()
res = api_client.get(f"/api/users/{other_user.pk}/")
assert res.status_code == status.HTTP_404_NOT_FOUND


@pytest.mark.django_db
def test_update_user(api_client, sample_user):
existing_email = sample_user.email
api_client.force_authenticate(sample_user)
data = {"email": "example@example.com", "password": "password", "first_name": "Test", "last_name": "User"}
res = api_client.put(f"/api/users/{sample_user.pk}/", data, format="json")
assert res.status_code == status.HTTP_200_OK
sample_user.refresh_from_db()
# Email should NOT have changed
assert sample_user.email == existing_email
assert sample_user.first_name == data["first_name"] == res.data["first_name"]
assert sample_user.last_name == data["last_name"] == res.data["last_name"]


@pytest.mark.django_db
def test_delete_user(api_client, sample_user):
api_client.force_authenticate(sample_user)
res = api_client.delete(f"/api/users/{sample_user.pk}/")
assert res.status_code == status.HTTP_204_NO_CONTENT
sample_user.refresh_from_db()
assert sample_user.is_active is False


@pytest.mark.use_requests
Expand All @@ -67,7 +127,7 @@ def test_password_reset(caplog, api_client, sample_user):

# Verify the link works for reseting the password
response = api_client.post(password_reset_url, data={"password": "new_password"}, format="json")
assert response.status_code == 200
assert response.status_code == status.HTTP_200_OK

# New Password should now work for authentication
serializer = UserLoginSerializer(data={"email": sample_user.email, "password": "new_password"})
Expand All @@ -87,7 +147,7 @@ class TestPreviewTemplateView:
@override_settings(DEBUG=False)
def test_disabled_if_not_debug(self, client):
response = client.post(self.url)
assert response.status_code == 404
assert response.status_code == status.HTTP_404_NOT_FOUND

@override_settings(DEBUG=True)
def test_enabled_if_debug(self, client):
Expand All @@ -98,19 +158,19 @@ def test_enabled_if_debug(self, client):
@override_settings(DEBUG=True)
def test_no_template_provided(self, client):
response = client.post(self.url, data={"_send_to": "someone@example.com"})
assert response.status_code == 400
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert any("You must provide a template name" in e for e in response.json())

@override_settings(DEBUG=True)
def test_invalid_template_provided(self, client):
response = client.post(f"{self.url}?template=SOME_TEMPLATE/WHICH_DOES_NOT/EXIST", data={"_send_to": "someone@example.com"})
assert response.status_code == 400
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert any("Invalid template name" in e for e in response.json())

@override_settings(DEBUG=True)
def test_missing_send_to(self, client):
response = client.post(f"{self.url}?template=SOME_TEMPLATE/WHICH_DOES_NOT/EXIST")
assert response.status_code == 400
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert "This field is required." in response.json()["_send_to"]

def test_parse_value_without_model(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
from rest_framework.exceptions import ValidationError
from rest_framework.response import Response

from {{ cookiecutter.project_slug }}.core.forms import PreviewTemplateForm
from {{ cookiecutter.project_slug }}.utils.emails import send_html_email

from .forms import PreviewTemplateForm
from .models import User
from .permissions import CreateOnlyPermissions
from .permissions import HasUserPermissions
from .serializers import UserLoginSerializer, UserRegistrationSerializer, UserSerializer

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -50,45 +50,45 @@ def post(self, request, *args, **kwargs):
return Response(response_data)


class UserViewSet(
viewsets.GenericViewSet,
mixins.RetrieveModelMixin,
mixins.ListModelMixin,
mixins.UpdateModelMixin,
mixins.DestroyModelMixin,
):
queryset = User.objects.all()
class UserViewSet(viewsets.GenericViewSet, mixins.RetrieveModelMixin):
queryset = User.objects
serializer_class = UserSerializer

# No auth required to create user
# Auth required for all other actions
permission_classes = (permissions.IsAuthenticated | CreateOnlyPermissions,)
permission_classes = (HasUserPermissions,)

@transaction.atomic
def create(self, request, *args, **kwargs):
def get_queryset(self):
"""
Endpoint to create/register a new user.
Users should only find themselves by default
"""
return super().get_queryset().for_user(self.request.user)

@transaction.atomic
def create(self, request, *args, **kwargs):
serializer = UserRegistrationSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
serializer.save() # This calls .create() on serializer
user = serializer.instance

user = serializer.save()
# Log-in user and re-serialize response
response_data = UserLoginSerializer.login(user, request)
return Response(response_data, status=status.HTTP_201_CREATED)

def update(self, request, *args, **kwargs):
"""
Endpoint to create/register a new user.
"""
serializer = UserSerializer(data=request.data, instance=self.get_object(), partial=True)

serializer.is_valid(raise_exception=True)
serializer.save()
user = serializer.data
return Response(serializer.data, status=status.HTTP_200_OK)

return Response(user, status=status.HTTP_200_OK)
def destroy(self, request, *args, **kwargs):
"""
When deleting a user's account, just disable their account first
The user may have a regret and try to get their account back
A background job should then properly delete the data after X days
"""
user = self.get_object()
user.is_active = False
user.save()
return Response(status=status.HTTP_204_NO_CONTENT)


@api_view(["post"])
Expand Down Expand Up @@ -117,7 +117,7 @@ def request_reset_link(request, *args, **kwargs):
def reset_password(request, *args, **kwargs):
user_id = kwargs.get("uid")
token = kwargs.get("token")
user = User.objects.filter(id=user_id).first()
user = User.objects.filter(pk=user_id).first()
if not user or not token:
raise ValidationError(detail={"non-field-error": "Invalid or expired token"})
is_valid = default_token_generator.check_token(user, token)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
if CURRENT_DOMAIN not in ALLOWED_HOSTS:
ALLOWED_HOSTS.append(CURRENT_DOMAIN)

# Used by the corsheaders app/middleware (django-cors-headers) to allow multiple domains to access the backend
CORS_ALLOWED_ORIGINS = [f"https://{host}" for host in ALLOWED_HOSTS]

# Application definition

INSTALLED_APPS = [
Expand Down Expand Up @@ -306,7 +309,6 @@
if not IN_DEV:
SECURE_SSL_REDIRECT = True
SECURE_PROXY_SSL_HEADER = ("HTTP_X_FORWARDED_PROTO", "https")
MIDDLEWARE += ["django.middleware.security.SecurityMiddleware"]

#
# Custom logging configuration
Expand Down Expand Up @@ -390,11 +392,6 @@ def filter(self, record):
# Popular testing framework that allows logging to stdout while running unit tests
TEST_RUNNER = "django_nose.NoseTestSuiteRunner"

CORS_ALLOWED_ORIGINS = ["https://{{ cookiecutter.project_slug|replace('_', '-') }}-staging.herokuapp.com", "https://{{ cookiecutter.project_slug|replace('_', '-') }}.herokuapp.com"]
{% if cookiecutter.client_app.lower() != 'none' -%}
CORS_ALLOWED_ORIGINS.append("http://localhost:8080")
{% endif -%}

SWAGGER_SETTINGS = {
"LOGIN_URL": "/login",
"USE_SESSION_AUTH": False,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from decouple import config

from {{ cookiecutter.project_slug }}.settings import LOGGING
from {{ cookiecutter.project_slug }}.settings import * # noqa
from {{ cookiecutter.project_slug }}.settings import LOGGING

# Override staticfiles setting to avoid cache issues with whitenoise Manifest staticfiles storage
# See: https://stackoverflow.com/a/69123932
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from django.contrib import admin
from django.urls import include, path

admin.site.site_header = "{{ cookiecutter.project_name }} Admin"
admin.site.site_title = "{{ cookiecutter.project_name }}"

urlpatterns = [
path(r"staff/", admin.site.urls),
path(r"", include("{{ cookiecutter.project_slug }}.common.favicon_urls")),
Expand Down

0 comments on commit 8ac3f59

Please sign in to comment.