diff --git a/app/api/deps.py b/app/api/deps.py index 50f19f5..ebef08e 100644 --- a/app/api/deps.py +++ b/app/api/deps.py @@ -25,26 +25,26 @@ class RedisConnectionManager: """ def __init__(self, auth_settings: AuthSettings): - self.url: str = f"{auth_settings.REDIS_DATABASE_URI}" - self.pool: Redis | None = None # type: ignore + self.__url: str = f"{auth_settings.REDIS_DATABASE_URI}" + self._pool: Redis | None = None # type: ignore - async def start(self) -> None: + async def __start(self) -> None: """ Start the redis pool connection :return: None :rtype: NoneType """ - self.pool = Redis.from_url(self.url, decode_responses=True) - await self.pool.ping() + self._pool = Redis.from_url(self.__url, decode_responses=True) + await self._pool.ping() logger.info("Redis Database initialized") - async def stop(self) -> None: + async def __stop(self) -> None: """ Stops the redis connection :return: None :rtype: NoneType """ - await self.pool.close() # type: ignore + await self._pool.close() # type: ignore async def get_connection(self) -> Redis | None: # type: ignore """ @@ -52,7 +52,7 @@ async def get_connection(self) -> Redis | None: # type: ignore :return: The redis connection :rtype: Optional[Redis] """ - return self.pool + return self._pool @asynccontextmanager async def connection(self) -> AsyncGenerator[Redis, Any]: # type: ignore @@ -61,9 +61,9 @@ async def connection(self) -> AsyncGenerator[Redis, Any]: # type: ignore :return: Yields the generator object :rtype: AsyncGenerator[Redis, Any] """ - await self.start() - yield self.pool # type: ignore - await self.stop() + await self.__start() + yield self._pool # type: ignore + await self.__stop() async def get_redis_dep( diff --git a/app/api/oauth2_validation.py b/app/api/oauth2_validation.py index 27c7e91..51842e2 100644 --- a/app/api/oauth2_validation.py +++ b/app/api/oauth2_validation.py @@ -32,7 +32,7 @@ ) -async def authenticate_user( +async def _authenticate_user( token: str, auth_settings: AuthSettings, user_service: UserService, @@ -91,7 +91,7 @@ async def get_refresh_current_user( :return: Authenticated user information :rtype: UserAuth """ - return await authenticate_user( + return await _authenticate_user( refresh_token, auth_settings, user_service, redis ) @@ -123,4 +123,4 @@ async def get_current_user( status_code=status.HTTP_401_UNAUTHORIZED, detail="Token is blacklisted", ) - return await authenticate_user(token, auth_settings, user_service, redis) + return await _authenticate_user(token, auth_settings, user_service, redis) diff --git a/app/api/redis_deps.py b/app/api/redis_deps.py index a6f3b88..7ad49c5 100644 --- a/app/api/redis_deps.py +++ b/app/api/redis_deps.py @@ -20,11 +20,11 @@ class RedisDependency: """ def __init__(self) -> None: - self._url: str = f"{auth_setting.REDIS_DATABASE_URI}" + self.__url: str = f"{auth_setting.REDIS_DATABASE_URI}" self._redis: Redis | None = None # type: ignore self.auth_settings: AuthSettings = auth_setting - async def init_redis(self) -> None: + async def __init_redis(self) -> None: """ Initializes the redis connection :return: None @@ -32,14 +32,14 @@ async def init_redis(self) -> None: """ try: self._redis = Redis.from_url( - self._url, + self.__url, decode_responses=True, ) except RedisError as exc: logger.error("Failed to establish Redis connection: %s", exc) raise - async def close_redis(self) -> None: + async def __close_redis(self) -> None: """ Closes the redis connection :return: None @@ -53,9 +53,9 @@ async def close_redis(self) -> None: raise async def __aenter__(self) -> Redis: # type: ignore - await self.init_redis() + await self.__init_redis() if self._redis: return self._redis async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: - await self.close_redis() + await self.__close_redis() diff --git a/app/config/init_settings.py b/app/config/init_settings.py index 4e8d158..126f011 100644 --- a/app/config/init_settings.py +++ b/app/config/init_settings.py @@ -44,10 +44,7 @@ class InitSettings(BaseSettings): extra="allow", ) - ITERATIONS: PositiveInt = 100000 - KEY_BYTES_LENGTH: PositiveInt = 32 SALT_BYTES: PositiveInt = 16 - IV_BYTES: PositiveInt = 12 PUBLIC_EXPONENT: PositiveInt = 65537 RSA_KEY_BITS: PositiveInt = 2048 SALUTE: str = "Salute!" diff --git a/app/core/security/password.py b/app/core/security/password.py index 3a3f228..ec9a09d 100644 --- a/app/core/security/password.py +++ b/app/core/security/password.py @@ -15,6 +15,18 @@ ) +def _raise_custom_error(error_message: str) -> None: + """ + Raise an exception + :param error_message: The error message to display + :type error_message: str + :return: None + :rtype: NoneType + """ + logger.error(error_message) + raise SecurityException(error_message) + + def get_password_hash(password: str) -> str: """ Hash a password using the bcrypt algorithm @@ -24,7 +36,7 @@ def get_password_hash(password: str) -> str: :rtype: str """ if not password: - raise_custom_error("Password cannot be empty or None") + _raise_custom_error("Password cannot be empty or None") return crypt_context.hash(password) @@ -39,19 +51,7 @@ def verify_password(hashed_password: str, plain_password: str) -> bool: :rtype: bool """ if not plain_password: - raise_custom_error("Plain password cannot be empty or None") + _raise_custom_error("Plain password cannot be empty or None") if not hashed_password: - raise_custom_error("Hashed password cannot be empty or None") + _raise_custom_error("Hashed password cannot be empty or None") return crypt_context.verify(plain_password, hashed_password) - - -def raise_custom_error(error_message: str) -> None: - """ - Raise an exception - :param error_message: The error message to display - :type error_message: str - :return: None - :rtype: NoneType - """ - logger.error(error_message) - raise SecurityException(error_message) diff --git a/app/middlewares/blacklist_token.py b/app/middlewares/blacklist_token.py index ae03725..08a0854 100644 --- a/app/middlewares/blacklist_token.py +++ b/app/middlewares/blacklist_token.py @@ -20,7 +20,7 @@ ] -def extract_token(request: Request) -> str | None: +def __extract_token(request: Request) -> str | None: """ Extract token from the Authorization headers of the request :param request: The upcoming request instance @@ -34,7 +34,7 @@ def extract_token(request: Request) -> str | None: return None -async def check_blacklist(token: str, request: Request) -> None: +async def __check_blacklist(token: str, request: Request) -> None: """ Check if a token is blacklisted from the upcoming request :param token: The token to check @@ -56,7 +56,7 @@ async def check_blacklist(token: str, request: Request) -> None: ) -async def process_request(request: Request) -> None: +async def _process_request(request: Request) -> None: """ Process request for the blacklist middleware :param request: The upcoming request instance @@ -65,8 +65,8 @@ async def process_request(request: Request) -> None: :rtype: NoneType """ token: str | None - if token := extract_token(request): - await check_blacklist(token, request) + if token := __extract_token(request): + await __check_blacklist(token, request) async def blacklist_middleware( @@ -82,6 +82,6 @@ async def blacklist_middleware( :rtype: Response """ if not any(request.url.path.startswith(route) for route in SKIP_ROUTES): - await process_request(request) + await _process_request(request) response: Response = await call_next(request) return response diff --git a/app/middlewares/rate_limiter.py b/app/middlewares/rate_limiter.py index 1330a0e..9b39f41 100644 --- a/app/middlewares/rate_limiter.py +++ b/app/middlewares/rate_limiter.py @@ -24,7 +24,7 @@ def __init__(self, app: FastAPI): self.app: FastAPI = app @staticmethod - async def handle_rate_limit_exceeded( + async def __handle_rate_limit_exceeded( rate_limiter: RateLimiter, rate_limiter_service: RateLimiterService, request: Request, @@ -63,7 +63,7 @@ async def handle_rate_limit_exceeded( headers=headers, ) - async def enforce_rate_limit( + async def __enforce_rate_limit( self, rate_limiter: RateLimiter, rate_limiter_service: RateLimiterService, @@ -83,11 +83,11 @@ async def enforce_rate_limit( await rate_limiter_service.add_request() request_count: int = await rate_limiter_service.get_request_count() if request_count > request.app.state.auth_settings.MAX_REQUESTS: - await self.handle_rate_limit_exceeded( + await self.__handle_rate_limit_exceeded( rate_limiter, rate_limiter_service, request ) - async def process_request(self, request: Request) -> None: + async def _process_request(self, request: Request) -> None: """ Process a backend request from the middleware :param request: The upcoming request instance @@ -111,7 +111,7 @@ async def process_request(self, request: Request) -> None: request.app.state.auth_settings.MAX_REQUESTS, rate_limiter, ) - await self.enforce_rate_limit( + await self.__enforce_rate_limit( rate_limiter, rate_limiter_service, request ) @@ -123,5 +123,5 @@ async def __call__( ) -> None: if scope["type"] == "http": request = Request(scope, receive=receive) - await self.process_request(request) + await self._process_request(request) await self.app(scope, receive, send) diff --git a/app/middlewares/security_headers.py b/app/middlewares/security_headers.py index 7b5019d..90ae2e9 100644 --- a/app/middlewares/security_headers.py +++ b/app/middlewares/security_headers.py @@ -28,14 +28,14 @@ async def dispatch( :rtype: Response """ response: Response = await call_next(request) - self.add_security_headers( + self._add_security_headers( response, request.app.state.auth_settings.STRICT_TRANSPORT_SECURITY_MAX_AGE, ) return response @staticmethod - def add_security_headers(response: Response, max_age: PositiveInt) -> None: + def _add_security_headers(response: Response, max_age: PositiveInt) -> None: """ Adds security headers to the response. :param max_age: The maximum age for the strict transport security diff --git a/app/services/infrastructure/cached_user.py b/app/services/infrastructure/cached_user.py index fcf24a0..56209f4 100644 --- a/app/services/infrastructure/cached_user.py +++ b/app/services/infrastructure/cached_user.py @@ -26,7 +26,7 @@ def __init__( redis: Redis, # type: ignore ): self._redis: Redis = redis # type: ignore - self._cache_seconds: PositiveInt = auth_setting.CACHE_SECONDS + self.__cache_seconds: PositiveInt = auth_setting.CACHE_SECONDS async def get_model_from_cache(self, key: UUID4) -> User | None: """ @@ -81,5 +81,5 @@ async def set_to_cache( :rtype: NoneType """ await self._redis.setex( - str(key), self._cache_seconds, json.dumps(custom_serializer(value)) + str(key), self.__cache_seconds, json.dumps(custom_serializer(value)) ) diff --git a/app/services/infrastructure/encryption.py b/app/services/infrastructure/encryption.py index 9fcbb95..7b03cb0 100644 --- a/app/services/infrastructure/encryption.py +++ b/app/services/infrastructure/encryption.py @@ -2,12 +2,9 @@ A module for encryption in the app.services.infrastructure package. """ -import base64 import logging -from os import urandom -from typing import Any +from functools import lru_cache -import aiofiles from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends.openssl.backend import Backend from cryptography.hazmat.primitives import hashes, serialization @@ -16,20 +13,8 @@ RSAPrivateKey, RSAPublicKey, ) -from cryptography.hazmat.primitives.asymmetric.types import ( - PrivateKeyTypes, - PublicKeyTypes, -) -from cryptography.hazmat.primitives.ciphers import ( - AEADDecryptionContext, - AEADEncryptionContext, - Cipher, - algorithms, - modes, -) -from pydantic import FilePath, PositiveInt -from app.config.config import init_setting, setting +from app.config.config import init_setting logger: logging.Logger = logging.getLogger(__name__) @@ -41,162 +26,16 @@ class EncryptionService: def __init__(self, backend: Backend = default_backend()) -> None: self._backend: Backend = backend - self._iterations: PositiveInt = init_setting.ITERATIONS - - def _load_public_key(self, public_key_pem: str) -> PublicKeyTypes: - """ - Load a public key from a PEM-formatted string. - :param public_key_pem: The public key in PEM format. - :type public_key_pem: str - :return: The loaded public key. - :rtype: PublicKeyTypes - """ - try: - return serialization.load_pem_public_key( - public_key_pem.encode(), backend=self._backend - ) - except Exception as e: - logger.error(f"Error loading public key: {e}") - raise - - def _load_private_key(self, private_key_pem: str) -> PrivateKeyTypes: - """ - Load a private key from a PEM-formatted string. - :param private_key_pem: The private key in PEM format. - :type private_key_pem: str - :return: The loaded private key. - :rtype: PrivateKeyTypes - """ - try: - return serialization.load_pem_private_key( - private_key_pem.encode(), None, backend=self._backend - ) - except Exception as e: - logger.error(f"Error loading private key: {e}") - raise - - @staticmethod - def _get_padding_scheme() -> padding.OAEP: - """ - Get the padding scheme used for RSA encryption and decryption. - :return: The padding scheme. - :rtype: padding.OAEP - """ - return padding.OAEP( + self._padding_scheme = padding.OAEP( mgf=padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None, ) + self.__public_exponent = init_setting.PUBLIC_EXPONENT + self.__rsa_key_bits = init_setting.RSA_KEY_BITS + self.__encoding = init_setting.ENCODING - def encrypt_aes_key_with_rsa( - self, public_key_pem: str, aes_key: bytes - ) -> bytes: - """ - Encrypt an AES key using a public RSA key. - :param public_key_pem: The public RSA key in PEM format. - :type public_key_pem: str - :param aes_key: The AES key to encrypt. - :type aes_key: bytes - :return: The encrypted AES key. - :rtype: bytes - """ - public_key: PublicKeyTypes = self._load_public_key(public_key_pem) - try: - return public_key.encrypt( # type: ignore - aes_key, self._get_padding_scheme() - ) - except Exception as e: - logger.error(f"Error encrypting AES key: {e}") - raise - - def decrypt_aes_key_with_rsa( - self, private_key_pem: str, encrypted_aes_key: bytes - ) -> bytes: - """ - Decrypt an AES key using a private RSA key. - :param private_key_pem: The private RSA key in PEM format. - :type private_key_pem: str - :param encrypted_aes_key: The encrypted AES key. - :type encrypted_aes_key: bytes - :return: The decrypted AES key. - :rtype: bytes - """ - private_key: PrivateKeyTypes = self._load_private_key(private_key_pem) - try: - return private_key.decrypt( # type: ignore - encrypted_aes_key, self._get_padding_scheme() - ) - except Exception as e: - logger.error(f"Error decrypting AES key: {e}") - raise - - def encrypt_data(self, data: str, public_key_pem: str) -> dict[str, bytes]: - """ - Encrypt data with a public key (PEM). - :param data: The data to encrypt - :type data: str - :param public_key_pem: The public key to encrypt - :type public_key_pem: str - :return: The encrypted data - :rtype: dict[str, bytes] - """ - aes_key: bytes = urandom(init_setting.KEY_BYTES_LENGTH) - encrypted_aes_key: bytes = self.encrypt_aes_key_with_rsa( - public_key_pem, aes_key - ) - iv: bytes = urandom(init_setting.IV_BYTES) - cipher: Cipher[modes.GCM] = Cipher( - algorithms.AES(aes_key), modes.GCM(iv), backend=self._backend - ) - encryptor: AEADEncryptionContext = cipher.encryptor() - encrypted_data_bytes: bytes = ( - encryptor.update(data.encode()) + encryptor.finalize() - ) - return { - "encrypted_data": encrypted_data_bytes, - "encrypted_aes_key": encrypted_aes_key, - "iv": iv, - "tag": encryptor.tag, - } - - def decrypt_data( - self, - encrypted_data: bytes, - private_key_pem: str, - encrypted_aes_key: bytes, - iv: bytes, - tag: bytes, - ) -> str: - """ - Decrypt data with a password. - :param encrypted_data: The encrypted data - :type encrypted_data: bytes - :param private_key_pem: The private key - :type private_key_pem: str - :param encrypted_aes_key: The encrypted AES key - :type encrypted_aes_key: bytes - :param iv: The initialization vector - :type iv: bytes - :param tag: The decryption tag - :type tag: bytes - :return: The decrypted data - :rtype: str - """ - aes_key: bytes = self.decrypt_aes_key_with_rsa( - private_key_pem, encrypted_aes_key - ) - cipher: Cipher[modes.GCM] = Cipher( - algorithms.AES(aes_key), - modes.GCM(iv, tag), - backend=self._backend, - ) - decryptor: AEADDecryptionContext = cipher.decryptor() - decrypted_data_bytes: bytes = ( - decryptor.update(encrypted_data) + decryptor.finalize() - ) - return decrypted_data_bytes.decode() - - def generate_key_pair(self) -> tuple[str, str]: + def _generate_key_pair(self) -> tuple[str, str]: """ Generate a pair of RSA keys. :return: Tuple containing public and private RSA keys in PEM format. @@ -204,8 +43,8 @@ def generate_key_pair(self) -> tuple[str, str]: """ # Used ONLY ONCE to generate the local keys on disk private_key: RSAPrivateKey = rsa.generate_private_key( - public_exponent=init_setting.PUBLIC_EXPONENT, - key_size=init_setting.RSA_KEY_BITS, + public_exponent=self.__public_exponent, + key_size=self.__rsa_key_bits, backend=self._backend, ) public_key: RSAPublicKey = private_key.public_key() @@ -213,11 +52,11 @@ def generate_key_pair(self) -> tuple[str, str]: encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption(), - ).decode(init_setting.ENCODING) + ).decode(self.__encoding) public_pem: str = public_key.public_bytes( encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo, - ).decode(init_setting.ENCODING) + ).decode(self.__encoding) return public_pem, private_pem def save_key_pair( @@ -232,81 +71,15 @@ def save_key_pair( :return: None :rtype: NoneType """ - public_pem, private_pem = self.generate_key_pair() + public_pem, private_pem = self._generate_key_pair() with open(public_key_path, "w") as public_file: public_file.write(public_pem) with open(private_key_path, "w") as private_file: private_file.write(private_pem) print("Public and Private keys have been generated and saved.") - @staticmethod - def serialize_encrypted_info(encrypted_info: dict[str, bytes]) -> str: - """ - Serialize the encrypted information - :param encrypted_info: The encrypted information - :type encrypted_info: dict[str, bytes] - :return: The combined data - :rtype: str - """ - combined: bytes = ( - encrypted_info["encrypted_aes_key"] - + encrypted_info["iv"] - + encrypted_info["encrypted_data"] - + encrypted_info["tag"] - ) - return base64.b64encode(combined).decode() - - @staticmethod - def deserialize_encrypted_info(serialized_info: str) -> dict[str, bytes]: - """ - Deserialize the encrypted information - :param serialized_info: The serialized info - :type serialized_info: str - :return: The dictionary with encrypted data as bytes - :rtype: dict[str, bytes] - """ - decoded: bytes = base64.b64decode(serialized_info) - rsa_key_size_bytes: PositiveInt = init_setting.RSA_KEY_BITS // 8 - iv_size_bytes: PositiveInt = init_setting.IV_BYTES - tag_size_bytes: PositiveInt = init_setting.SALT_BYTES - encrypted_aes_key: bytes = decoded[:rsa_key_size_bytes] - iv: bytes = decoded[ - rsa_key_size_bytes : rsa_key_size_bytes + iv_size_bytes - ] - tag: bytes = decoded[-tag_size_bytes:] - encrypted_data: bytes = decoded[ - rsa_key_size_bytes + iv_size_bytes : -tag_size_bytes - ] - return { - "encrypted_aes_key": encrypted_aes_key, - "iv": iv, - "encrypted_data": encrypted_data, - "tag": tag, - } - - async def encrypt_and_serialize( - self, - data: Any, - key_path: FilePath = setting.PUBLIC_KEY_PATH, - ) -> tuple[bytes, str]: - """ - Encrypts and serializes the given data - :param data: The data to encrypt - :type data: Any - :param key_path: The path for the key - :type key_path: FilePath - :return: The encrypted data and the serialized data - :rtype: - """ - async with aiofiles.open(key_path, mode="r") as key_file: - public_key: str = await key_file.read() - encrypted_info: dict[str, bytes] = self.encrypt_data( - str(data), public_key - ) - serialized_info: str = self.serialize_encrypted_info(encrypted_info) - return encrypted_info["encrypted_data"], serialized_info - +@lru_cache def get_encryption_service() -> EncryptionService: """ Get the encryption service instance diff --git a/app/services/infrastructure/ip_blacklist.py b/app/services/infrastructure/ip_blacklist.py index 07f6ff4..f913c6b 100644 --- a/app/services/infrastructure/ip_blacklist.py +++ b/app/services/infrastructure/ip_blacklist.py @@ -29,7 +29,18 @@ def __init__( blacklist_expiration_seconds: PositiveInt, ): self._redis: Redis = redis # type: ignore - self._expiration_seconds: PositiveInt = blacklist_expiration_seconds + self.__expiration_seconds: PositiveInt = blacklist_expiration_seconds + + @staticmethod + def _get_redis_key(ip: IPvAnyAddress) -> str: + """ + Generate the Redis key for the given IP address. + :param ip: The IP address. + :type ip: IPvAnyAddress + :return: The Redis key. + :rtype: str + """ + return f"blacklist:{ip}" @handle_redis_exceptions async def is_ip_blacklisted(self, ip: IPvAnyAddress) -> bool: @@ -40,7 +51,7 @@ async def is_ip_blacklisted(self, ip: IPvAnyAddress) -> bool: :return: True if blacklisted, False otherwise. :rtype: bool """ - return bool(await self._redis.get(self.get_redis_key(ip))) + return bool(await self._redis.get(self._get_redis_key(ip))) @handle_redis_exceptions async def blacklist_ip(self, ip: IPvAnyAddress) -> None: @@ -52,22 +63,11 @@ async def blacklist_ip(self, ip: IPvAnyAddress) -> None: :rtype: NoneType """ await self._redis.setex( - self.get_redis_key(ip), - self._expiration_seconds, + self._get_redis_key(ip), + self.__expiration_seconds, f"Blacklisted at {datetime.now(UTC).isoformat()}", ) - @staticmethod - def get_redis_key(ip: IPvAnyAddress) -> str: - """ - Generate the Redis key for the given IP address. - :param ip: The IP address. - :type ip: IPvAnyAddress - :return: The Redis key. - :rtype: str - """ - return f"blacklist:{ip}" - def get_ip_blacklist_service( redis: Annotated[Redis, Depends(get_redis_dep)], # type: ignore diff --git a/app/services/infrastructure/rate_limiter.py b/app/services/infrastructure/rate_limiter.py index bc91111..6f6060b 100644 --- a/app/services/infrastructure/rate_limiter.py +++ b/app/services/infrastructure/rate_limiter.py @@ -25,11 +25,11 @@ def __init__( rate_limiter: RateLimiter, ): self._redis: Redis = redis # type: ignore - self._rate_limit_duration: PositiveInt = rate_limit_duration - self._max_requests: PositiveInt = max_requests + self.__rate_limit_duration: PositiveInt = rate_limit_duration + self.__max_requests: PositiveInt = max_requests self._rate_limiter: RateLimiter = rate_limiter - def get_rate_limit_key(self) -> str: + def _get_rate_limit_key(self) -> str: """ Returns the rate limit key :return: The key to store on Redis based on the model instance @@ -48,9 +48,9 @@ async def add_request(self) -> None: :return: None :rtype: NoneType """ - rate_limit_key: str = self.get_rate_limit_key() + rate_limit_key: str = self._get_rate_limit_key() min_timestamp: datetime = datetime.now() - timedelta( - seconds=self._rate_limit_duration + seconds=self.__rate_limit_duration ) now_timestamp: float = datetime.now().timestamp() await self._redis.zremrangebyscore( @@ -67,7 +67,7 @@ async def get_request_count(self) -> int: :return: The number of requests in the current window :rtype: int """ - rate_limit_key: str = self.get_rate_limit_key() + rate_limit_key: str = self._get_rate_limit_key() return await self._redis.zcard(rate_limit_key) async def get_remaining_requests(self) -> int: @@ -77,7 +77,7 @@ async def get_remaining_requests(self) -> int: :rtype: int """ request_count: int = await self.get_request_count() - return self._max_requests - request_count + return self.__max_requests - request_count @handle_redis_exceptions async def get_reset_time(self) -> datetime: @@ -86,7 +86,7 @@ async def get_reset_time(self) -> datetime: :return: The reset time available :rtype: datetime """ - rate_limit_key: str = self.get_rate_limit_key() + rate_limit_key: str = self._get_rate_limit_key() oldest_request: list[tuple[Any, float]] = await self._redis.zrange( rate_limit_key, 0, 0, withscores=True ) @@ -95,4 +95,4 @@ async def get_reset_time(self) -> datetime: if oldest_request else datetime.now() ) - return oldest_timestamp + timedelta(seconds=self._rate_limit_duration) + return oldest_timestamp + timedelta(seconds=self.__rate_limit_duration) diff --git a/app/services/infrastructure/token.py b/app/services/infrastructure/token.py index 2e07400..6934ccf 100644 --- a/app/services/infrastructure/token.py +++ b/app/services/infrastructure/token.py @@ -27,10 +27,10 @@ def __init__( auth_settings: AuthSettings, ): self._redis: Redis = redis # type: ignore - self._refresh_token_expire_minutes: PositiveInt = ( + self.__refresh_token_expire_minutes: PositiveInt = ( auth_settings.REFRESH_TOKEN_EXPIRE_MINUTES ) - self._blacklist_expiration_seconds: PositiveInt = ( + self.__blacklist_expiration_seconds: PositiveInt = ( PositiveInt( PositiveInt(auth_settings.ACCESS_TOKEN_EXPIRE_MINUTES) + 1 ) @@ -50,7 +50,7 @@ async def create_token(self, token: Token) -> bool: try: inserted: bool = await self._redis.setex( token.key, - self._refresh_token_expire_minutes, + self.__refresh_token_expire_minutes, token.user_info, ) except RedisError as r_exc: @@ -58,23 +58,6 @@ async def create_token(self, token: Token) -> bool: raise r_exc return inserted - @handle_redis_exceptions - @benchmark - async def get_token(self, key: str) -> str | None: - """ - Read token from the authentication database - :param key: The key to search for - :type key: str - :return: The refresh token - :rtype: str - """ - try: - value: str = str(await self._redis.get(key)) - except RedisError as r_exc: - logger.error("Error at getting token. %s", r_exc) - raise r_exc - return value - @handle_redis_exceptions @benchmark async def blacklist_token(self, token_key: str) -> bool: @@ -89,7 +72,7 @@ async def blacklist_token(self, token_key: str) -> bool: try: blacklisted: bool = await self._redis.setex( f"blacklist:{token_key}", - self._blacklist_expiration_seconds, + self.__blacklist_expiration_seconds, "true", ) except RedisError as r_exc: @@ -109,7 +92,7 @@ async def is_token_blacklisted(self, token_key: str) -> bool: """ try: blacklisted: str | None = await self._redis.get( - f"blacklist" f":{token_key}" + f"blacklist:{token_key}" ) except RedisError as r_exc: logger.error("Error at checking if token is blacklisted. %s", r_exc) diff --git a/app/services/infrastructure/user.py b/app/services/infrastructure/user.py index e86d1fa..d98a5ae 100644 --- a/app/services/infrastructure/user.py +++ b/app/services/infrastructure/user.py @@ -11,7 +11,6 @@ from redis.asyncio import Redis from app.api.deps import get_redis_dep -from app.config.config import auth_setting from app.crud.specification import ( EmailSpecification, IdSpecification, @@ -48,7 +47,6 @@ def __init__( ): self._user_repo: UserRepository = user_repo self._redis: Redis = redis # type: ignore - self._cache_seconds: PositiveInt = auth_setting.CACHE_SECONDS async def get_user_by_id(self, user_id: UUID4) -> UserResponse | None: """