Skip to content

Commit

Permalink
add macros for generating payload structs, use them in generated task…
Browse files Browse the repository at this point in the history
…s controller
  • Loading branch information
hdoordt committed Oct 23, 2024
1 parent e6ad499 commit f46d9d9
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 26 deletions.
2 changes: 1 addition & 1 deletion blueprint/db/src/entities/tasks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub struct Task {
/// ```
/// let task_changeset: TaskChangeset = Faker.fake();
/// ```
#[derive(Deserialize, Validate, Clone)]
#[derive(Debug, Deserialize, Validate, Clone)]
#[cfg_attr(feature = "test-helpers", derive(Serialize, Dummy))]
pub struct TaskChangeset {
/// The description must be at least 1 character long.
Expand Down
129 changes: 126 additions & 3 deletions blueprint/macros/src/lib.rs.liquid
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
//! The {{crate_name}}-macros crate contains the `test`{%- unless template_type == "minimal" %} and `db_test`{%- endunless %} macro{%- unless template_type == "minimal" -%} s{% endunless -%}.

use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, ItemFn};
use quote::{quote, ToTokens};
use syn::{parse_macro_input, Fields, Ident, ItemFn, ItemStruct, Type};

#[allow(clippy::test_attr_in_doctest)]
/// Used to mark an application test.
Expand Down Expand Up @@ -110,4 +110,127 @@ pub fn db_test(_: TokenStream, item: TokenStream) -> TokenStream {

TokenStream::from(output)
}
{%- endunless %}
{%- endunless %}

#[proc_macro_attribute]
pub fn request_payload(_: TokenStream, item: TokenStream) -> TokenStream {
let input = parse_macro_input!(item as ItemStruct);
let PayloadStructInfo {
outer_ty,
inner_ty,
inner_ty_lit_str,
} = PayloadStructInfo::from_input(&input);

TokenStream::from(quote! {
#[derive(::serde::Deserialize)]
#[serde(try_from = #inner_ty_lit_str)]
#input

impl TryFrom<#inner_ty> for #outer_ty {
type Error = ::validator::ValidationErrors;

fn try_from(inner: #inner_ty) -> Result<Self, Self::Error> {
::validator::Validate::validate(&inner)?;
Ok(Self(inner))
}
}

impl From<#outer_ty> for #inner_ty {
fn from(#outer_ty(inner): #outer_ty) -> Self {
inner
}
}
})
}

#[proc_macro_attribute]
pub fn batch_request_payload(_: TokenStream, item: TokenStream) -> TokenStream {
let input = parse_macro_input!(item as ItemStruct);
let PayloadStructInfo {
outer_ty,
inner_ty,
inner_ty_lit_str,
} = PayloadStructInfo::from_input(&input);

TokenStream::from(quote! {
#[derive(::serde::Deserialize)]
#[serde(try_from = #inner_ty_lit_str)]
#input

impl TryFrom<#inner_ty> for #outer_ty {
type Error = ::validator::ValidationErrors;

fn try_from(inner: #inner_ty) -> Result<Self, Self::Error> {
let cap = inner.len();

inner
.into_iter()
.try_fold(Vec::with_capacity(cap), |mut v, item| {
::validator::Validate::validate(&item)?;
v.push(item);
Ok(v)
})
.map(Self)
}
}

impl From<#outer_ty> for #inner_ty {
fn from(#outer_ty(inner): #outer_ty) -> Self {
inner
}
}
})
}

#[proc_macro_attribute]
pub fn response_payload(_: TokenStream, item: TokenStream) -> TokenStream {
let input = parse_macro_input!(item as ItemStruct);
let PayloadStructInfo {
outer_ty,
inner_ty,
inner_ty_lit_str,
} = PayloadStructInfo::from_input(&input);

TokenStream::from(quote! {
#[derive(::serde::Serialize)]
#[serde(try_from = #inner_ty_lit_str)]
#input

impl From<#inner_ty> for #outer_ty {
fn from(inner: #inner_ty) -> Self {
Self(inner)
}
}
})
}

struct PayloadStructInfo<'input> {
outer_ty: &'input Ident,
inner_ty: &'input Type,
inner_ty_lit_str: String,
}

impl<'input> PayloadStructInfo<'input> {
fn from_input(input: &'input ItemStruct) -> Self {
fn error() -> ! {
panic!("Macro can only be applied to tuple structs with a single field")
}

let outer_ty = &input.ident;

let Fields::Unnamed(fields) = &input.fields else {
error()
};
let mut fields = fields.unnamed.iter();
let Some(field) = fields.next() else { error() };
let None = fields.next() else { error() };

let inner_ty = &field.ty;
let inner_ty_lit_str = inner_ty.clone().to_token_stream().to_string();
Self {
outer_ty,
inner_ty,
inner_ty_lit_str,
}
}
}
5 changes: 3 additions & 2 deletions blueprint/web/Cargo.toml.liquid
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ publish = false
doctest = false

[features]
test-helpers = ["dep:serde_json", "dep:tower", "dep:hyper", "dep:{{project-name}}-macros"]
test-helpers = ["dep:serde_json", "dep:tower", "dep:hyper"]

[dependencies]
anyhow = "1.0"
Expand All @@ -30,7 +30,8 @@ uuid = { version = "1.6", features = ["serde"] }
serde_json = { version = "1.0", optional = true }
tower = { version = "0.5", features = ["util"], optional = true }
hyper = { version = "1.0", features = ["full"], optional = true }
{{project-name}}-macros = { path = "../macros", optional = true }
{{project-name}}-macros = { path = "../macros" }
validator = { version = "0.18.1", features = ["derive"] }

[dev-dependencies]
fake = "2.9"
Expand Down
73 changes: 53 additions & 20 deletions blueprint/web/src/controllers/tasks.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
use crate::{state::AppState, internal_error};
use crate::{internal_error, state::AppState};
use axum::{extract::Path, extract::State, http::StatusCode, Json};
use {{crate_name}}_db::{entities::tasks, transaction, Error};
use {{crate_name}}_db::{
entities::tasks::{self},
transaction, Error,
};
use payloads::*;
use tracing::info;
use uuid::Uuid;

/// Creates a task in the database.
///
/// This function creates a task in the database (see [`{{crate_name}}_db::entities::tasks::create`]) based on a [`{{crate_name}}_db::entities::tasks::TaskChangeset`] (sent as JSON). If the task is created successfully, a 201 response is returned with the created [`{{crate_name}}_db::entities::tasks::Task`]'s JSON representation in the response body. If the changeset is invalid, a 422 response is returned.
/// This function creates a task in the database (see [`getest_db::entities::tasks::create`]) based on a [`getest_db::entities::tasks::TaskChangeset`] (sent as JSON). If the task is created successfully, a 201 response is returned with the created [`getest_db::entities::tasks::Task`]'s JSON representation in the response body. If the changeset is invalid, a 422 response is returned.
pub async fn create(
State(app_state): State<AppState>,
Json(task): Json<tasks::TaskChangeset>,
) -> Result<(StatusCode, Json<tasks::Task>), (StatusCode, String)> {
match tasks::create(task, &app_state.db_pool).await {
Ok(task) => Ok((StatusCode::CREATED, Json(task))),
Json(payload): Json<CreateRequestPayload>,
) -> Result<(StatusCode, Json<CreateResponsePayload>), (StatusCode, String)> {
match tasks::create(payload.into(), &app_state.db_pool).await {
Ok(task) => Ok((StatusCode::CREATED, Json(task.into()))),
Err(Error::ValidationError(e)) => {
info!(err.msg = %e, err.details = ?e, "Validation failed");
Err((StatusCode::UNPROCESSABLE_ENTITY, e.to_string()))
Expand All @@ -23,17 +27,17 @@ pub async fn create(

/// Creates multiple tasks in the database.
///
/// This function creates multiple tasks in the database (see [`{{crate_name}}_db::entities::tasks::create`]) based on [`{{crate_name}}_db::entities::tasks::TaskChangeset`]s (sent as JSON). If all tasks are created successfully, a 201 response is returned with the created [`{{crate_name}}_db::entities::tasks::Task`]s' JSON representation in the response body. If any of the passed changesets is invalid, a 422 response is returned.
/// This function creates multiple tasks in the database (see [`getest_db::entities::tasks::create`]) based on [`getest_db::entities::tasks::TaskChangeset`]s (sent as JSON). If all tasks are created successfully, a 201 response is returned with the created [`getest_db::entities::tasks::Task`]s' JSON representation in the response body. If any of the passed changesets is invalid, a 422 response is returned.
///
/// This function creates all tasks in a transaction so that either all are created successfully or none is.
pub async fn create_batch(
State(app_state): State<AppState>,
Json(tasks): Json<Vec<tasks::TaskChangeset>>,
) -> Result<(StatusCode, Json<Vec<tasks::Task>>), (StatusCode, String)> {
Json(payload): Json<CreateBatchRequestPayload>,
) -> Result<(StatusCode, Json<CreateBatchResponsePayload>), (StatusCode, String)> {
match transaction(&app_state.db_pool).await {
Ok(mut tx) => {
let mut results: Vec<tasks::Task> = vec![];
for task in tasks {
for task in Vec::<_>::from(payload) {
match tasks::create(task, &mut *tx).await {
Ok(task) => results.push(task),
Err(Error::ValidationError(e)) => {
Expand All @@ -45,7 +49,7 @@ pub async fn create_batch(
}

match tx.commit().await {
Ok(_) => Ok((StatusCode::CREATED, Json(results))),
Ok(_) => Ok((StatusCode::CREATED, Json(results.into()))),
Err(e) => Err((internal_error(e), "".into())),
}
}
Expand All @@ -55,7 +59,7 @@ pub async fn create_batch(

/// Reads and responds with all the tasks currently present in the database.
///
/// This function reads all [`{{crate_name}}_db::entities::tasks::Task`]s from the database (see [`{{crate_name}}_db::entities::tasks::load_all`]) and responds with their JSON representations.
/// This function reads all [`getest_db::entities::tasks::Task`]s from the database (see [`getest_db::entities::tasks::load_all`]) and responds with their JSON representations.
pub async fn read_all(
State(app_state): State<AppState>,
) -> Result<Json<Vec<tasks::Task>>, StatusCode> {
Expand All @@ -70,7 +74,7 @@ pub async fn read_all(

/// Reads and responds with a task identified by its ID.
///
/// This function reads one [`{{crate_name}}_db::entities::tasks::Task`] identified by its ID from the database (see [`{{crate_name}}_db::entities::tasks::load`]) and responds with its JSON representations. If no task is found for the ID, a 404 response is returned.
/// This function reads one [`getest_db::entities::tasks::Task`] identified by its ID from the database (see [`getest_db::entities::tasks::load`]) and responds with its JSON representations. If no task is found for the ID, a 404 response is returned.
pub async fn read_one(
State(app_state): State<AppState>,
Path(id): Path<Uuid>,
Expand All @@ -84,14 +88,14 @@ pub async fn read_one(

/// Updates a task in the database.
///
/// This function updates a task identified by its ID in the database (see [`{{crate_name}}_db::entities::tasks::update`]) with the data from the passed [`{{crate_name}}_db::entities::tasks::TaskChangeset`] (sent as JSON). If the task is updated successfully, a 200 response is returned with the created [`{{crate_name}}_db::entities::tasks::Task`]'s JSON representation in the response body. If the changeset is invalid, a 422 response is returned.
/// This function updates a task identified by its ID in the database (see [`getest_db::entities::tasks::update`]) with the data from the passed [`getest_db::entities::tasks::TaskChangeset`] (sent as JSON). If the task is updated successfully, a 200 response is returned with the created [`getest_db::entities::tasks::Task`]'s JSON representation in the response body. If the changeset is invalid, a 422 response is returned.
pub async fn update(
State(app_state): State<AppState>,
Path(id): Path<Uuid>,
Json(task): Json<tasks::TaskChangeset>,
) -> Result<Json<tasks::Task>, (StatusCode, String)> {
match tasks::update(id, task, &app_state.db_pool).await {
Ok(task) => Ok(Json(task)),
Json(payload): Json<UpdateRequestPayload>,
) -> Result<Json<UpdateResponsePayload>, (StatusCode, String)> {
match tasks::update(id, payload.into(), &app_state.db_pool).await {
Ok(task) => Ok(Json(task.into())),
Err(Error::NoRecordFound) => Err((StatusCode::NOT_FOUND, "".into())),
Err(Error::ValidationError(e)) => {
info!(err.msg = %e, err.details = ?e, "Validation failed");
Expand All @@ -103,7 +107,7 @@ pub async fn update(

/// Deletes a task identified by its ID from the database.
///
/// This function deletes one [`{{crate_name}}_db::entities::tasks::Task`] identified by the entity's id from the database (see [`{{crate_name}}_db::entities::tasks::delete`]) and responds with a 204 status code and empty response body. If no task is found for the ID, a 404 response is returned.
/// This function deletes one [`getest_db::entities::tasks::Task`] identified by the entity's id from the database (see [`getest_db::entities::tasks::delete`]) and responds with a 204 status code and empty response body. If no task is found for the ID, a 404 response is returned.
pub async fn delete(
State(app_state): State<AppState>,
Path(id): Path<Uuid>,
Expand All @@ -114,3 +118,32 @@ pub async fn delete(
Err(e) => Err((internal_error(e), "".into())),
}
}

mod payloads {
use {{crate_name}}_db::entities::tasks::{Task, TaskChangeset};
use {{crate_name}}_macros::{batch_request_payload, request_payload, response_payload};

#[derive(Debug)]
#[request_payload]
pub struct CreateRequestPayload(TaskChangeset);

#[derive(Debug)]
#[response_payload]
pub struct CreateResponsePayload(Task);

#[derive(Debug)]
#[batch_request_payload]
pub struct CreateBatchRequestPayload(Vec<TaskChangeset>);

#[derive(Debug)]
#[response_payload]
pub struct CreateBatchResponsePayload(Vec<Task>);

#[derive(Debug)]
#[request_payload]
pub struct UpdateRequestPayload(TaskChangeset);

#[derive(Debug)]
#[response_payload]
pub struct UpdateResponsePayload(Task);
}

0 comments on commit f46d9d9

Please sign in to comment.