Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query/Form: Use serde_path_to_error to report fields that failed to parse #3081

Merged
merged 9 commits into from
Dec 20, 2024
2 changes: 2 additions & 0 deletions axum-extra/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -10,8 +10,10 @@ and this project adheres to [Semantic Versioning].
- **breaking:** `axum::extract::ws::Message` now uses `Bytes` in place of `Vec<u8>`,
and a new `Utf8Bytes` type in place of `String`, for its variants ([#3078])
- **changed:** Upgraded `tokio-tungstenite` to 0.26 ([#3078])
- **changed:** Query/Form: Use `serde_path_to_error` to report fields that failed to parse ([#3081])

[#3078]: https://github.com/tokio-rs/axum/pull/3078
[#3081]: https://github.com/tokio-rs/axum/pull/3081

# 0.10.0

4 changes: 2 additions & 2 deletions axum-extra/Cargo.toml
Original file line number Diff line number Diff line change
@@ -23,7 +23,7 @@ cookie-private = ["cookie", "cookie?/private"]
cookie-signed = ["cookie", "cookie?/signed"]
cookie-key-expansion = ["cookie", "cookie?/key-expansion"]
erased-json = ["dep:serde_json", "dep:typed-json"]
form = ["dep:serde_html_form"]
form = ["dep:form_urlencoded", "dep:serde_html_form", "dep:serde_path_to_error"]
json-deserializer = ["dep:serde_json", "dep:serde_path_to_error"]
json-lines = [
"dep:serde_json",
@@ -36,7 +36,7 @@ json-lines = [
multipart = ["dep:multer", "dep:fastrand"]
protobuf = ["dep:prost"]
scheme = []
query = ["dep:serde_html_form"]
query = ["dep:form_urlencoded", "dep:serde_html_form", "dep:serde_path_to_error"]
tracing = ["axum-core/tracing", "axum/tracing"]
typed-header = ["dep:headers"]
typed-routing = ["dep:axum-macros", "dep:percent-encoding", "dep:serde_html_form", "dep:form_urlencoded"]
45 changes: 43 additions & 2 deletions axum-extra/src/extract/form.rs
Original file line number Diff line number Diff line change
@@ -56,7 +56,9 @@ where
.await
.map_err(FormRejection::RawFormRejection)?;

serde_html_form::from_bytes::<T>(&bytes)
let deserializer = serde_html_form::Deserializer::new(form_urlencoded::parse(&bytes));

serde_path_to_error::deserialize::<_, T>(deserializer)
.map(Self)
.map_err(|err| FormRejection::FailedToDeserializeForm(Error::new(err)))
}
@@ -115,8 +117,10 @@ impl std::error::Error for FormRejection {
mod tests {
use super::*;
use crate::test_helpers::*;
use axum::{routing::post, Router};
use axum::routing::{on, post, MethodFilter};
use axum::Router;
use http::header::CONTENT_TYPE;
use mime::APPLICATION_WWW_FORM_URLENCODED;
use serde::Deserialize;

#[tokio::test]
@@ -143,4 +147,41 @@ mod tests {
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "one,two");
}

#[tokio::test]
async fn deserialize_error_status_codes() {
#[allow(dead_code)]
#[derive(Deserialize)]
struct Payload {
a: i32,
}

let app = Router::new().route(
"/",
on(
MethodFilter::GET.or(MethodFilter::POST),
|_: Form<Payload>| async {},
),
);

let client = TestClient::new(app);

let res = client.get("/?a=false").await;
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
assert_eq!(
res.text().await,
"Failed to deserialize form: a: invalid digit found in string"
);

let res = client
.post("/")
.header(CONTENT_TYPE, APPLICATION_WWW_FORM_URLENCODED.as_ref())
.body("a=false")
.await;
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
assert_eq!(
res.text().await,
"Failed to deserialize form: a: invalid digit found in string"
);
}
}
36 changes: 32 additions & 4 deletions axum-extra/src/extract/query.rs
Original file line number Diff line number Diff line change
@@ -103,7 +103,9 @@ where

async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let query = parts.uri.query().unwrap_or_default();
let value = serde_html_form::from_str(query)
let deserializer =
serde_html_form::Deserializer::new(form_urlencoded::parse(query.as_bytes()));
let value = serde_path_to_error::deserialize(deserializer)
.map_err(|err| QueryRejection::FailedToDeserializeQueryString(Error::new(err)))?;
Ok(Query(value))
}
@@ -121,7 +123,9 @@ where
_state: &S,
) -> Result<Option<Self>, Self::Rejection> {
if let Some(query) = parts.uri.query() {
let value = serde_html_form::from_str(query)
let deserializer =
serde_html_form::Deserializer::new(form_urlencoded::parse(query.as_bytes()));
let value = serde_path_to_error::deserialize(deserializer)
.map_err(|err| QueryRejection::FailedToDeserializeQueryString(Error::new(err)))?;
Ok(Some(Self(value)))
} else {
@@ -230,7 +234,9 @@ where

async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
if let Some(query) = parts.uri.query() {
let value = serde_html_form::from_str(query).map_err(|err| {
let deserializer =
serde_html_form::Deserializer::new(form_urlencoded::parse(query.as_bytes()));
let value = serde_path_to_error::deserialize(deserializer).map_err(|err| {
OptionalQueryRejection::FailedToDeserializeQueryString(Error::new(err))
})?;
Ok(OptionalQuery(Some(value)))
@@ -302,7 +308,8 @@ impl std::error::Error for OptionalQueryRejection {
mod tests {
use super::*;
use crate::test_helpers::*;
use axum::{routing::post, Router};
use axum::routing::{get, post};
use axum::Router;
use http::header::CONTENT_TYPE;
use serde::Deserialize;

@@ -331,6 +338,27 @@ mod tests {
assert_eq!(res.text().await, "one,two");
}

#[tokio::test]
async fn correct_rejection_status_code() {
#[derive(Deserialize)]
#[allow(dead_code)]
struct Params {
n: i32,
}

async fn handler(_: Query<Params>) {}

let app = Router::new().route("/", get(handler));
let client = TestClient::new(app);

let res = client.get("/?n=hi").await;
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
assert_eq!(
res.text().await,
"Failed to deserialize query string: n: invalid digit found in string"
);
}

#[tokio::test]
async fn optional_query_supports_multiple_values() {
#[derive(Deserialize)]
4 changes: 4 additions & 0 deletions axum/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

# Unreleased

- **changed:** Query/Form: Use `serde_path_to_error` to report fields that failed to parse ([#3081])

[#3081]: https://github.com/tokio-rs/axum/pull/3081

# 0.8.0

## rc.1
5 changes: 3 additions & 2 deletions axum/Cargo.toml
Original file line number Diff line number Diff line change
@@ -23,15 +23,15 @@ default = [
"tower-log",
"tracing",
]
form = ["dep:serde_urlencoded"]
form = ["dep:form_urlencoded", "dep:serde_urlencoded", "dep:serde_path_to_error"]
http1 = ["dep:hyper", "hyper?/http1", "hyper-util?/http1"]
http2 = ["dep:hyper", "hyper?/http2", "hyper-util?/http2"]
json = ["dep:serde_json", "dep:serde_path_to_error"]
macros = ["dep:axum-macros"]
matched-path = []
multipart = ["dep:multer"]
original-uri = []
query = ["dep:serde_urlencoded"]
query = ["dep:form_urlencoded", "dep:serde_urlencoded", "dep:serde_path_to_error"]
tokio = ["dep:hyper-util", "dep:tokio", "tokio/net", "tokio/rt", "tower/make", "tokio/macros"]
tower-log = ["tower/log"]
tracing = ["dep:tracing", "axum-core/tracing"]
@@ -68,6 +68,7 @@ tower-service = "0.3"
# optional dependencies
axum-macros = { path = "../axum-macros", version = "0.5.0-rc.1", optional = true }
base64 = { version = "0.22.1", optional = true }
form_urlencoded = { version = "1.1.0", optional = true }
jplatte marked this conversation as resolved.
Show resolved Hide resolved
hyper = { version = "1.1.0", optional = true }
hyper-util = { version = "0.1.3", features = ["tokio", "server", "service"], optional = true }
multer = { version = "3.0.0", optional = true }
14 changes: 11 additions & 3 deletions axum/src/extract/query.rs
Original file line number Diff line number Diff line change
@@ -87,7 +87,9 @@ where
_state: &S,
) -> Result<Option<Self>, Self::Rejection> {
if let Some(query) = parts.uri.query() {
let value = serde_urlencoded::from_str(query)
let deserializer =
serde_urlencoded::Deserializer::new(form_urlencoded::parse(query.as_bytes()));
let value = serde_path_to_error::deserialize(deserializer)
.map_err(FailedToDeserializeQueryString::from_err)?;
Ok(Some(Self(value)))
} else {
@@ -121,8 +123,10 @@ where
/// ```
pub fn try_from_uri(value: &Uri) -> Result<Self, QueryRejection> {
let query = value.query().unwrap_or_default();
let params =
serde_urlencoded::from_str(query).map_err(FailedToDeserializeQueryString::from_err)?;
let deserializer =
serde_urlencoded::Deserializer::new(form_urlencoded::parse(query.as_bytes()));
let params = serde_path_to_error::deserialize(deserializer)
.map_err(FailedToDeserializeQueryString::from_err)?;
Ok(Query(params))
}
}
@@ -201,6 +205,10 @@ mod tests {

let res = client.get("/?n=hi").await;
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
assert_eq!(
res.text().await,
"Failed to deserialize query string: n: invalid digit found in string"
);
}

#[test]
17 changes: 14 additions & 3 deletions axum/src/form.rs
Original file line number Diff line number Diff line change
@@ -84,14 +84,17 @@ where

match req.extract().await {
Ok(RawForm(bytes)) => {
let value =
serde_urlencoded::from_bytes(&bytes).map_err(|err| -> FormRejection {
let deserializer =
serde_urlencoded::Deserializer::new(form_urlencoded::parse(&bytes));
let value = serde_path_to_error::deserialize(deserializer).map_err(
|err| -> FormRejection {
if is_get_or_head {
FailedToDeserializeForm::from_err(err).into()
} else {
FailedToDeserializeFormBody::from_err(err).into()
}
})?;
},
)?;
Ok(Form(value))
}
Err(RawFormRejection::BytesRejection(r)) => Err(FormRejection::BytesRejection(r)),
@@ -252,12 +255,20 @@ mod tests {

let res = client.get("/?a=false").await;
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
assert_eq!(
res.text().await,
"Failed to deserialize form: a: invalid digit found in string"
);

let res = client
.post("/")
.header(CONTENT_TYPE, APPLICATION_WWW_FORM_URLENCODED.as_ref())
.body("a=false")
.await;
assert_eq!(res.status(), StatusCode::UNPROCESSABLE_ENTITY);
assert_eq!(
res.text().await,
"Failed to deserialize form body: a: invalid digit found in string"
);
}
}
Loading