Skip to content

Commit

Permalink
add redis whitelist for jwt
Browse files Browse the repository at this point in the history
  • Loading branch information
huangcheng committed Dec 6, 2023
1 parent d3a816d commit 0880062
Show file tree
Hide file tree
Showing 14 changed files with 131 additions and 52 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ derive_more = "0.99.17"
fern = "0.6.2"
jsonwebtoken = { version = "9.1.0", default-features = false }
log = "0.4.20"
redis = { version = "0.24.0", features = ["tokio-comp"] }
regex = "1.10.2"
rocket = { version = "0.5.0", features = ["json", "uuid"] }
rocket_db_pools = { version = "0.1.0", features = ["sqlx_mysql"] }
rocket_db_pools = { version = "0.1.0", features = ["sqlx_mysql", "deadpool_redis"] }
serde = { version = "1.0.193", features = ["derive"] }
serde_json = "1.0.108"
sha2 = "0.10.8"
Expand Down
3 changes: 3 additions & 0 deletions Rocket.toml.example
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ upload_url = "/upload"
[default.databases.startpage]
url = "mysql://startpage:startpage@127.0.0.1/startpage"

[default.databases.cache]
url = "redis://localhost:6379/0"

[default.jwt]
# you can generate a secret with `openssl rand -hex 32`
secret = ""
Expand Down
9 changes: 5 additions & 4 deletions src/bin/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ use startpage::config::Config;
use startpage::routes::upload::upload;
use startpage::routes::{auth, category, site, user};
use startpage::state::AppState;
use startpage::utils::calculate_expires;
use startpage::Db;
use startpage::utils::parse_duration;
use startpage::{MySQLDb, RedisDb};

fn drop_rocket(meta: &log::Metadata) -> bool {
let name = meta.target();
Expand Down Expand Up @@ -54,7 +54,7 @@ async fn main() -> Result<(), rocket::Error> {
.expect("Failed to extract app config");

let jwt_expiration =
calculate_expires(&config.jwt.expires_in).expect("Failed to parse duration");
parse_duration(&config.jwt.expires_in).expect("Failed to parse jwt expiration");

let upload_url = figment
.extract::<Config>()
Expand All @@ -70,7 +70,8 @@ async fn main() -> Result<(), rocket::Error> {

let _rok = rocket::custom(figment)
.manage(state)
.attach(Db::init())
.attach(MySQLDb::init())
.attach(RedisDb::init())
.mount(
"/api/user",
routes![user::me, user::update, user::update_password],
Expand Down
27 changes: 23 additions & 4 deletions src/handlers/auth.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
use bcrypt::verify;
use jsonwebtoken::{encode, EncodingKey, Header};
use log::error;
use rocket::futures::TryFutureExt;
use rocket_db_pools::deadpool_redis::redis::AsyncCommands;
use rocket_db_pools::Connection;
use sqlx::query_as;

use crate::config::Config;
use crate::errors::ServiceError;
use crate::request;
use crate::state::AppState;
use crate::utils::calculate_expires;
use crate::Claims;
use crate::{models, Db};
use crate::{models, MySQLDb, RedisDb};

pub async fn login(
user: &request::auth::User<'_>,
state: &AppState,
config: &Config,
db: &mut Connection<Db>,
db: &mut Connection<MySQLDb>,
cache: &mut Connection<RedisDb>,
) -> Result<String, ServiceError> {
let record = query_as::<_, models::user::User>(
r#"SELECT username, nickname, password, avatar, email FROM user WHERE username = ?"#,
Expand All @@ -37,9 +41,9 @@ pub async fn login(

if valid {
let claims = Claims {
sub: record.username,
sub: record.username.clone(),
company: String::from("StartPage"),
exp: state.jwt_expiration as usize,
exp: calculate_expires(&config.jwt.expires_in)? as usize,
};

let token = encode(
Expand All @@ -53,6 +57,21 @@ pub async fn login(
ServiceError::InternalServerError
})?;

let username = record.username.clone();

cache
.pset_ex(
&username,
token.clone(),
state.jwt_expiration.num_milliseconds() as usize,
)
.map_err(|e| {
error!("Failed to set token: {}", e);

ServiceError::InternalServerError
})
.await?;

Ok(token)
} else {
Err(ServiceError::BadRequest(String::from(
Expand Down
12 changes: 6 additions & 6 deletions src/handlers/category.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ use crate::models::category::Category;
use crate::request::category::{CreateCategory, UpdateCategory};
use crate::response;
use crate::response::WithTotal;
use crate::Db;
use crate::MySQLDb;

pub async fn get_categories(
page: i64,
size: i64,
search: Option<&str>,
upload_url: &str,
db: &mut Connection<Db>,
db: &mut Connection<MySQLDb>,
) -> Result<WithTotal<response::category::Category>, ServiceError> {
let total = match search {
Some(search) => query(
Expand Down Expand Up @@ -76,7 +76,7 @@ pub async fn get_categories(
pub async fn update_category<'r>(
id: &'r str,
category: &'r UpdateCategory<'r>,
db: &mut Connection<Db>,
db: &mut Connection<MySQLDb>,
) -> Result<(), ServiceError> {
let record = query_as::<_, Category>(
r#"SELECT id, name, description, icon, created_at, updated_at FROM category WHERE id = ?"#,
Expand Down Expand Up @@ -135,7 +135,7 @@ pub async fn update_category<'r>(

pub async fn add_category(
category: &CreateCategory<'_>,
db: &mut Connection<Db>,
db: &mut Connection<MySQLDb>,
) -> Result<(), ServiceError> {
let id = query(r#"SELECT id FROM category WHERE name = ?"#)
.bind(category.name)
Expand All @@ -159,7 +159,7 @@ pub async fn add_category(
Ok(())
}

pub async fn delete_category(id: &str, db: &mut Connection<Db>) -> Result<(), ServiceError> {
pub async fn delete_category(id: &str, db: &mut Connection<MySQLDb>) -> Result<(), ServiceError> {
let sites_count =
query(r#"SELECT COUNT(site_id) AS count FROM category_site WHERE category_id = ?"#)
.bind(id)
Expand All @@ -183,7 +183,7 @@ pub async fn get_sites(
category_id: &str,
search: Option<&str>,
upload_url: &str,
db: &mut Connection<Db>,
db: &mut Connection<MySQLDb>,
) -> Result<Vec<response::site::Site>, ServiceError> {
let sites = match search {
Some(search) => query_as::<_, response::site::Site>(
Expand Down
13 changes: 8 additions & 5 deletions src/handlers/site.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ use crate::models::site::Site;
use crate::request::site::{CreateSite, UpdateSite};
use crate::response::site::SiteWithCategory;
use crate::response::WithTotal;
use crate::Db;
use crate::MySQLDb;

pub async fn get_sites(
page: i64,
size: i64,
search: Option<&str>,
upload_url: &str,
db: &mut Connection<Db>,
db: &mut Connection<MySQLDb>,
) -> Result<WithTotal<SiteWithCategory>, ServiceError> {
let count = match search {
Some(search) => {
Expand Down Expand Up @@ -82,7 +82,10 @@ pub async fn get_sites(
})
}

pub async fn add_site(site: &CreateSite<'_>, db: &mut Connection<Db>) -> Result<(), ServiceError> {
pub async fn add_site(
site: &CreateSite<'_>,
db: &mut Connection<MySQLDb>,
) -> Result<(), ServiceError> {
query_as::<_, Category>(
r#"SELECT id, name, description, icon, created_at, updated_at FROM category WHERE id = ?"#,
)
Expand Down Expand Up @@ -115,7 +118,7 @@ pub async fn add_site(site: &CreateSite<'_>, db: &mut Connection<Db>) -> Result<
pub async fn update_site(
site_id: &str,
site: &UpdateSite<'_>,
db: &mut Connection<Db>,
db: &mut Connection<MySQLDb>,
) -> Result<(), ServiceError> {
let record = query_as::<_, Site>(
r#"SELECT id, name, url, description, icon, created_at, updated_at FROM site WHERE id = ?"#,
Expand Down Expand Up @@ -201,7 +204,7 @@ pub async fn update_site(
Ok(())
}

pub async fn delete_site(id: &str, db: &mut Connection<Db>) -> Result<(), ServiceError> {
pub async fn delete_site(id: &str, db: &mut Connection<MySQLDb>) -> Result<(), ServiceError> {
query(r#"DELETE FROM category_site WHERE site_id = ?"#)
.bind(id)
.execute(&mut ***db)
Expand Down
8 changes: 4 additions & 4 deletions src/handlers/user.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ use sqlx::{query, query_as};
use crate::errors::ServiceError;
use crate::request::user::{UpdatePassword, UpdateUser};
use crate::response;
use crate::{models, Db};
use crate::{models, MySQLDb};

pub async fn get_user(
username: &str,
upload_url: &str,
db: &mut Connection<Db>,
db: &mut Connection<MySQLDb>,
) -> Result<response::user::User, ServiceError> {
let user = query_as::<_, models::user::User>("SELECT * FROM user WHERE username = ?")
.bind(username)
Expand Down Expand Up @@ -39,7 +39,7 @@ pub async fn get_user(
pub async fn update_user(
name: &'_ str,
user: &UpdateUser<'_>,
db: &mut Connection<Db>,
db: &mut Connection<MySQLDb>,
) -> Result<(), ServiceError> {
let record = query_as::<_, models::user::User>(
"SELECT username, password, email, avatar, nickname FROM user WHERE username = ?",
Expand Down Expand Up @@ -83,7 +83,7 @@ pub async fn update_user(
pub async fn update_user_password(
name: &'_ str,
user: &UpdatePassword<'_>,
db: &mut Connection<Db>,
db: &mut Connection<MySQLDb>,
) -> Result<(), ServiceError> {
let record = query_as::<_, models::user::User>(
"SELECT username, password, email, avatar, nickname FROM user WHERE username = ?",
Expand Down
7 changes: 6 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use rocket_db_pools::deadpool_redis;
use rocket_db_pools::Database;
use serde::{Deserialize, Serialize};
use sqlx::MySqlPool;
Expand All @@ -11,7 +12,11 @@ struct Claims {

#[derive(Database)]
#[database("startpage")]
pub struct Db(MySqlPool);
pub struct MySQLDb(MySqlPool);

#[derive(Database)]
#[database("cache")]
pub struct RedisDb(deadpool_redis::Pool);

pub mod errors;
pub mod handlers;
Expand Down
31 changes: 26 additions & 5 deletions src/middlewares.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
use rocket::http::Status;
use rocket::request::{FromRequest, Outcome};
use rocket_db_pools::deadpool_redis::redis::AsyncCommands;

use crate::config::Config;
use crate::Claims;
use crate::{Claims, RedisDb};

pub struct JwtMiddleware {
pub username: String,
Expand All @@ -12,6 +13,7 @@ pub struct JwtMiddleware {
#[derive(Debug)]
pub enum JwtError {
ConfigError,
CacheError,
MissingToken,
InvalidToken,
ExpiredToken,
Expand All @@ -37,7 +39,7 @@ impl<'r> FromRequest<'r> for JwtMiddleware {
None => return Outcome::Error((Status::Unauthorized, JwtError::MissingToken)),
};

let token = match decode::<Claims>(
let token_data = match decode::<Claims>(
token,
&DecodingKey::from_secret(config.jwt.secret.as_bytes()),
&Validation::new(Algorithm::HS256),
Expand All @@ -46,8 +48,27 @@ impl<'r> FromRequest<'r> for JwtMiddleware {
Err(_) => return Outcome::Error((Status::Unauthorized, JwtError::InvalidToken)),
};

Outcome::Success(JwtMiddleware {
username: token.claims.sub,
})
let username = token_data.claims.sub.clone();

let is_in_white_list: &Option<bool> = request
.local_cache_async(async {
let redis = request.guard::<&RedisDb>().await.succeeded()?;
let mut connection = redis.get().await.ok()?;

let value = connection.get::<_, String>(username.clone()).await.ok()?;

Some(value == token)
})
.await;

if is_in_white_list.is_none() {
return Outcome::Error((Status::Unauthorized, JwtError::CacheError));
}

if *is_in_white_list == Some(false) {
return Outcome::Error((Status::Unauthorized, JwtError::InvalidToken));
}

Outcome::Success(JwtMiddleware { username })
}
}
20 changes: 14 additions & 6 deletions src/routes/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,24 @@ use rocket::http::Status;
use rocket::post;
use rocket::serde::json::Json;
use rocket::State;
use rocket_db_pools::deadpool_redis::redis::AsyncCommands;
use rocket_db_pools::Connection;

use crate::config::Config;
use crate::middlewares::JwtMiddleware;
use crate::request;
use crate::response;
use crate::response::auth::Logout;
use crate::state::AppState;
use crate::{handlers, Db};
use crate::{handlers, request, response, MySQLDb, RedisDb};

#[post("/login", format = "json", data = "<user>")]
pub async fn login(
user: Json<request::auth::User<'_>>,
state: &State<AppState>,
config: &State<Config>,
mut db: Connection<Db>,
mut db: Connection<MySQLDb>,
mut cache: Connection<RedisDb>,
) -> Result<response::auth::JwtToken, Status> {
let token = handlers::auth::login(user.deref(), state, config, &mut db)
let token = handlers::auth::login(user.deref(), state, config, &mut db, &mut cache)
.await
.map_err(|e| {
error!("{}", e);
Expand All @@ -34,6 +34,14 @@ pub async fn login(
}

#[post("/logout")]
pub async fn logout(_jwt: JwtMiddleware) -> Result<Logout, Status> {
pub async fn logout(_jwt: JwtMiddleware, mut cache: Connection<RedisDb>) -> Result<Logout, Status> {
let username = _jwt.username.clone();

cache.del(username).await.map_err(|e| {
error!("{}", e);

Status::InternalServerError
})?;

Ok(Logout)
}
Loading

0 comments on commit 0880062

Please sign in to comment.