Skip to content

Commit

Permalink
test: bump up test coverage from 23% to 29%
Browse files Browse the repository at this point in the history
  • Loading branch information
rhodinemma committed May 22, 2024
1 parent 77c4f23 commit 22a432a
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 67 deletions.
2 changes: 1 addition & 1 deletion app/helpers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def get_current_user(access_token: str):
raise HTTPException(status_code=401, detail="Invalid token")
except Exception as e:
print(f"An unexpected error occurred: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
raise HTTPException(status_code=401, detail="Unauthorized")


def check_authentication(current_user):
Expand Down
53 changes: 28 additions & 25 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from unittest.mock import patch
from fastapi.testclient import TestClient
import pytest
from main import app
from app.helpers.auth import has_role, get_current_user, check_authentication
from app.database import database_exists, create_database
from fastapi import HTTPException
from jose import jwt
from jose import JWTError, jwt
from config import settings
import os
from functools import wraps
Expand All @@ -13,13 +14,15 @@

client = TestClient(app)


def test_has_role():
roles = [{'name': 'admin'}, {'name': 'user'}]
assert has_role(roles, 'admin') == True
assert has_role(roles, 'guest') == False


def test_get_current_user():
access_token = "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
access_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
mock_jwt_decode = patch('jose.jwt.decode', return_value={
'user_claims': {
'roles': [{'name': 'admin'}],
Expand All @@ -33,6 +36,7 @@ def test_get_current_user():
assert user.id == 'user_id'
assert user.email == 'test@example.com'


def test_check_authentication():
current_user = None
try:
Expand All @@ -43,21 +47,27 @@ def test_check_authentication():
else:
assert False, "HTTPException not raised"


def test_invalid_access_token():
access_token = "Bearer invalid_token"
result = get_current_user(access_token)
assert result.role is None
assert result.id is None
assert result.email is None
access_token = "invalid_token"
with pytest.raises(HTTPException) as excinfo:
result = get_current_user(access_token)

assert result.id is None
assert result.role is None
assert result.email is None
assert excinfo.value.status_code == 401
assert excinfo.value.detail == "Unauthorized"


@patch('jwt.decode')
def test_jwt_decode_error(mock_jwt_decode):
mock_jwt_decode.side_effect = Exception("JWT decode error")
access_token = "Bearer valid_token"
result = get_current_user(access_token)
assert result.role is None
assert result.id is None
assert result.email is None
mock_jwt_decode.side_effect = JWTError("JWT decode error")
access_token = "valid_token"
with pytest.raises(HTTPException) as excinfo:
get_current_user(access_token)
assert excinfo.value.status_code == 401
assert excinfo.value.detail == "Unauthorized"


def admin_or_user_required(fn):
Expand All @@ -80,23 +90,23 @@ def wrapper(*args, **kwargs):
# kwargs['current_user'] = {'roles': ['admin']}
# return fn(*args, **kwargs)
# return wrapper

# def dummy_has_role(user, role):
# return 'administrator' in user.get('roles', [])

# @admin_or_user_required
# def dummy_function(*args, **kwargs):
# return "Admin access granted" if kwargs.get('is_admin') else "Access denied"

# # Patching authenticate and has_role functions
# authenticate = dummy_authenticate
# has_role = dummy_has_role

# result = dummy_function(access_token="Bearer eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpYXQiOjE3MTI1ODc4MTQsIm5iZiI6MTcxMjU4NzgxNCwianRpIjoiMjhkZmFiZmEtNjFiZS00OTE4LWJlMDMtNWE1ZGJlMTI4MzllIiwiZXhwIjoxNzEzNDUxODE0LCJpZGVudGl0eSI6IjRkYmJlNDhjLWU1NWYtNDQ2Mi05Nzk5LTBmOTFmOTJjY2NiNCIsImZyZXNoIjpmYWxzZSwidHlwZSI6ImFjY2VzcyIsInVzZXJfY2xhaW1zIjp7InJvbGVzIjpbeyJuYW1lIjoiY3VzdG9tZXIiLCJpZCI6IjdmN2ZiZGQ5LWMxMGQtNGRiMC1iOTQ3LWUyZDc0MmE2MTlhOSJ9XSwiZW1haWwiOiJoZW5yeUBjcmFuZWNsb3VkLmlvIn19.ej5-HNioEPrVT6oZ2mdKamTGVQiBt7LSAbALP1Jde0g", current_user={'user_claims': {
# 'roles': [{'name': 'admin'}],
# 'email': 'test@example.com'
# }})

# assert result == "Access denied"

# def test_user_required():
Expand All @@ -122,10 +132,3 @@ def wrapper(*args, **kwargs):
# # Testing user access
# result = dummy_function(access_token="Bearer eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpYXQiOjE3MTI1ODU4NzEsIm5iZiI6MTcxMjU4NTg3MSwianRpIjoiZDFjODIzZGUtZGE4OC00MDI5LTg0NDktZWQ0ZmVlMWUyNjExIiwiZXhwIjoxNzEzNDQ5ODcxLCJpZGVudGl0eSI6IjMzMDIyMmFmLWJkYTktNDlhYy04MWMzLWQ3ZGQ0MDI1NjlhYSIsImZyZXNoIjpmYWxzZSwidHlwZSI6ImFjY2VzcyIsInVzZXJfY2xhaW1zIjp7InJvbGVzIjpbeyJuYW1lIjoiYWRtaW5pc3RyYXRvciIsImlkIjoiYWYyM2Y2ZDgtZWRlZi00NWY0LTg1ZWMtZGE2Y2Q3ZDUxOWJiIn1dLCJlbWFpbCI6ImFkbWluQGNyYW5lY2xvdWQuaW8ifX0.FXrs1icgXrPwGsH4m6EW9iNIV0uuXLykRNFLWthoyMM")
# assert result == "Access denied"







177 changes: 136 additions & 41 deletions tests/test_endpoints.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import uuid
from fastapi.testclient import TestClient
from main import app
import pytest
Expand All @@ -9,24 +10,24 @@
from sqlalchemy.orm import Session
from types import SimpleNamespace
from app.models import Database
from app.routes import fetch_database_stats
from app.routes import fetch_database_stats, get_all_databases
from app.helpers.auth import get_current_user, check_authentication



client = TestClient(app)


def test_validate_set_by():
valid_value = 'year'
user_graph_schema = UserGraphSchema(start='2022-01-01', end='2022-12-31', set_by=valid_value)
user_graph_schema = UserGraphSchema(
start='2022-01-01', end='2022-12-31', set_by=valid_value)
assert user_graph_schema.set_by == valid_value

# Test with an invalid value
invalid_value = 'week'
with pytest.raises(ValueError):
UserGraphSchema(start='2022-01-01', end='2022-12-31', set_by=invalid_value)


UserGraphSchema(start='2022-01-01', end='2022-12-31',
set_by=invalid_value)


def test_index():
Expand All @@ -35,65 +36,159 @@ def test_index():
assert response.json() == ['Welcome to the Database Service']


@patch('app.helpers.auth.get_current_user')
@patch('app.routes.get_current_user')
@patch('app.routes.check_authentication')
def test_get_all_databases(mock_check_authentication, mock_get_current_user):
from app.routes import get_all_databases
from app.models import Database

access_token = "dummy_access_token"
@patch('app.helpers.database_session.get_db')
def test_get_all_databases(
mock_get_db,
mock_check_authentication,
mock_get_current_user
):
# Mock current user
current_user = Mock()
current_user.id = 1
current_user.role = "administrator"
mock_get_current_user.return_value = current_user

# Mock authentication check
mock_check_authentication.return_value = None

# Mock database session and data
db = Mock(spec=Session)
databases = [Database(id=1, name="Test Database 1"), Database(id=2, name="Test Database 2")]
databases = [
Database(id=1, name="Test Database 1", user="user1",
password="password1", database_flavour_name="mysql"),
Database(id=2, name="Test Database 2", user="user2",
password="password2", database_flavour_name="mysql")
]
db.query(Database).all.return_value = databases
mock_get_db.return_value = db

mock_get_current_user.return_value = current_user
# Perform the request
response = client.get(
"/databases", headers={"Authorization": "Bearer dummy_access_token"})

# Validate the response
assert response.status_code == 200
data = response.json()
assert "data" in data
assert "databases" in data["data"]

result = get_all_databases(access_token=access_token, db=db)

expected_response = SimpleNamespace(status_code=200, data={"databases": databases})
@patch('app.routes.get_current_user')
@patch('app.routes.check_authentication')
@patch('app.helpers.database_session.get_db')
def test_fetch_database_stats(
mock_get_db,
mock_check_authentication,
mock_get_current_user
):
# Mock current user
current_user = Mock()
current_user.id = 1
current_user.role = "administrator"
mock_get_current_user.return_value = current_user

mock_check_authentication.assert_called_once()
# Mock authentication check
mock_check_authentication.return_value = None

assert result.status_code == expected_response.status_code
assert type(result.data) == type(expected_response.data)
# Mock database session and data
db = Mock(spec=Session)
databases = [
Database(id=1, name="Test Database 1", database_flavour_name="mysql"),
Database(id=2, name="Test Database 2", database_flavour_name="mysql"),
Database(id=3, name="Test Database 3",
database_flavour_name="postgres")
]
db.query(Database).filter_by.return_value.all.return_value = databases
mock_get_db.return_value = db

# Perform the request
response = client.get(
"/databases/stats", headers={"Authorization": "Bearer dummy_access_token"})

# Validate the response
assert response.status_code == 200
data = response.json()
assert "data" in data
assert "databases" in data["data"]
assert "total_database_count" in data["data"]["databases"]
assert "dbs_stats_per_flavour" in data["data"]["databases"]
# Assuming 2 database flavours
assert len(data["data"]["databases"]["dbs_stats_per_flavour"]) == 2
assert "mysql_db_count" in data["data"]["databases"]["dbs_stats_per_flavour"]
assert "postgres_db_count" in data["data"]["databases"]["dbs_stats_per_flavour"]
assert data["data"]["databases"]["dbs_stats_per_flavour"]["mysql_db_count"] == 2


@patch('app.routes.get_current_user')
@patch('app.routes.check_authentication')
def test_fetch_database_stats( mock_check_authentication, mock_get_current_user):
access_token = "dummy_access_token"
@patch('app.helpers.database_session.get_db')
def test_create_database(
mock_get_db,
mock_check_authentication,
mock_get_current_user
):
# Mock current user
current_user = Mock()
current_user.id = 1
current_user.role = "administrator"
mock_get_current_user.return_value = current_user

# Mock authentication check
mock_check_authentication.return_value = None

# Mock database session
db = Mock(spec=Session)
database_flavours = [{'name': 'flavour1'}, {'name': 'flavour2'}]
databases = [Database(id=1, name="Test Database 1", database_flavour_name='flavour1'),
Database(id=2, name="Test Database 2", database_flavour_name='flavour2'),
Database(id=3, name="Test Database 3", database_flavour_name='flavour1')]
mock_get_db.return_value = db

db.query.return_value.filter_by.side_effect = lambda **kwargs: Mock(all=Mock(return_value=[db for db in databases if db.database_flavour_name == kwargs['database_flavour_name']]))

mock_get_current_user.return_value = current_user
# Perform the request
response = client.post(
"/databases",
headers={"Authorization": "Bearer dummy_access_token"},
json={"database_flavour_name": "mysql"}
)

result = fetch_database_stats(access_token=access_token, db=db)
# Validate the response
assert response.status_code == 200

expected_total_count = len(databases)
expected_db_stats_per_flavour = {'flavour1_db_count': 2, 'flavour2_db_count': 1}
expected_data = {'total_database_count': expected_total_count, 'dbs_stats_per_flavour': expected_db_stats_per_flavour}
expected_response = SimpleNamespace(status_code=200, data={'databases': expected_data})

mock_get_current_user.assert_called_once_with(access_token)
mock_check_authentication.assert_called_once()
@patch('app.routes.get_current_user')
@patch('app.routes.check_authentication')
@patch('app.helpers.database_session.get_db')
def test_single_database_not_found(
mock_get_db,
mock_check_authentication,
mock_get_current_user
):
# Mock current user
current_user = Mock()
current_user.id = 1
current_user.role = "administrator"
mock_get_current_user.return_value = current_user

assert result.status_code == expected_response.status_code
assert type(result.data) == type(expected_response.data)
# Mock authentication check
mock_check_authentication.return_value = None

# def test_index(test_client):
# response = test_client.get("/api/")
# assert response.status_code == 200
# assert response.json() == {"message": "Hello, world"}
# Mock database session
db = Mock(spec=Session)
mock_get_db.return_value = db

# Create a mock database
# Mock database query result
mock_database = Mock(spec=Database)
mock_database.id = uuid.uuid4()
mock_database.name = "Test Database"
mock_database.user = "test_user"
mock_database.password = "test_password"
mock_database.database_flavour_name = "mysql"
db.query(Database).filter.return_value.first.return_value = mock_database

# Perform the request
response = client.get(
f"/databases/{mock_database.id}",
headers={"Authorization": "Bearer dummy_access_token"}
)

# Validate the response
assert response.status_code == 404

0 comments on commit 22a432a

Please sign in to comment.