diff --git a/.sqlx/query-33b11051e779866db9aeb86d28a59db07a94323ffdc59a5a2c1da694ebe9a65f.json b/.sqlx/query-33b11051e779866db9aeb86d28a59db07a94323ffdc59a5a2c1da694ebe9a65f.json new file mode 100644 index 0000000..9aef6c4 --- /dev/null +++ b/.sqlx/query-33b11051e779866db9aeb86d28a59db07a94323ffdc59a5a2c1da694ebe9a65f.json @@ -0,0 +1,22 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT username\n FROM users\n WHERE user_id = $1\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "username", + "type_info": "Text" + } + ], + "parameters": { + "Left": [ + "Uuid" + ] + }, + "nullable": [ + false + ] + }, + "hash": "33b11051e779866db9aeb86d28a59db07a94323ffdc59a5a2c1da694ebe9a65f" +} diff --git a/Cargo.lock b/Cargo.lock index bf92e01..21ade4a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3024,6 +3024,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a183cf7feeba97b4dd1c0d46788634f6221d87fa961b305bed08c851829efcc0" dependencies = [ "getrandom", + "serde", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 39fbd61..39409f8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,7 +41,7 @@ tracing-bunyan-formatter = "0.3.9" tracing-log = "0.2.0" tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } unicode-segmentation = "1.11.0" -uuid = { version = "1.7.0", features = ["v4"] } +uuid = { version = "1.7.0", features = ["serde", "v4"] } validator = "0.16.1" [dev-dependencies] diff --git a/src/lib.rs b/src/lib.rs index 40f47c9..5dce688 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,5 +5,6 @@ pub mod domain; pub mod email_client; pub mod request_id; pub mod routes; +pub mod session_state; pub mod startup; pub mod telemetry; diff --git a/src/routes/admin/dashboard.rs b/src/routes/admin/dashboard.rs new file mode 100644 index 0000000..b1f2ab2 --- /dev/null +++ b/src/routes/admin/dashboard.rs @@ -0,0 +1,75 @@ +use crate::{app_state::AppState, session_state::TypedSession}; +use anyhow::Context; +use askama::Template; +use askama_axum::IntoResponse; +use axum::{ + extract::State, + http::StatusCode, + response::{Redirect, Response}, +}; +use sqlx::PgPool; +use uuid::Uuid; + +#[tracing::instrument(name = "Get admin dashboard", skip(app_state, session))] +pub(super) async fn admin_dashboard( + State(app_state): State, + session: TypedSession, +) -> Result { + let response = match session + .get_user_id() + .await + .map_err(|e| DashboardError::InvalidSession(e.into()))? + { + Some(user_id) => Dashboard { + title: "Admin Dashboard", + username: get_username(&app_state.db_pool, user_id).await?, + } + .into_response(), + None => Redirect::to("/login").into_response(), + }; + + Ok(response) +} + +#[tracing::instrument(skip(db_pool))] +async fn get_username(db_pool: &PgPool, user_id: Uuid) -> Result { + let row = sqlx::query!( + r#" + SELECT username + FROM users + WHERE user_id = $1 + "#, + user_id, + ) + .fetch_one(db_pool) + .await + .context("Failed to perform a query to retreive a username")?; + + Ok(row.username) +} + +#[derive(Template)] +#[template(path = "web/dashboard.html")] +struct Dashboard<'a> { + title: &'a str, + username: String, +} + +#[derive(Debug, thiserror::Error)] +pub enum DashboardError { + #[error("Invalid session")] + InvalidSession(#[source] anyhow::Error), + #[error("Something went wrong")] + UnexpectedError(#[from] anyhow::Error), +} + +impl IntoResponse for DashboardError { + fn into_response(self) -> Response { + tracing::error!("{:#?}", self); + + match self { + Self::InvalidSession(_) => StatusCode::UNAUTHORIZED.into_response(), + Self::UnexpectedError(_) => StatusCode::INTERNAL_SERVER_ERROR.into_response(), + } + } +} diff --git a/src/routes/admin/mod.rs b/src/routes/admin/mod.rs new file mode 100644 index 0000000..4a6618b --- /dev/null +++ b/src/routes/admin/mod.rs @@ -0,0 +1,9 @@ +use crate::app_state::AppState; +use axum::{routing::get, Router}; +use dashboard::admin_dashboard; + +mod dashboard; + +pub fn router() -> Router { + Router::new().route("/admin/dashboard", get(admin_dashboard)) +} diff --git a/src/routes/login/action.rs b/src/routes/login/action.rs index 34973da..8271ca1 100644 --- a/src/routes/login/action.rs +++ b/src/routes/login/action.rs @@ -1,6 +1,7 @@ use crate::{ app_state::AppState, authentication::{validate_credentials, AuthError, Credentials}, + session_state::TypedSession, }; use axum::{ extract::State, @@ -16,17 +17,18 @@ use secrecy::Secret; use serde::Deserialize; #[tracing::instrument( - skip(app_state, form, jar), + skip(app_state, form, session, jar), fields(username = tracing::field::Empty, user_id = tracing::field::Empty) )] pub(super) async fn login( State(app_state): State, + session: TypedSession, jar: SignedCookieJar, Form(form): Form, -) -> Result { +) -> Result { tracing::Span::current().record("username", &tracing::field::display(&form.username)); - let user_id = validate_credentials( + let user_id = match validate_credentials( &app_state.db_pool, Credentials { username: form.username, @@ -34,14 +36,33 @@ pub(super) async fn login( }, ) .await - .map_err(|e| match e { - AuthError::InvalidCredentials(_) => LoginError::AuthErrorWithResponse(e.into(), jar), - AuthError::UnexpectedError(_) => LoginError::UnexpectedError(e.into()), - })?; + { + Ok(user_id) => user_id, + Err(e) => match e { + AuthError::InvalidCredentials(_) => { + return Err(LoginErrorResponse::new_auth_with_redirect(e.into(), jar)); + } + AuthError::UnexpectedError(_) => { + return Err(LoginErrorResponse::new_unexpected(e.into())); + } + }, + }; tracing::Span::current().record("user_id", &tracing::field::display(&user_id)); - Ok(Redirect::to("/")) + if let Err(e) = session.cycle_id().await { + return Err(LoginErrorResponse::new_unexpected_with_redirect( + e.into(), + jar, + )); + } + + session + .insert_user_id(user_id) + .await + .map_err(|e| LoginErrorResponse::new_unexpected_with_redirect(e.into(), jar))?; + + Ok(Redirect::to("/admin/dashboard")) } #[derive(Deserialize)] @@ -50,40 +71,57 @@ pub(super) struct FormData { password: Secret, } -#[derive(Debug, thiserror::Error)] -pub(super) enum LoginError { - #[error("Authentication failed")] - AuthErrorWithResponse(#[source] anyhow::Error, SignedCookieJar), - #[error("Authentication failed")] - AuthError(#[source] anyhow::Error), - #[error("Something went wrong")] - UnexpectedError(#[from] anyhow::Error), +pub(super) struct LoginErrorResponse { + error: LoginError, + jar: Option, } -impl IntoResponse for LoginError { - fn into_response(self) -> Response { - match self { - Self::AuthErrorWithResponse(e, jar) => { - let error = Self::AuthError(e); - tracing::error!("{:#?}", error); +impl LoginErrorResponse { + fn new_unexpected(error: anyhow::Error) -> Self { + let error = LoginError::UnexpectedError(error); + Self { error, jar: None } + } - let jar = jar.add( - Cookie::build(("_flash", error.to_string())) - .http_only(true) - .same_site(SameSite::Strict), - ); - let redirect = Redirect::to("/login"); + fn new_auth_with_redirect(error: anyhow::Error, jar: SignedCookieJar) -> Self { + let error = LoginError::AuthError(error); + let jar = Some(jar.add(&error)); + Self { error, jar } + } - (jar, redirect).into_response() - } - Self::AuthError(_) => { - tracing::error!("{:#?}", self); - StatusCode::UNAUTHORIZED.into_response() - } - Self::UnexpectedError(_) => { - tracing::error!("{:#?}", self); + fn new_unexpected_with_redirect(error: anyhow::Error, jar: SignedCookieJar) -> Self { + let error = LoginError::UnexpectedError(error); + let jar = Some(jar.add(&error)); + Self { error, jar } + } +} + +impl IntoResponse for LoginErrorResponse { + fn into_response(self) -> Response { + tracing::error!("{:#?}", self.error); + + match (self.error, self.jar) { + (_, Some(jar)) => (jar, Redirect::to("/login")).into_response(), + (LoginError::AuthError(_), None) => StatusCode::UNAUTHORIZED.into_response(), + (LoginError::UnexpectedError(_), None) => { StatusCode::INTERNAL_SERVER_ERROR.into_response() } } } } + +#[derive(Debug, thiserror::Error)] +enum LoginError { + #[error("Authentication failed")] + AuthError(#[source] anyhow::Error), + #[error("Something went wrong")] + UnexpectedError(#[from] anyhow::Error), +} + +impl From<&LoginError> for Cookie<'static> { + fn from(value: &LoginError) -> Self { + Cookie::build(("_flash", value.to_string())) + .http_only(true) + .same_site(SameSite::Strict) + .build() + } +} diff --git a/src/routes/mod.rs b/src/routes/mod.rs index b37042d..1f9e88a 100644 --- a/src/routes/mod.rs +++ b/src/routes/mod.rs @@ -1,3 +1,4 @@ +pub mod admin; pub mod health_check; pub mod home; pub mod login; diff --git a/src/session_state.rs b/src/session_state.rs new file mode 100644 index 0000000..dc76168 --- /dev/null +++ b/src/session_state.rs @@ -0,0 +1,38 @@ +use axum::{ + async_trait, + extract::FromRequestParts, + http::{request::Parts, StatusCode}, +}; +use tower_sessions::{session::Error, Session}; +use uuid::Uuid; + +pub struct TypedSession(Session); + +impl TypedSession { + const USER_ID_KEY: &'static str = "user_id"; + + pub async fn cycle_id(&self) -> Result<(), Error> { + self.0.cycle_id().await + } + + pub async fn insert_user_id(&self, user_id: Uuid) -> Result<(), Error> { + self.0.insert(Self::USER_ID_KEY, user_id).await + } + + pub async fn get_user_id(&self) -> Result, Error> { + self.0.get(Self::USER_ID_KEY).await + } +} + +#[async_trait] +impl FromRequestParts for TypedSession +where + S: Send + Sync, +{ + type Rejection = (StatusCode, &'static str); + + async fn from_request_parts(req: &mut Parts, state: &S) -> Result { + let session = Session::from_request_parts(req, state).await?; + Ok(TypedSession(session)) + } +} diff --git a/src/startup.rs b/src/startup.rs index a15f3f4..cb7eb4e 100644 --- a/src/startup.rs +++ b/src/startup.rs @@ -3,7 +3,7 @@ use crate::{ configuration::{ApplicationSettings, DatabaseSettings, Settings}, email_client::EmailClient, request_id::RequestUuid, - routes::{health_check, home, login, newsletters, subscriptions, subscriptions_confirm}, + routes::{admin, health_check, home, login, newsletters, subscriptions, subscriptions_confirm}, telemetry::request_span, }; use anyhow::anyhow; @@ -132,6 +132,7 @@ async fn run( .merge(newsletters::router()) .merge(home::router()) .merge(login::router()) + .merge(admin::router()) .with_state(app_state) .layer( SessionManagerLayer::new(RedisStore::new(redis_pool)) diff --git a/templates/web/dashboard.html b/templates/web/dashboard.html new file mode 100644 index 0000000..fb1b499 --- /dev/null +++ b/templates/web/dashboard.html @@ -0,0 +1,5 @@ +{% extends "base.html" %} + +{% block content %} +

Welcome {{ username }}!

+{% endblock %} diff --git a/tests/api/admin_dashboard.rs b/tests/api/admin_dashboard.rs new file mode 100644 index 0000000..5328ea3 --- /dev/null +++ b/tests/api/admin_dashboard.rs @@ -0,0 +1,14 @@ +use crate::helpers::TestApp; + +#[tokio::test] +async fn login_is_required_to_access_admin_dashboard() { + // given + let app = TestApp::spawn().await; + + // when + let response = app.get_admin_dashboard().await; + + // then + assert_eq!(response.status(), 303); + assert_eq!(response.headers().get("Location").unwrap(), "/login"); +} diff --git a/tests/api/helpers.rs b/tests/api/helpers.rs index 24f30cf..ae25486 100644 --- a/tests/api/helpers.rs +++ b/tests/api/helpers.rs @@ -188,6 +188,18 @@ impl TestApp { .unwrap() } + pub async fn get_admin_dashboard(&self) -> Response { + self.client + .get(self.url("/admin/dashboard")) + .send() + .await + .expect("Failed to execute request") + } + + pub async fn get_admin_dashboard_html(&self) -> String { + self.get_admin_dashboard().await.text().await.unwrap() + } + fn url(&self, endpoint: &str) -> String { format!("http://{}{endpoint}", self.address) } diff --git a/tests/api/login.rs b/tests/api/login.rs index fe6d226..1300bc8 100644 --- a/tests/api/login.rs +++ b/tests/api/login.rs @@ -2,6 +2,27 @@ use crate::helpers::TestApp; use serde_json::json; use uuid::Uuid; +#[tokio::test] +async fn successful_login_redirects_to_admin_dashboard() { + // given + let app = TestApp::spawn().await; + let login_body = serde_json::json!({ + "username": &app.test_user.username, + "password": &app.test_user.password, + }); + + // when + let response = app.post_login(&login_body).await; + assert_eq!(response.status(), 303); + assert_eq!( + response.headers().get("Location").unwrap(), + "/admin/dashboard" + ); + + let html_page = app.get_admin_dashboard_html().await; + assert!(html_page.contains(&format!("Welcome {}", app.test_user.username))); +} + #[tokio::test] async fn an_error_flash_message_is_sent_on_failure() { // given diff --git a/tests/api/main.rs b/tests/api/main.rs index d7a12d3..26afd04 100644 --- a/tests/api/main.rs +++ b/tests/api/main.rs @@ -1,3 +1,4 @@ +mod admin_dashboard; mod health_check; mod helpers; mod login;