Skip to content

Commit

Permalink
migrate database to rocket db pools
Browse files Browse the repository at this point in the history
  • Loading branch information
huangcheng committed Nov 28, 2023
1 parent cc3c45b commit fee74b9
Show file tree
Hide file tree
Showing 17 changed files with 153 additions and 96 deletions.
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
DATABASE_URL=mysql://startpage:startpage@127.0.0.1/startpage

# you can generate a secret with `openssl rand -hex 32`
JWT_SECRET=
JWT_EXPIRES_IN=1w
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Build
run: cargo build --verbose
- name: Run tests
Expand Down
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ jsonwebtoken = { version = "9.1.0", default-features = false }
log = "0.4.20"
regex = "1.10.2"
rocket = { version = "0.5.0", features = ["json", "uuid"] }
rocket_db_pools = { version = "0.1.0", features = ["sqlx_mysql"] }
serde = { version = "1.0.193", features = ["derive"] }
serde_json = "1.0.108"
sqlx = { version = "0.7", features = [ "runtime-tokio", "mysql", "migrate", "uuid", "chrono" ] }
Expand Down
2 changes: 2 additions & 0 deletions Rocket.toml.example
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[default.databases.startpage]
url = "mysql://startpage:startpage@127.0.0.1/startpage"
12 changes: 5 additions & 7 deletions src/bin/server.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
use dotenvy::dotenv;
use rocket::{self, routes};
use rocket_db_pools::Database;
use sqlx::MySqlPool;

use startpage::routes::upload::upload;
use startpage::routes::{auth, category, user};
use startpage::state::AppState;
use startpage::utils::calculate_expires;
use startpage::Db;

fn drop_rocket(meta: &log::Metadata) -> bool {
let name = meta.target();
Expand Down Expand Up @@ -40,26 +43,20 @@ async fn main() -> Result<(), rocket::Error> {

dotenv().expect("Failed to read .env file");

let database_url = std::env::var("DATABASE_URL").expect("DATABASE_URL");

let jwt_secret = std::env::var("JWT_SECRET").expect("JWT_SECRET");

let jwt_expire_in = std::env::var("JWT_EXPIRES_IN").expect("JWT_EXPIRES_IN");

let jwt_expiration = calculate_expires(&jwt_expire_in).expect("Failed to calculate expires");

let pool = MySqlPool::connect(&database_url)
.await
.expect("Failed to connect to database");

let state = AppState {
pool,
jwt_secret,
jwt_expiration,
};

let _rok = rocket::build()
.manage(state)
.attach(Db::init())
.mount(
"/api/user",
routes![user::me, user::update, user::update_password],
Expand All @@ -78,6 +75,7 @@ async fn main() -> Result<(), rocket::Error> {
category::delete_site,
],
)
.mount("/api/upload", routes![upload])
.launch()
.await?;

Expand Down
8 changes: 5 additions & 3 deletions src/handlers/auth.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
use bcrypt::verify;
use jsonwebtoken::{encode, EncodingKey, Header};
use log::error;
use rocket_db_pools::Connection;
use sqlx::query_as;

use crate::errors::ServiceError;
use crate::models;
use crate::request;
use crate::state::AppState;
use crate::Claims;
use crate::{models, Db};

pub async fn login(
user: &request::auth::User<'_>,
state: &AppState,
db: &mut Connection<Db>,
) -> Result<String, ServiceError> {
let record = query_as::<_, models::user::User>(
r#"SELECT username, nickname, password, avatar, email FROM user WHERE username = ?"#,
)
.bind(user.username)
.fetch_one(&state.pool)
.fetch_one(&mut ***db)
.await
.map_err(|e| {
error!("Failed to query user: {}", e);
Expand All @@ -28,7 +30,7 @@ pub async fn login(
let valid = verify(user.password, &record.password).map_err(|e| {
error!("Failed to verify password: {}", e);

ServiceError::InternalServerError
ServiceError::Unauthorized
})?;

if valid {
Expand Down
86 changes: 47 additions & 39 deletions src/handlers/category.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use rocket::State;
use rocket_db_pools::Connection;
use sqlx::{query, query_as, Row};

use crate::errors::ServiceError;
Expand All @@ -7,33 +7,46 @@ use crate::models::site::Site;
use crate::request::category::{CreateCategory, UpdateCategory};
use crate::request::site::{CreateSite, UpdateSite};
use crate::response;
use crate::state::AppState;
use crate::response::WithTotal;
use crate::Db;

pub async fn get_all_categories(
state: &State<AppState>,
) -> Result<Vec<response::category::Category>, ServiceError> {
page: i64,
size: i64,
db: &mut Connection<Db>,
) -> Result<WithTotal<response::category::Category>, ServiceError> {
let total = query(r#"SELECT COUNT(id) AS count FROM category"#)
.fetch_one(&mut ***db)
.await?
.get::<i64, &str>("count");

let categories = sqlx::query_as::<_, Category>(
r#"SELECT id, name, description, created_at, updated_at FROM category"#,
r#"SELECT id, name, description, created_at, updated_at FROM category LIMIT ? OFFSET ?"#,
)
.fetch_all(&state.pool)
.bind(size)
.bind(page * size)
.fetch_all(&mut ***db)
.await?;

Ok(categories
.into_iter()
.map(|category| category.into())
.collect())
Ok(WithTotal {
total,
data: categories
.into_iter()
.map(|category| category.into())
.collect(),
})
}

pub async fn update_category<'r>(
id: &'r str,
category: &'r UpdateCategory<'r>,
state: &State<AppState>,
db: &mut Connection<Db>,
) -> Result<Category, ServiceError> {
let record = query_as::<_, Category>(
r#"SELECT id, name, description, created_at, updated_at FROM category WHERE id = ?"#,
)
.bind(id)
.fetch_one(&state.pool)
.fetch_one(&mut ***db)
.await
.map_err(|e| match e {
sqlx::Error::RowNotFound => ServiceError::BadRequest(String::from("Category not found")),
Expand Down Expand Up @@ -62,30 +75,30 @@ pub async fn update_category<'r>(
.bind(&record.name)
.bind(&record.description)
.bind(record.id)
.execute(&state.pool)
.execute(&mut ***db)
.await?;

Ok(record)
}

pub async fn add_category(
category: &CreateCategory<'_>,
state: &State<AppState>,
db: &mut Connection<Db>,
) -> Result<(), ServiceError> {
query(r#"INSERT INTO category (name, description) VALUES (?, ?)"#)
.bind(category.name)
.bind(category.description)
.execute(&state.pool)
.execute(&mut ***db)
.await?;

Ok(())
}

pub async fn delete_category(id: &str, state: &State<AppState>) -> Result<(), ServiceError> {
pub async fn delete_category(id: &str, db: &mut Connection<Db>) -> Result<(), ServiceError> {
let sites_count =
query(r#"SELECT COUNT(site_id) AS count FROM category_site WHERE category_id = ?"#)
.bind(id)
.fetch_one(&state.pool)
.fetch_one(&mut ***db)
.await?
.get::<i64, &str>("count");

Expand All @@ -95,7 +108,7 @@ pub async fn delete_category(id: &str, state: &State<AppState>) -> Result<(), Se

query(r#"DELETE FROM category WHERE id = ?"#)
.bind(id)
.execute(&state.pool)
.execute(&mut ***db)
.await?;

Ok(())
Expand All @@ -104,13 +117,13 @@ pub async fn delete_category(id: &str, state: &State<AppState>) -> Result<(), Se
pub async fn add_site(
category_id: &str,
site: &CreateSite<'_>,
state: &State<AppState>,
db: &mut Connection<Db>,
) -> Result<(), ServiceError> {
query_as::<_, Category>(
r#"SELECT id, name, description, created_at, updated_at FROM category WHERE id = ?"#,
)
.bind(category_id)
.fetch_one(&state.pool)
.fetch_one(&mut ***db)
.await
.map_err(|e| match e {
sqlx::Error::RowNotFound => ServiceError::BadRequest(String::from("Category not found")),
Expand All @@ -122,33 +135,28 @@ pub async fn add_site(
.bind(site.url)
.bind(site.description)
.bind(site.icon)
.execute(&state.pool)
.execute(&mut ***db)
.await?
.last_insert_id();

query(r#"INSERT INTO category_site (category_id, site_id) VALUES (?, ?)"#)
.bind(category_id)
.bind(id)
.execute(&state.pool)
.execute(&mut ***db)
.await?;

Ok(())
}

pub async fn get_sites(
category_id: &str,
state: &State<AppState>,
db: &mut Connection<Db>,
) -> Result<Vec<response::site::Site>, ServiceError> {
let sites = query_as::<_, response::site::Site>(
r#"SELECT site.id, site.name, site.url, site.description, site.icon, site.created_at, site.updated_at, category.name AS category
FROM site
INNER JOIN category_site ON site.id = category_site.site_id
INNER JOIN category ON category.id = category_site.category_id
WHERE site.id IN (SELECT site_id FROM category_site WHERE category_id = ?)
"#,
r#"SELECT id, name, url, description, icon FROM site WHERE id IN (SELECT site_id FROM category_site WHERE category_id = ?)"#,
)
.bind(category_id)
.fetch_all(&state.pool)
.fetch_all(&mut ***db)
.await?;

Ok(sites)
Expand All @@ -158,13 +166,13 @@ pub async fn modify_site(
category_id: &str,
site_id: &str,
site: &UpdateSite<'_>,
state: &State<AppState>,
db: &mut Connection<Db>,
) -> Result<(), ServiceError> {
let record = query_as::<_, Site>(
r#"SELECT id, name, url, description, icon, created_at, updated_at FROM site WHERE id = ?"#,
)
.bind(site_id)
.fetch_one(&state.pool)
.fetch_one(&mut ***db)
.await
.map_err(|e| match e {
sqlx::Error::RowNotFound => ServiceError::BadRequest(String::from("Site not found")),
Expand Down Expand Up @@ -207,7 +215,7 @@ pub async fn modify_site(
.bind(&record.description)
.bind(&record.icon)
.bind(record.id)
.execute(&state.pool)
.execute(&mut ***db)
.await?;

let category_id = match site.category_id {
Expand All @@ -217,7 +225,7 @@ pub async fn modify_site(

query(r#"SELECT id FROM category WHERE id = ?"#)
.bind(&category_id)
.fetch_one(&state.pool)
.fetch_one(&mut ***db)
.await
.map_err(|e| match e {
sqlx::Error::RowNotFound => {
Expand All @@ -229,7 +237,7 @@ pub async fn modify_site(
query(r#"UPDATE category_site SET category_id = ? WHERE site_id = ?"#)
.bind(&category_id)
.bind(record.id)
.execute(&state.pool)
.execute(&mut ***db)
.await?;

Ok(())
Expand All @@ -238,14 +246,14 @@ pub async fn modify_site(
pub async fn delete_site(
category_id: &str,
site_id: &str,
state: &State<AppState>,
db: &mut Connection<Db>,
) -> Result<(), ServiceError> {
query(
r#"SELECT category_id, site_id FROM category_site WHERE category_id = ? AND site_id = ?"#,
)
.bind(category_id)
.bind(site_id)
.fetch_one(&state.pool)
.fetch_one(&mut ***db)
.await
.map_err(|e| match e {
sqlx::Error::RowNotFound => ServiceError::BadRequest(String::from("Site not found")),
Expand All @@ -255,12 +263,12 @@ pub async fn delete_site(
query(r#"DELETE FROM category_site WHERE category_id = ? AND site_id = ?"#)
.bind(category_id)
.bind(site_id)
.execute(&state.pool)
.execute(&mut ***db)
.await?;

query(r#"DELETE FROM site WHERE id = ?"#)
.bind(site_id)
.execute(&state.pool)
.execute(&mut ***db)
.await?;

Ok(())
Expand Down
Loading

0 comments on commit fee74b9

Please sign in to comment.