Skip to content

Commit

Permalink
chore: fmt and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
nikola-bozin-org committed May 5, 2024
1 parent 3de18c9 commit 910fefa
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 17 deletions.
2 changes: 0 additions & 2 deletions server/src/middlewares/rate_limiter/rate_limit_info.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::time::Duration;

use redis_macros::{FromRedisValue, ToRedisArgs};
use serde::{Deserialize, Serialize};

Expand Down
20 changes: 12 additions & 8 deletions server/src/middlewares/rate_limiter/rate_limit_mw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@ use axum::{
Extension,
};
use chrono::Local;
use tokio::sync::Mutex;

use crate::{constants::REQUESTS_AMOUNT_TIME_FRAME, state::AppState, RateLimiterRedisInteractor};

pub async fn rate_limit(
Extension(state): Extension<Arc<AppState>>,
ConnectInfo(ip_addr): ConnectInfo<SocketAddr>,
mut req: Request,
req: Request,
next: Next,
) -> Response {
println!("Rate limiter hit with ip: {}", ip_addr);
Expand All @@ -24,7 +23,7 @@ pub async fn rate_limit(

let requests_amount = state.rate_limiter_config.requests_amount;
let next_reset = Local::now() + REQUESTS_AMOUNT_TIME_FRAME;

if ip_data.is_none() {
state
.redis_rate_limiter_db
Expand All @@ -39,13 +38,18 @@ pub async fn rate_limit(
} else {
let ip_data = ip_data.unwrap();
if ip_data.limit == 0 {
if ip_data.next_reset<Local::now().timestamp(){
if ip_data.next_reset < Local::now().timestamp() {
state
.redis_rate_limiter_db
.set_data(ip_addr,&crate::RateLimitInfo {
limit: requests_amount,next_reset:next_reset.timestamp()
}).await;
}
.set_data(
ip_addr,
&crate::RateLimitInfo {
limit: requests_amount,
next_reset: next_reset.timestamp(),
},
)
.await;
}
drop(state); // drop the lock so the state can be used in next middleware.
return (StatusCode::TOO_MANY_REQUESTS, "Too many requests!").into_response();
}
Expand Down
26 changes: 26 additions & 0 deletions server/src/middlewares/rate_limiter/rate_limiter_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,29 @@ impl RateLimiter for RateLimiterConfig {
self.time_frame = limit;
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_new_rate_limiter_config() {
let limiter_config = RateLimiterConfig::new(10, Duration::from_secs(60));
assert_eq!(limiter_config.requests_amount, 10);
assert_eq!(limiter_config.time_frame, Duration::from_secs(60));
}

#[test]
fn test_set_requests_amount() {
let mut limiter_config = RateLimiterConfig::new(10, Duration::from_secs(60));
limiter_config.set_requests_amount(5);
assert_eq!(limiter_config.requests_amount, 5);
}

#[test]
fn test_set_limit() {
let mut limiter_config = RateLimiterConfig::new(10, Duration::from_secs(60));
limiter_config.set_limit(Duration::from_secs(30));
assert_eq!(limiter_config.time_frame, Duration::from_secs(30));
}
}
14 changes: 7 additions & 7 deletions server/src/middlewares/rate_limiter/redis_interactor.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
use std::net::SocketAddr;

use redis::{aio::MultiplexedConnection, AsyncCommands, Client, Commands};
use redis::{aio::MultiplexedConnection, AsyncCommands, Client};

use super::{RateLimitInfo, Result};

pub trait RateLimiterRedisInteractor {
async fn new(redis_url: String) -> Result<Self>
where
Self: Sized;
async fn get_data(& self, ip_addr: SocketAddr) -> Option<RateLimitInfo>;
async fn set_data(& self, ip_addr: SocketAddr, rate_limit_info: &RateLimitInfo);
async fn get_data(&self, ip_addr: SocketAddr) -> Option<RateLimitInfo>;
async fn set_data(&self, ip_addr: SocketAddr, rate_limit_info: &RateLimitInfo);
}

#[derive(Clone, Debug)]
Expand All @@ -25,7 +25,7 @@ impl RateLimiterRedisInteractor for RedisRateLimiterDb {
Ok(Self { client, connection })
}

async fn get_data(& self, ip_addr: SocketAddr) -> Option<RateLimitInfo> {
async fn get_data(&self, ip_addr: SocketAddr) -> Option<RateLimitInfo> {
let key = ip_addr.to_string();
let mut connection = self.connection.clone();
connection
Expand All @@ -34,10 +34,10 @@ impl RateLimiterRedisInteractor for RedisRateLimiterDb {
.unwrap()
}

async fn set_data(& self, ip_addr: SocketAddr, rate_limit_info: &RateLimitInfo) {
async fn set_data(&self, ip_addr: SocketAddr, rate_limit_info: &RateLimitInfo) {
let key = ip_addr.to_string();
let mut connection = self.connection.clone();

connection
.set::<String, &RateLimitInfo, ()>(key, rate_limit_info)
.await
Expand Down Expand Up @@ -66,7 +66,7 @@ mod tests {

#[tokio::test]
async fn test_set_and_get_data() {
let mut db = setup_test_db().await;
let db = setup_test_db().await;
let test_ip = SocketAddr::from_str("127.0.0.1:8080").unwrap();
let rate_limit_info = RateLimitInfo {
limit: REQUESTS_AMOUNT_LIMIT,
Expand Down

0 comments on commit 910fefa

Please sign in to comment.