Skip to content
This repository has been archived by the owner on Aug 21, 2024. It is now read-only.

Commit

Permalink
Validate subscription token
Browse files Browse the repository at this point in the history
  • Loading branch information
0rzech committed Mar 12, 2024
1 parent b5a6862 commit 83f8cf9
Show file tree
Hide file tree
Showing 10 changed files with 264 additions and 35 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ axum = "0.7.4"
config = "0.14.0"
once_cell = "1.19.0"
rand = "0.8.5"
regex = "1.10.3"
reqwest = { version = "0.11.24", features = ["json"], default-features = false }
secrecy = { version = "0.8.0", features = ["serde"] }
serde = { version = "1.0.196", features = ["derive"] }
Expand Down
2 changes: 2 additions & 0 deletions src/domain/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ mod subscriber_email;
mod subscriber_name;
mod subscription;
mod subscription_status;
mod subscription_token;

pub use new_subscriber::NewSubscriber;
pub use subscriber_email::SubscriberEmail;
pub use subscriber_name::SubscriberName;
pub use subscription::Subscription;
pub use subscription_status::SubscriptionStatus;
pub use subscription_token::{token_regex, SubscriptionToken};
3 changes: 1 addition & 2 deletions src/domain/subscriber_name.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ impl TryFrom<String> for SubscriberName {

#[cfg(test)]
mod tests {
use super::FORBIDDEN_CHARS;
use crate::domain::SubscriberName;
use super::{SubscriberName, FORBIDDEN_CHARS};
use claims::{assert_err, assert_ok};

#[test]
Expand Down
179 changes: 179 additions & 0 deletions src/domain/subscription_token.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
use once_cell::sync::Lazy;
use rand::{distributions::Alphanumeric, thread_rng, Rng};
use regex::Regex;
use serde::Deserialize;
use std::iter::repeat_with;

const TOKEN_CHARS: &str = r"[[:alnum:]]";
const TOKEN_LENGTH: usize = 25;

pub fn token_regex() -> String {
format!(r"{TOKEN_CHARS}{{{TOKEN_LENGTH}}}")
}

fn token_regex_anchored() -> String {
format!(r"^{}$", token_regex())
}

#[derive(Clone, Debug, Deserialize)]
pub struct SubscriptionToken(String);

impl SubscriptionToken {
pub fn generate() -> Self {
Self::generate_with_rng(&mut thread_rng())
}

fn generate_with_rng(rng: &mut impl Rng) -> Self {
let token = repeat_with(|| rng.sample(Alphanumeric))
.map(char::from)
.take(TOKEN_LENGTH)
.collect();

Self(token)
}

pub fn parse(s: String) -> Result<Self, String> {
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(&token_regex_anchored()).unwrap());

if RE.is_match(&s) {
Ok(Self(s))
} else {
Err(format!("Invalid subscription token: `{s}`"))
}
}
}

impl AsRef<str> for SubscriptionToken {
fn as_ref(&self) -> &str {
&self.0
}
}

impl TryFrom<String> for SubscriptionToken {
type Error = String;

fn try_from(s: String) -> Result<Self, Self::Error> {
Self::parse(s)
}
}

#[cfg(test)]
mod tests {
use super::{token_regex, SubscriptionToken};
use claims::{assert_err, assert_ok};
use helpers::{invalid_length_tokens, non_alnum_tokens, valid_tokens};
use proptest::prelude::proptest;

proptest! {
#[test]
fn generated_tokens_are_valid(token in valid_tokens()) {
// when
let result = SubscriptionToken::parse(token.0);

// then
assert_ok!(result);
}
}

proptest! {
#[test]
fn valid_tokens_are_parsed_successfully(token in token_regex().as_str()) {
// when
let result = SubscriptionToken::parse(token);

// then
assert_ok!(result);
}
}

#[test]
fn empty_string_is_rejected() {
// given
let token = "".to_string();

// when
let result = SubscriptionToken::parse(token);

// then
assert_err!(result);
}

proptest! {
#[test]
fn tokens_with_non_alphanumeric_characters_are_rejected(token in non_alnum_tokens().as_str()) {
// when
let result = SubscriptionToken::parse(token);

// then
assert_err!(result);
}
}

proptest! {
#[test]
fn tokens_with_invalid_length_are_rejected(token in invalid_length_tokens()) {
// when
let result = SubscriptionToken::parse(token);

// then
assert_err!(result);
}
}

mod helpers {
use crate::domain::SubscriptionToken;

use super::super::TOKEN_LENGTH;
use proptest::{
strategy::{NewTree, Strategy, ValueTree},
test_runner::TestRunner,
};

pub fn valid_tokens() -> impl Strategy<Value = SubscriptionToken> {
ValidTokenStrategy
}

#[derive(Debug)]
struct ValidTokenStrategy;

impl Strategy for ValidTokenStrategy {
type Tree = ValidTokenValueTree;
type Value = SubscriptionToken;

fn new_tree(&self, runner: &mut TestRunner) -> NewTree<Self> {
Ok(ValidTokenValueTree(SubscriptionToken::generate_with_rng(
runner.rng(),
)))
}
}

struct ValidTokenValueTree(SubscriptionToken);

impl ValueTree for ValidTokenValueTree {
type Value = SubscriptionToken;

fn current(&self) -> Self::Value {
self.0.clone()
}

fn simplify(&mut self) -> bool {
false
}

fn complicate(&mut self) -> bool {
false
}
}

pub const FILTERED_LENGTHS: [usize; 2] = [0, TOKEN_LENGTH];

pub fn non_alnum_tokens() -> String {
format!(r"[[:^alnum:]]{{{TOKEN_LENGTH}}}")
}

pub fn invalid_length_tokens() -> impl Strategy<Value = String> {
let whence = format!("Invalid token length must not be any of {FILTERED_LENGTHS:?}");
"[[:alnum:]]*".prop_filter(whence, |v| !FILTERED_LENGTHS.contains(&v.len()))
}
}
}
29 changes: 12 additions & 17 deletions src/routes/subscriptions.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use crate::{
app_state::AppState,
domain::{NewSubscriber, SubscriberEmail, SubscriberName, Subscription, SubscriptionStatus},
domain::{
NewSubscriber, SubscriberEmail, SubscriberName, Subscription, SubscriptionStatus,
SubscriptionToken,
},
email_client::EmailClient,
};
use axum::{
Expand All @@ -9,11 +12,9 @@ use axum::{
routing::post,
Form, Router,
};
use rand::{distributions::Alphanumeric, thread_rng, Rng};
use reqwest::Error;
use serde::Deserialize;
use sqlx::{Executor, FromRow, Postgres, Transaction};
use std::iter::repeat_with;
use time::OffsetDateTime;
use uuid::Uuid;

Expand Down Expand Up @@ -63,7 +64,7 @@ async fn subscribe(State(app_state): State<AppState>, Form(form): Form<FormData>
}
};

let subscription_token = generate_subscription_token();
let subscription_token = SubscriptionToken::generate();

if store_token(&mut transaction, subscriber_id, &subscription_token)
.await
Expand Down Expand Up @@ -162,10 +163,12 @@ async fn send_confirmation_email(
email_client: &EmailClient,
new_subscriber: NewSubscriber,
base_url: &Uri,
subscription_token: &str,
subscription_token: &SubscriptionToken,
) -> Result<(), Error> {
let confirmation_link =
format!("{base_url}subscriptions/confirm?subscription_token={subscription_token}");
let confirmation_link = format!(
"{base_url}subscriptions/confirm?subscription_token={}",
subscription_token.as_ref()
);
let html_body = format!(
"Welcome to our newsletter!<br/>\
Click <a href=\"{confirmation_link}\">here</a> to confirm your subscription."
Expand Down Expand Up @@ -195,29 +198,21 @@ impl TryFrom<FormData> for NewSubscriber {
}
}

fn generate_subscription_token() -> String {
let mut rng = thread_rng();
repeat_with(|| rng.sample(Alphanumeric))
.map(char::from)
.take(25)
.collect()
}

#[tracing::instrument(
name = "Storing subscription token in the database",
skip(transaction, subscription_token)
)]
async fn store_token(
transaction: &mut Transaction<'_, Postgres>,
subscriber_id: Uuid,
subscription_token: &str,
subscription_token: &SubscriptionToken,
) -> Result<(), sqlx::Error> {
let query = sqlx::query!(
r#"
INSERT INTO subscription_tokens (subscription_token, subscriber_id)
VALUES ($1, $2)
"#,
subscription_token,
subscription_token.as_ref(),
subscriber_id
);

Expand Down
31 changes: 19 additions & 12 deletions src/routes/subscriptions_confirm.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use crate::{app_state::AppState, domain::SubscriptionStatus};
use crate::{
app_state::AppState,
domain::{SubscriptionStatus, SubscriptionToken},
};
use axum::{
extract::{Query, State},
http::StatusCode,
Expand Down Expand Up @@ -26,17 +29,21 @@ async fn confirm(
}
};

let subscriber_id = match get_subscriber_id_from_token(
&mut transaction,
&parameters.subscription_token,
)
.await
{
Ok(Some(id)) => id,
Ok(None) => return StatusCode::UNAUTHORIZED,
Err(_) => return StatusCode::INTERNAL_SERVER_ERROR,
let subscription_token = match SubscriptionToken::parse(parameters.subscription_token) {
Ok(token) => token,
Err(e) => {
tracing::error!(e);
return StatusCode::BAD_REQUEST;
}
};

let subscriber_id =
match get_subscriber_id_from_token(&mut transaction, &subscription_token).await {
Ok(Some(id)) => id,
Ok(None) => return StatusCode::UNAUTHORIZED,
Err(_) => return StatusCode::INTERNAL_SERVER_ERROR,
};

if confirm_subscriber(&mut transaction, subscriber_id)
.await
.is_err()
Expand Down Expand Up @@ -71,14 +78,14 @@ struct Parameters {
)]
async fn get_subscriber_id_from_token(
transaction: &mut Transaction<'_, Postgres>,
subscription_token: &str,
subscription_token: &SubscriptionToken,
) -> Result<Option<Uuid>, sqlx::Error> {
let query = sqlx::query!(
r#"
SELECT subscriber_id FROM subscription_tokens
WHERE subscription_token = $1
"#,
subscription_token,
subscription_token.as_ref(),
);

let result = transaction.fetch_optional(query).await.map_err(|e| {
Expand Down
14 changes: 13 additions & 1 deletion tests/api/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,26 @@ impl TestApp {
.expect(FAILED_TO_EXECUTE_REQUEST)
}

pub async fn confirm_subscription(&self) -> reqwest::Response {
pub async fn confirm_subscription_without_token(&self) -> reqwest::Response {
self.client
.get(self.url("/subscriptions/confirm"))
.send()
.await
.expect(FAILED_TO_EXECUTE_REQUEST)
}

pub async fn confirm_subscription(&self, token: &str) -> reqwest::Response {
self.client
.get(format!(
"{}?subscription_token={}",
self.url("/subscriptions/confirm"),
token
))
.send()
.await
.expect(FAILED_TO_EXECUTE_REQUEST)
}

pub async fn post_subscriptions(&self, body: String) -> reqwest::Response {
self.client
.post(self.url("/subscriptions"))
Expand Down
Loading

0 comments on commit 83f8cf9

Please sign in to comment.