diff --git a/Cargo.lock b/Cargo.lock index 613046d..a973488 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -210,6 +210,29 @@ dependencies = [ "tracing", ] +[[package]] +name = "axum-extra" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0be6ea09c9b96cb5076af0de2e383bd2bc0c18f827cf1967bdd353e0b910d733" +dependencies = [ + "axum", + "axum-core", + "bytes", + "cookie 0.18.1", + "futures-util", + "http 1.1.0", + "http-body 1.0.0", + "http-body-util", + "mime", + "pin-project-lite", + "serde", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "backtrace" version = "0.3.71" @@ -384,6 +407,50 @@ dependencies = [ "unicode-segmentation", ] +[[package]] +name = "cookie" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7efb37c3e1ccb1ff97164ad95ac1606e8ccd35b3fa0a7d99a304c7f4a428cc24" +dependencies = [ + "percent-encoding", + "time", + "version_check", +] + +[[package]] +name = "cookie" +version = "0.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ddef33a339a91ea89fb53151bd0a4689cfce27055c291dfa69945475d22c747" +dependencies = [ + "base64 0.22.0", + "hmac", + "percent-encoding", + "rand", + "sha2", + "subtle", + "time", + "version_check", +] + +[[package]] +name = "cookie_store" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "387461abbc748185c3a6e1673d826918b450b87ff22639429c694619a83b6cf6" +dependencies = [ + "cookie 0.17.0", + "idna 0.3.0", + "log", + "publicsuffix", + "serde", + "serde_derive", + "serde_json", + "time", + "url", +] + [[package]] name = "core-foundation" version = "0.9.4" @@ -883,12 +950,6 @@ dependencies = [ "windows-sys 0.52.0", ] -[[package]] -name = "htmlescape" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e9025058dae765dee5070ec375f591e2ba14638c63feff74f13805a72e523163" - [[package]] name = "http" version = "0.2.12" @@ -1018,6 +1079,16 @@ dependencies = [ "tokio", ] +[[package]] +name = "idna" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e14ddfc70884202db2244c223200c204c2bda1bc6e0998d11b5e024d657209e6" +dependencies = [ + "unicode-bidi", + "unicode-normalization", +] + [[package]] name = "idna" version = "0.4.0" @@ -1607,6 +1678,22 @@ dependencies = [ "unarray", ] +[[package]] +name = "psl-types" +version = "2.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33cb294fe86a74cbcf50d4445b37da762029549ebeea341421c7c70370f86cac" + +[[package]] +name = "publicsuffix" +version = "2.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96a8c1bda5ae1af7f99a2962e49df150414a43d62404644d98dd5c3a93d07457" +dependencies = [ + "idna 0.3.0", + "psl-types", +] + [[package]] name = "quick-error" version = "1.2.3" @@ -1722,6 +1809,8 @@ checksum = "dd67538700a17451e7cba03ac727fb961abb7607553461627b97de0b89cf4a62" dependencies = [ "base64 0.21.7", "bytes", + "cookie 0.17.0", + "cookie_store", "encoding_rs", "futures-core", "futures-util", @@ -3112,13 +3201,11 @@ dependencies = [ "askama", "askama_axum", "axum", + "axum-extra", "base64 0.22.0", "claims", "config", "fake", - "hex", - "hmac", - "htmlescape", "linkify", "once_cell", "proptest", @@ -3129,7 +3216,6 @@ dependencies = [ "serde", "serde-aux", "serde_json", - "sha2", "sqlx", "thiserror", "time", @@ -3141,7 +3227,6 @@ dependencies = [ "tracing-log 0.2.0", "tracing-subscriber", "unicode-segmentation", - "urlencoding", "uuid", "validator", "wiremock", diff --git a/Cargo.toml b/Cargo.toml index a8c6517..1c35eb9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,19 +18,16 @@ argon2 = { version = "0.5.3", features = ["std"] } askama = { version = "0.12.1", features = ["with-axum"], default-features = false } askama_axum = { version = "0.4.0", default-features = false } axum = "0.7.4" +axum-extra = { version = "0.9.3", features = ["cookie-signed"] } base64 = "0.22.0" config = "0.14.0" -hex = "0.4.3" -hmac = { version = "0.12.1", features = ["std"] } -htmlescape = "0.3.1" once_cell = "1.19.0" rand = "0.8.5" regex = "1.10.3" -reqwest = { version = "0.11.24", features = ["json"], default-features = false } +reqwest = { version = "0.11.24", features = ["cookies", "json"], default-features = false } secrecy = { version = "0.8.0", features = ["serde"] } serde = { version = "1.0.196", features = ["derive"] } serde-aux = { version = "4.4.0", default-features = false } -sha2 = "0.10.8" sqlx = { version = "0.7.3", features = ["macros", "migrate", "postgres", "time", "runtime-tokio", "tls-native-tls", "uuid"], default-features = false } thiserror = "1.0.58" time = { version = "0.3.34", features = ["macros", "serde"] } @@ -42,7 +39,6 @@ tracing-bunyan-formatter = "0.3.9" tracing-log = "0.2.0" tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } unicode-segmentation = "1.11.0" -urlencoding = "2.1.3" uuid = { version = "1.7.0", features = ["v4"] } validator = "0.16.1" diff --git a/src/app_state.rs b/src/app_state.rs index 10a32b9..da95e59 100644 --- a/src/app_state.rs +++ b/src/app_state.rs @@ -1,6 +1,6 @@ use crate::email_client::EmailClient; -use axum::http::Uri; -use secrecy::Secret; +use axum::{extract::FromRef, http::Uri}; +use axum_extra::extract::cookie::Key; use sqlx::PgPool; #[derive(Clone)] @@ -8,5 +8,11 @@ pub struct AppState { pub db_pool: PgPool, pub email_client: EmailClient, pub base_url: Uri, - pub hmac_secret: Secret, + pub hmac_secret: Key, +} + +impl FromRef for Key { + fn from_ref(state: &AppState) -> Self { + state.hmac_secret.clone() + } } diff --git a/src/routes/login/action.rs b/src/routes/login/action.rs index 74aa7cc..34973da 100644 --- a/src/routes/login/action.rs +++ b/src/routes/login/action.rs @@ -8,17 +8,20 @@ use axum::{ response::{IntoResponse, Redirect, Response}, Form, }; -use hmac::{Hmac, Mac}; -use secrecy::{ExposeSecret, Secret}; +use axum_extra::extract::{ + cookie::{Cookie, SameSite}, + SignedCookieJar, +}; +use secrecy::Secret; use serde::Deserialize; -use urlencoding::Encoded; #[tracing::instrument( - skip(app_state, form), + skip(app_state, form, jar), fields(username = tracing::field::Empty, user_id = tracing::field::Empty) )] pub(super) async fn login( State(app_state): State, + jar: SignedCookieJar, Form(form): Form, ) -> Result { tracing::Span::current().record("username", &tracing::field::display(&form.username)); @@ -32,17 +35,7 @@ pub(super) async fn login( ) .await .map_err(|e| match e { - AuthError::InvalidCredentials(_) => { - let error = format!("error={}", Encoded::new(e.to_string())); - let tag = format!("tag={:x}", { - let secret: &[u8] = app_state.hmac_secret.expose_secret().as_bytes(); - let mut mac = Hmac::::new_from_slice(secret).unwrap(); - mac.update(error.as_bytes()); - mac.finalize().into_bytes() - }); - - LoginError::AuthError(e.into(), Redirect::to(&format!("/login?{error}&{tag}"))) - } + AuthError::InvalidCredentials(_) => LoginError::AuthErrorWithResponse(e.into(), jar), AuthError::UnexpectedError(_) => LoginError::UnexpectedError(e.into()), })?; @@ -60,18 +53,37 @@ pub(super) struct FormData { #[derive(Debug, thiserror::Error)] pub(super) enum LoginError { #[error("Authentication failed")] - AuthError(#[source] anyhow::Error, Redirect), + AuthErrorWithResponse(#[source] anyhow::Error, SignedCookieJar), + #[error("Authentication failed")] + AuthError(#[source] anyhow::Error), #[error("Something went wrong")] UnexpectedError(#[from] anyhow::Error), } impl IntoResponse for LoginError { fn into_response(self) -> Response { - tracing::error!("{:#?}", self); - match self { - Self::AuthError(_, redirect) => redirect.into_response(), - Self::UnexpectedError(_) => StatusCode::INTERNAL_SERVER_ERROR.into_response(), + Self::AuthErrorWithResponse(e, jar) => { + let error = Self::AuthError(e); + tracing::error!("{:#?}", error); + + let jar = jar.add( + Cookie::build(("_flash", error.to_string())) + .http_only(true) + .same_site(SameSite::Strict), + ); + let redirect = Redirect::to("/login"); + + (jar, redirect).into_response() + } + Self::AuthError(_) => { + tracing::error!("{:#?}", self); + StatusCode::UNAUTHORIZED.into_response() + } + Self::UnexpectedError(_) => { + tracing::error!("{:#?}", self); + StatusCode::INTERNAL_SERVER_ERROR.into_response() + } } } } diff --git a/src/routes/login/form.rs b/src/routes/login/form.rs index 784988f..97437fd 100644 --- a/src/routes/login/form.rs +++ b/src/routes/login/form.rs @@ -1,63 +1,25 @@ -use crate::app_state::AppState; -use anyhow::Context; use askama_axum::Template; -use axum::extract::{Query, State}; -use hmac::{Hmac, Mac}; -use secrecy::{ExposeSecret, Secret}; -use serde::Deserialize; -use urlencoding::Encoded; +use axum_extra::extract::{cookie::Cookie, SignedCookieJar}; -#[tracing::instrument(skip(app_state, parameters))] -pub(super) async fn login_form( - State(app_state): State, - Query(parameters): Query, -) -> LoginForm<'static> { - let error_message = match parameters.error_message(&app_state.hmac_secret) { - Ok(raw_html) => raw_html, - Err(e) => { - tracing::warn!("Failed to get error message from query parameters: {:?}", e); - None - } - } - .map(|raw_html| htmlescape::encode_minimal(&raw_html)); +#[tracing::instrument(skip(jar))] +pub(super) async fn login_form(jar: SignedCookieJar) -> (SignedCookieJar, LoginForm<'static>) { + const FLASH: &str = "_flash"; - LoginForm { - title: "Login", - username_label: "Username", - username_placeholder: "Enter username", - password_label: "Password", - password_placeholder: "Enter password", - submit_label: "Login", - error_message, - action: "/login", - } -} - -#[derive(Deserialize)] -pub(super) struct Parameters { - error: Option, - tag: Option, -} - -impl Parameters { - fn error_message(self, hmac_secret: &Secret) -> Result, anyhow::Error> { - match (&self.error, self.tag) { - (Some(e), Some(t)) => { - let tag = hex::decode(t).context("Failed to decode hex hmac tag")?; - let error = format!("error={}", Encoded::new(e)); - - let mut mac = - Hmac::::new_from_slice(hmac_secret.expose_secret().as_bytes())?; - mac.update(error.as_bytes()); - mac.verify_slice(&tag)?; + let flash = jar.get(FLASH).map(|c| c.value().into()); - Ok(self.error) - } - (None, None) => Ok(None), - (Some(_), None) => Err(anyhow::anyhow!("Error message is missing hmac tag")), - (None, Some(_)) => Err(anyhow::anyhow!("Hmac tag is missing error message")), - } - } + ( + jar.remove(Cookie::from(FLASH)), + LoginForm { + title: "Login", + username_label: "Username", + username_placeholder: "Enter username", + password_label: "Password", + password_placeholder: "Enter password", + submit_label: "Login", + flash, + action: "/login", + }, + ) } #[derive(Template)] @@ -69,6 +31,6 @@ pub(super) struct LoginForm<'a> { password_label: &'a str, password_placeholder: &'a str, submit_label: &'a str, - error_message: Option, + flash: Option, action: &'a str, } diff --git a/src/startup.rs b/src/startup.rs index c587cfb..7b9d16d 100644 --- a/src/startup.rs +++ b/src/startup.rs @@ -7,7 +7,8 @@ use crate::{ telemetry::request_span, }; use axum::{http::Uri, serve::Serve, Router}; -use secrecy::Secret; +use axum_extra::extract::cookie::Key; +use secrecy::{ExposeSecret, Secret}; use sqlx::{postgres::PgPoolOptions, PgPool}; use std::{net::SocketAddr, str::FromStr}; use tokio::net::TcpListener; @@ -87,7 +88,7 @@ async fn run( db_pool, email_client, base_url: Uri::from_str(&base_url).expect("Failed to parse base url"), - hmac_secret, + hmac_secret: Key::from(hmac_secret.expose_secret().as_bytes()), }; let app = Router::new() diff --git a/templates/web/login_form.html b/templates/web/login_form.html index bba1153..44e7533 100644 --- a/templates/web/login_form.html +++ b/templates/web/login_form.html @@ -1,8 +1,8 @@ {% extends "base.html" %} {% block content %} -{%- if let Some(message) = error_message %} -

{{ message }}

+{%- if let Some(flash) = flash %} +

{{ flash }}

{%- endif %}
diff --git a/tests/api/helpers.rs b/tests/api/helpers.rs index 3440326..4951401 100644 --- a/tests/api/helpers.rs +++ b/tests/api/helpers.rs @@ -2,6 +2,7 @@ use argon2::{password_hash::SaltString, Algorithm, Argon2, Params, PasswordHashe use claims::assert_some_eq; use linkify::{LinkFinder, LinkKind}; use once_cell::sync::Lazy; +use reqwest::{redirect, Response}; use sqlx::{Connection, Executor, PgConnection, PgPool}; use std::{net::SocketAddr, str::FromStr}; use uuid::Uuid; @@ -52,6 +53,12 @@ impl TestApp { let test_user = TestUser::generate(); test_user.store(&db_pool).await; + let client = reqwest::Client::builder() + .redirect(redirect::Policy::none()) + .cookie_store(true) + .build() + .unwrap(); + tokio::spawn(app.run_until_stopped()); Self { @@ -59,7 +66,7 @@ impl TestApp { db_pool, email_server, test_user, - client: reqwest::Client::new(), + client, } } @@ -158,6 +165,29 @@ impl TestApp { ConfirmationLinks { html, plain_text } } + pub async fn post_login(&self, body: &Body) -> Response + where + Body: serde::Serialize, + { + self.client + .post(self.url("/login")) + .form(body) + .send() + .await + .expect("Failed to execute request") + } + + pub async fn get_login_html(&self) -> String { + self.client + .get(self.url("/login")) + .send() + .await + .expect("Failed to execute request") + .text() + .await + .unwrap() + } + fn url(&self, endpoint: &str) -> String { format!("http://{}{endpoint}", self.address) } diff --git a/tests/api/login.rs b/tests/api/login.rs new file mode 100644 index 0000000..fe6d226 --- /dev/null +++ b/tests/api/login.rs @@ -0,0 +1,40 @@ +use crate::helpers::TestApp; +use serde_json::json; +use uuid::Uuid; + +#[tokio::test] +async fn an_error_flash_message_is_sent_on_failure() { + // given + let app = TestApp::spawn().await; + let login_body = json!({ + "username": Uuid::new_v4().to_string(), + "password": Uuid::new_v4().to_string(), + }); + + // when + let response = app.post_login(&login_body).await; + + // then + assert_eq!(response.status(), 303); + assert_eq!(response.headers().get("Location").unwrap(), "/login"); +} + +#[tokio::test] +async fn an_error_flash_message_is_set_on_failure() { + // given + let app = TestApp::spawn().await; + let login_body = json!({ + "username": Uuid::new_v4().to_string(), + "password": Uuid::new_v4().to_string(), + }); + + // when + app.post_login(&login_body).await; + + // then + let html_page = app.get_login_html().await; + assert!(html_page.contains("

Authentication failed

")); + + let html_page = app.get_login_html().await; + assert!(!html_page.contains("

Authentication failed

")); +} diff --git a/tests/api/main.rs b/tests/api/main.rs index 129fe56..d7a12d3 100644 --- a/tests/api/main.rs +++ b/tests/api/main.rs @@ -1,5 +1,6 @@ mod health_check; mod helpers; +mod login; mod newsletters; mod subscriptions; mod subscriptions_confirm;