Skip to content
This repository has been archived by the owner on Aug 21, 2024. It is now read-only.

Commit

Permalink
Reject unauthenticated admin users
Browse files Browse the repository at this point in the history
  • Loading branch information
0rzech committed Apr 6, 2024
1 parent b7de87a commit e62103c
Show file tree
Hide file tree
Showing 15 changed files with 277 additions and 38 deletions.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ tracing-bunyan-formatter = "0.3.9"
tracing-log = "0.2.0"
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
unicode-segmentation = "1.11.0"
uuid = { version = "1.7.0", features = ["v4"] }
uuid = { version = "1.7.0", features = ["serde", "v4"] }
validator = "0.16.1"

[dev-dependencies]
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ pub mod domain;
pub mod email_client;
pub mod request_id;
pub mod routes;
pub mod session_state;
pub mod startup;
pub mod telemetry;
75 changes: 75 additions & 0 deletions src/routes/admin/dashboard.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
use crate::{app_state::AppState, session_state::TypedSession};
use anyhow::Context;
use askama::Template;
use askama_axum::IntoResponse;
use axum::{
extract::State,
http::StatusCode,
response::{Redirect, Response},
};
use sqlx::PgPool;
use uuid::Uuid;

#[tracing::instrument(name = "Get admin dashboard", skip(app_state, session))]
pub(super) async fn admin_dashboard(
State(app_state): State<AppState>,
session: TypedSession,
) -> Result<Response, DashboardError> {
let response = match session
.get_user_id()
.await
.map_err(|e| DashboardError::InvalidSession(e.into()))?
{
Some(user_id) => Dashboard {
title: "Admin Dashboard",
username: get_username(&app_state.db_pool, user_id).await?,
}
.into_response(),
None => Redirect::to("/login").into_response(),
};

Ok(response)
}

#[tracing::instrument(skip(db_pool))]
async fn get_username(db_pool: &PgPool, user_id: Uuid) -> Result<String, anyhow::Error> {
let row = sqlx::query!(
r#"
SELECT username
FROM users
WHERE user_id = $1
"#,
user_id,
)
.fetch_one(db_pool)
.await
.context("Failed to perform a query to retreive a username")?;

Ok(row.username)
}

#[derive(Template)]
#[template(path = "web/dashboard.html")]
struct Dashboard<'a> {
title: &'a str,
username: String,
}

#[derive(Debug, thiserror::Error)]
pub enum DashboardError {
#[error("Invalid session")]
InvalidSession(#[source] anyhow::Error),
#[error("Something went wrong")]
UnexpectedError(#[from] anyhow::Error),
}

impl IntoResponse for DashboardError {
fn into_response(self) -> Response {
tracing::error!("{:#?}", self);

match self {
Self::InvalidSession(_) => StatusCode::UNAUTHORIZED.into_response(),
Self::UnexpectedError(_) => StatusCode::INTERNAL_SERVER_ERROR.into_response(),
}
}
}
9 changes: 9 additions & 0 deletions src/routes/admin/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
use crate::app_state::AppState;
use axum::{routing::get, Router};
use dashboard::admin_dashboard;

mod dashboard;

pub fn router() -> Router<AppState> {
Router::new().route("/admin/dashboard", get(admin_dashboard))
}
110 changes: 74 additions & 36 deletions src/routes/login/action.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::{
app_state::AppState,
authentication::{validate_credentials, AuthError, Credentials},
session_state::TypedSession,
};
use axum::{
extract::State,
Expand All @@ -16,32 +17,52 @@ use secrecy::Secret;
use serde::Deserialize;

#[tracing::instrument(
skip(app_state, form, jar),
skip(app_state, form, session, jar),
fields(username = tracing::field::Empty, user_id = tracing::field::Empty)
)]
pub(super) async fn login(
State(app_state): State<AppState>,
session: TypedSession,
jar: SignedCookieJar,
Form(form): Form<FormData>,
) -> Result<Redirect, LoginError> {
) -> Result<Redirect, LoginErrorResponse> {
tracing::Span::current().record("username", &tracing::field::display(&form.username));

let user_id = validate_credentials(
let user_id = match validate_credentials(
&app_state.db_pool,
Credentials {
username: form.username,
password: form.password,
},
)
.await
.map_err(|e| match e {
AuthError::InvalidCredentials(_) => LoginError::AuthErrorWithResponse(e.into(), jar),
AuthError::UnexpectedError(_) => LoginError::UnexpectedError(e.into()),
})?;
{
Ok(user_id) => user_id,
Err(e) => match e {
AuthError::InvalidCredentials(_) => {
return Err(LoginErrorResponse::new_auth_with_redirect(e.into(), jar));
}
AuthError::UnexpectedError(_) => {
return Err(LoginErrorResponse::new_unexpected(e.into()));
}
},
};

tracing::Span::current().record("user_id", &tracing::field::display(&user_id));

Ok(Redirect::to("/"))
if let Err(e) = session.cycle_id().await {
return Err(LoginErrorResponse::new_unexpected_with_redirect(
e.into(),
jar,
));
}

session
.insert_user_id(user_id)
.await
.map_err(|e| LoginErrorResponse::new_unexpected_with_redirect(e.into(), jar))?;

Ok(Redirect::to("/admin/dashboard"))
}

#[derive(Deserialize)]
Expand All @@ -50,40 +71,57 @@ pub(super) struct FormData {
password: Secret<String>,
}

#[derive(Debug, thiserror::Error)]
pub(super) enum LoginError {
#[error("Authentication failed")]
AuthErrorWithResponse(#[source] anyhow::Error, SignedCookieJar),
#[error("Authentication failed")]
AuthError(#[source] anyhow::Error),
#[error("Something went wrong")]
UnexpectedError(#[from] anyhow::Error),
pub(super) struct LoginErrorResponse {
error: LoginError,
jar: Option<SignedCookieJar>,
}

impl IntoResponse for LoginError {
fn into_response(self) -> Response {
match self {
Self::AuthErrorWithResponse(e, jar) => {
let error = Self::AuthError(e);
tracing::error!("{:#?}", error);
impl LoginErrorResponse {
fn new_unexpected(error: anyhow::Error) -> Self {
let error = LoginError::UnexpectedError(error);
Self { error, jar: None }
}

let jar = jar.add(
Cookie::build(("_flash", error.to_string()))
.http_only(true)
.same_site(SameSite::Strict),
);
let redirect = Redirect::to("/login");
fn new_auth_with_redirect(error: anyhow::Error, jar: SignedCookieJar) -> Self {
let error = LoginError::AuthError(error);
let jar = Some(jar.add(&error));
Self { error, jar }
}

(jar, redirect).into_response()
}
Self::AuthError(_) => {
tracing::error!("{:#?}", self);
StatusCode::UNAUTHORIZED.into_response()
}
Self::UnexpectedError(_) => {
tracing::error!("{:#?}", self);
fn new_unexpected_with_redirect(error: anyhow::Error, jar: SignedCookieJar) -> Self {
let error = LoginError::UnexpectedError(error);
let jar = Some(jar.add(&error));
Self { error, jar }
}
}

impl IntoResponse for LoginErrorResponse {
fn into_response(self) -> Response {
tracing::error!("{:#?}", self.error);

match (self.error, self.jar) {
(_, Some(jar)) => (jar, Redirect::to("/login")).into_response(),
(LoginError::AuthError(_), None) => StatusCode::UNAUTHORIZED.into_response(),
(LoginError::UnexpectedError(_), None) => {
StatusCode::INTERNAL_SERVER_ERROR.into_response()
}
}
}
}

#[derive(Debug, thiserror::Error)]
enum LoginError {
#[error("Authentication failed")]
AuthError(#[source] anyhow::Error),
#[error("Something went wrong")]
UnexpectedError(#[from] anyhow::Error),
}

impl From<&LoginError> for Cookie<'static> {
fn from(value: &LoginError) -> Self {
Cookie::build(("_flash", value.to_string()))
.http_only(true)
.same_site(SameSite::Strict)
.build()
}
}
1 change: 1 addition & 0 deletions src/routes/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod admin;
pub mod health_check;
pub mod home;
pub mod login;
Expand Down
38 changes: 38 additions & 0 deletions src/session_state.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use axum::{
async_trait,
extract::FromRequestParts,
http::{request::Parts, StatusCode},
};
use tower_sessions::{session::Error, Session};
use uuid::Uuid;

pub struct TypedSession(Session);

impl TypedSession {
const USER_ID_KEY: &'static str = "user_id";

pub async fn cycle_id(&self) -> Result<(), Error> {
self.0.cycle_id().await
}

pub async fn insert_user_id(&self, user_id: Uuid) -> Result<(), Error> {
self.0.insert(Self::USER_ID_KEY, user_id).await
}

pub async fn get_user_id(&self) -> Result<Option<Uuid>, Error> {
self.0.get(Self::USER_ID_KEY).await
}
}

#[async_trait]
impl<S> FromRequestParts<S> for TypedSession
where
S: Send + Sync,
{
type Rejection = (StatusCode, &'static str);

async fn from_request_parts(req: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let session = Session::from_request_parts(req, state).await?;
Ok(TypedSession(session))
}
}
3 changes: 2 additions & 1 deletion src/startup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::{
configuration::{ApplicationSettings, DatabaseSettings, Settings},
email_client::EmailClient,
request_id::RequestUuid,
routes::{health_check, home, login, newsletters, subscriptions, subscriptions_confirm},
routes::{admin, health_check, home, login, newsletters, subscriptions, subscriptions_confirm},
telemetry::request_span,
};
use anyhow::anyhow;
Expand Down Expand Up @@ -132,6 +132,7 @@ async fn run(
.merge(newsletters::router())
.merge(home::router())
.merge(login::router())
.merge(admin::router())
.with_state(app_state)
.layer(
SessionManagerLayer::new(RedisStore::new(redis_pool))
Expand Down
5 changes: 5 additions & 0 deletions templates/web/dashboard.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{% extends "base.html" %}

{% block content %}
<p>Welcome {{ username }}!</p>
{% endblock %}
14 changes: 14 additions & 0 deletions tests/api/admin_dashboard.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
use crate::helpers::TestApp;

#[tokio::test]
async fn login_is_required_to_access_admin_dashboard() {
// given
let app = TestApp::spawn().await;

// when
let response = app.get_admin_dashboard().await;

// then
assert_eq!(response.status(), 303);
assert_eq!(response.headers().get("Location").unwrap(), "/login");
}
12 changes: 12 additions & 0 deletions tests/api/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,18 @@ impl TestApp {
.unwrap()
}

pub async fn get_admin_dashboard(&self) -> Response {
self.client
.get(self.url("/admin/dashboard"))
.send()
.await
.expect("Failed to execute request")
}

pub async fn get_admin_dashboard_html(&self) -> String {
self.get_admin_dashboard().await.text().await.unwrap()
}

fn url(&self, endpoint: &str) -> String {
format!("http://{}{endpoint}", self.address)
}
Expand Down
Loading

0 comments on commit e62103c

Please sign in to comment.