Skip to content

Commit

Permalink
feat: first simple rate limiter middleware solution
Browse files Browse the repository at this point in the history
  • Loading branch information
nikola-bozin-org committed May 4, 2024
1 parent a30e8e3 commit 6f91931
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 12 deletions.
17 changes: 13 additions & 4 deletions server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ use std::{env, net::SocketAddr, sync::Arc};

use axum::{middleware, Extension, Router};
use middlewares::response_mapper;
use tokio::net::TcpListener;
use tokio::{net::TcpListener, sync::Mutex};

use crate::middlewares::*;

#[tokio::main]
async fn main() {
Expand All @@ -22,17 +23,25 @@ async fn main() {
let database_url = env::var("DATABASE_URL")
.unwrap_or_else(|_| panic!("Missing required environment variable: {}", "DATABASE_URL"));

let redis_url = env::var("REDIS_URL")
.unwrap_or_else(|_| panic!("Missing required environment variable: {}", "DATABSE_URL"));

let tfd = tracing_fast_dev::tfd();

tfd.info("wisdomia", "INITIALIZATION");

let redis_rate_limiter_db = RedisRateLimiterDb::new(redis_url).await.unwrap();

let db = connect(database_url.as_str()).await.unwrap();

sqlx::migrate!("../migrations").run(&db).await.unwrap();

let state = AppState { db: db.clone() };
let state = AppState {
db: db.clone(),
redis_rate_limiter_db,
};

let shared_state = Arc::new(state);
let shared_state = Arc::new(Mutex::new(state));

let listener = TcpListener::bind(format!("{}:{}", constants::HOST, constants::PORT))
.await
Expand All @@ -42,8 +51,8 @@ async fn main() {

let router = Router::new()
.nest("/api/v1", routes::routes())
.layer(Extension(shared_state))
.layer(middleware::from_fn(middlewares::rate_limit))
.layer(Extension(shared_state))
.layer(middleware::map_response(response_mapper));

axum::serve(
Expand Down
1 change: 1 addition & 0 deletions server/src/middlewares/rate_limiter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ pub use error::*;
pub use rate_limit_info::*;
pub use rate_limit_mw::*;
pub use rate_limiter_config::*;
pub use redis_interactor::*;
36 changes: 34 additions & 2 deletions server/src/middlewares/rate_limiter/rate_limit_mw.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,48 @@
use std::net::SocketAddr;
use std::{net::SocketAddr, sync::Arc};

use axum::{
extract::{ConnectInfo, Request},
http::StatusCode,
middleware::Next,
response::Response,
response::{IntoResponse, Response},
Extension,
};
use tokio::sync::Mutex;

use crate::{state::AppState, RateLimiterRedisInteractor};

pub async fn rate_limit(
Extension(state): Extension<Arc<Mutex<AppState>>>,
ConnectInfo(ip_addr): ConnectInfo<SocketAddr>,
mut req: Request,
next: Next,
) -> Response {
println!("Rate limiter hit with ip: {}", ip_addr);
let mut state = state.lock().await;
let ip_data = state.redis_rate_limiter_db.get_data(ip_addr).await;
dbg!(&ip_data);
if ip_data.is_none() {
state
.redis_rate_limiter_db
.set_data(ip_addr, &crate::RateLimitInfo { limit: 10 })
.await;
} else {
let ip_data = ip_data.unwrap();
if ip_data.limit == 0 {
drop(state); // drop the lock so the state can be used in next middleware.
return (StatusCode::TOO_MANY_REQUESTS, "Too many requests!").into_response();
} else {
state
.redis_rate_limiter_db
.set_data(
ip_addr,
&crate::RateLimitInfo {
limit: ip_data.limit - 1,
},
)
.await;
}
}
drop(state); // drop the lock so the state can be used in next middleware.
next.run(req).await
}
2 changes: 1 addition & 1 deletion server/src/middlewares/rate_limiter/redis_interactor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub trait RateLimiterRedisInteractor {
async fn set_data(&mut self, ip_addr: SocketAddr, rate_limit_info: &RateLimitInfo);
}

#[derive(Clone, Debug)]
pub struct RedisRateLimiterDb {
pub client: Client,
pub connection: MultiplexedConnection,
Expand Down Expand Up @@ -56,7 +57,6 @@ mod tests {
#[tokio::test]
async fn test_new() {
let _db = setup_test_db().await;
// If no panic and no error, assume successful connection and client creation
}

#[tokio::test]
Expand Down
5 changes: 3 additions & 2 deletions server/src/routes/wisdoms.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::sync::Arc;

use axum::{http::StatusCode, response::IntoResponse, routing::get, Extension, Json, Router};
use tokio::sync::Mutex;

use crate::{db::_get_wisdoms, helpers::default_handle_error, state::AppState};

Expand All @@ -14,9 +15,9 @@ fn _routes() -> Router {
Router::new().route("/", get(get_wisdoms))
}

async fn get_wisdoms(Extension(state): Extension<Arc<AppState>>) -> impl IntoResponse {
async fn get_wisdoms(Extension(state): Extension<Arc<Mutex<AppState>>>) -> impl IntoResponse {
tracing_fast_dev::tfd().info("GET_WISDOM", "FUNCTION");
match _get_wisdoms(&state.db).await {
match _get_wisdoms(&state.lock().await.db).await {
Ok(wisdoms) => (StatusCode::OK, Json(json!({ "wisdoms": wisdoms }))).into_response(),
Err(e) => default_handle_error(e),
}
Expand Down
3 changes: 2 additions & 1 deletion server/src/state.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::db::Database;
use crate::{db::Database, RedisRateLimiterDb};

#[derive(Clone)]
pub struct AppState {
pub db: Database,
pub redis_rate_limiter_db: RedisRateLimiterDb,
}
5 changes: 3 additions & 2 deletions worker/src/wisdoms_checker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ async fn send_tg_message(bot: Bot, message: &str, chat_id: i64) -> Result<Messag
.await
}

//TODO: This code should BE fixed to not use so many match-es everywhere.
// TODO: This code should BE fixed to not use so many match-es everywhere.
// It does not look good...
pub async fn worker_thread(pool: PgPool) {
let _ = dotenv().ok();
let mut prev_base64_string = String::new();
Expand All @@ -37,7 +38,7 @@ pub async fn worker_thread(pool: PgPool) {
loop {
tokio::time::sleep(Duration::from_secs(5)).await;

//Its one dot (.) because we are running it from the root.
//Its one dot (.) because we are running the worker it from the root.
let base64_string = match fs::read_to_string("./encoded-wisdoms.b64") {
Ok(content) => content.replace(['\n', '\r'], ""),
Err(e) => {
Expand Down

0 comments on commit 6f91931

Please sign in to comment.