diff --git a/axum-extra/CHANGELOG.md b/axum-extra/CHANGELOG.md index a58aec6c1e..0a1f7523e6 100644 --- a/axum-extra/CHANGELOG.md +++ b/axum-extra/CHANGELOG.md @@ -10,8 +10,10 @@ and this project adheres to [Semantic Versioning]. - **breaking:** `axum::extract::ws::Message` now uses `Bytes` in place of `Vec`, and a new `Utf8Bytes` type in place of `String`, for its variants ([#3078]) - **changed:** Upgraded `tokio-tungstenite` to 0.26 ([#3078]) +- **changed:** Query/Form: Use `serde_path_to_error` to report fields that failed to parse ([#3081]) [#3078]: https://github.com/tokio-rs/axum/pull/3078 +[#3081]: https://github.com/tokio-rs/axum/pull/3081 # 0.10.0 diff --git a/axum-extra/Cargo.toml b/axum-extra/Cargo.toml index d8f5435fd3..384544568e 100644 --- a/axum-extra/Cargo.toml +++ b/axum-extra/Cargo.toml @@ -23,7 +23,7 @@ cookie-private = ["cookie", "cookie?/private"] cookie-signed = ["cookie", "cookie?/signed"] cookie-key-expansion = ["cookie", "cookie?/key-expansion"] erased-json = ["dep:serde_json", "dep:typed-json"] -form = ["dep:serde_html_form"] +form = ["dep:form_urlencoded", "dep:serde_html_form", "dep:serde_path_to_error"] json-deserializer = ["dep:serde_json", "dep:serde_path_to_error"] json-lines = [ "dep:serde_json", @@ -36,7 +36,7 @@ json-lines = [ multipart = ["dep:multer", "dep:fastrand"] protobuf = ["dep:prost"] scheme = [] -query = ["dep:serde_html_form"] +query = ["dep:form_urlencoded", "dep:serde_html_form", "dep:serde_path_to_error"] tracing = ["axum-core/tracing", "axum/tracing"] typed-header = ["dep:headers"] typed-routing = ["dep:axum-macros", "dep:percent-encoding", "dep:serde_html_form", "dep:form_urlencoded"] diff --git a/axum-extra/src/extract/form.rs b/axum-extra/src/extract/form.rs index a7ca9305aa..8d2d30f91c 100644 --- a/axum-extra/src/extract/form.rs +++ b/axum-extra/src/extract/form.rs @@ -56,7 +56,9 @@ where .await .map_err(FormRejection::RawFormRejection)?; - serde_html_form::from_bytes::(&bytes) + let deserializer = serde_html_form::Deserializer::new(form_urlencoded::parse(&bytes)); + + serde_path_to_error::deserialize::<_, T>(deserializer) .map(Self) .map_err(|err| FormRejection::FailedToDeserializeForm(Error::new(err))) } @@ -115,8 +117,10 @@ impl std::error::Error for FormRejection { mod tests { use super::*; use crate::test_helpers::*; - use axum::{routing::post, Router}; + use axum::routing::{on, post, MethodFilter}; + use axum::Router; use http::header::CONTENT_TYPE; + use mime::APPLICATION_WWW_FORM_URLENCODED; use serde::Deserialize; #[tokio::test] @@ -143,4 +147,41 @@ mod tests { assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "one,two"); } + + #[tokio::test] + async fn deserialize_error_status_codes() { + #[allow(dead_code)] + #[derive(Deserialize)] + struct Payload { + a: i32, + } + + let app = Router::new().route( + "/", + on( + MethodFilter::GET.or(MethodFilter::POST), + |_: Form| async {}, + ), + ); + + let client = TestClient::new(app); + + let res = client.get("/?a=false").await; + assert_eq!(res.status(), StatusCode::BAD_REQUEST); + assert_eq!( + res.text().await, + "Failed to deserialize form: a: invalid digit found in string" + ); + + let res = client + .post("/") + .header(CONTENT_TYPE, APPLICATION_WWW_FORM_URLENCODED.as_ref()) + .body("a=false") + .await; + assert_eq!(res.status(), StatusCode::BAD_REQUEST); + assert_eq!( + res.text().await, + "Failed to deserialize form: a: invalid digit found in string" + ); + } } diff --git a/axum-extra/src/extract/query.rs b/axum-extra/src/extract/query.rs index 6e50456e2f..489fc1c7d4 100644 --- a/axum-extra/src/extract/query.rs +++ b/axum-extra/src/extract/query.rs @@ -103,7 +103,9 @@ where async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { let query = parts.uri.query().unwrap_or_default(); - let value = serde_html_form::from_str(query) + let deserializer = + serde_html_form::Deserializer::new(form_urlencoded::parse(query.as_bytes())); + let value = serde_path_to_error::deserialize(deserializer) .map_err(|err| QueryRejection::FailedToDeserializeQueryString(Error::new(err)))?; Ok(Query(value)) } @@ -121,7 +123,9 @@ where _state: &S, ) -> Result, Self::Rejection> { if let Some(query) = parts.uri.query() { - let value = serde_html_form::from_str(query) + let deserializer = + serde_html_form::Deserializer::new(form_urlencoded::parse(query.as_bytes())); + let value = serde_path_to_error::deserialize(deserializer) .map_err(|err| QueryRejection::FailedToDeserializeQueryString(Error::new(err)))?; Ok(Some(Self(value))) } else { @@ -230,7 +234,9 @@ where async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { if let Some(query) = parts.uri.query() { - let value = serde_html_form::from_str(query).map_err(|err| { + let deserializer = + serde_html_form::Deserializer::new(form_urlencoded::parse(query.as_bytes())); + let value = serde_path_to_error::deserialize(deserializer).map_err(|err| { OptionalQueryRejection::FailedToDeserializeQueryString(Error::new(err)) })?; Ok(OptionalQuery(Some(value))) @@ -302,7 +308,8 @@ impl std::error::Error for OptionalQueryRejection { mod tests { use super::*; use crate::test_helpers::*; - use axum::{routing::post, Router}; + use axum::routing::{get, post}; + use axum::Router; use http::header::CONTENT_TYPE; use serde::Deserialize; @@ -331,6 +338,27 @@ mod tests { assert_eq!(res.text().await, "one,two"); } + #[tokio::test] + async fn correct_rejection_status_code() { + #[derive(Deserialize)] + #[allow(dead_code)] + struct Params { + n: i32, + } + + async fn handler(_: Query) {} + + let app = Router::new().route("/", get(handler)); + let client = TestClient::new(app); + + let res = client.get("/?n=hi").await; + assert_eq!(res.status(), StatusCode::BAD_REQUEST); + assert_eq!( + res.text().await, + "Failed to deserialize query string: n: invalid digit found in string" + ); + } + #[tokio::test] async fn optional_query_supports_multiple_values() { #[derive(Deserialize)] diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index e896deb317..5301ebf51a 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased +- **changed:** Query/Form: Use `serde_path_to_error` to report fields that failed to parse ([#3081]) + +[#3081]: https://github.com/tokio-rs/axum/pull/3081 + # 0.8.0 ## rc.1 diff --git a/axum/Cargo.toml b/axum/Cargo.toml index fb88399ec5..833d935266 100644 --- a/axum/Cargo.toml +++ b/axum/Cargo.toml @@ -23,7 +23,7 @@ default = [ "tower-log", "tracing", ] -form = ["dep:serde_urlencoded"] +form = ["dep:form_urlencoded", "dep:serde_urlencoded", "dep:serde_path_to_error"] http1 = ["dep:hyper", "hyper?/http1", "hyper-util?/http1"] http2 = ["dep:hyper", "hyper?/http2", "hyper-util?/http2"] json = ["dep:serde_json", "dep:serde_path_to_error"] @@ -31,7 +31,7 @@ macros = ["dep:axum-macros"] matched-path = [] multipart = ["dep:multer"] original-uri = [] -query = ["dep:serde_urlencoded"] +query = ["dep:form_urlencoded", "dep:serde_urlencoded", "dep:serde_path_to_error"] tokio = ["dep:hyper-util", "dep:tokio", "tokio/net", "tokio/rt", "tower/make", "tokio/macros"] tower-log = ["tower/log"] tracing = ["dep:tracing", "axum-core/tracing"] @@ -68,6 +68,7 @@ tower-service = "0.3" # optional dependencies axum-macros = { path = "../axum-macros", version = "0.5.0-rc.1", optional = true } base64 = { version = "0.22.1", optional = true } +form_urlencoded = { version = "1.1.0", optional = true } hyper = { version = "1.1.0", optional = true } hyper-util = { version = "0.1.3", features = ["tokio", "server", "service"], optional = true } multer = { version = "3.0.0", optional = true } diff --git a/axum/src/extract/query.rs b/axum/src/extract/query.rs index 14473aab04..64221afabb 100644 --- a/axum/src/extract/query.rs +++ b/axum/src/extract/query.rs @@ -87,7 +87,9 @@ where _state: &S, ) -> Result, Self::Rejection> { if let Some(query) = parts.uri.query() { - let value = serde_urlencoded::from_str(query) + let deserializer = + serde_urlencoded::Deserializer::new(form_urlencoded::parse(query.as_bytes())); + let value = serde_path_to_error::deserialize(deserializer) .map_err(FailedToDeserializeQueryString::from_err)?; Ok(Some(Self(value))) } else { @@ -121,8 +123,10 @@ where /// ``` pub fn try_from_uri(value: &Uri) -> Result { let query = value.query().unwrap_or_default(); - let params = - serde_urlencoded::from_str(query).map_err(FailedToDeserializeQueryString::from_err)?; + let deserializer = + serde_urlencoded::Deserializer::new(form_urlencoded::parse(query.as_bytes())); + let params = serde_path_to_error::deserialize(deserializer) + .map_err(FailedToDeserializeQueryString::from_err)?; Ok(Query(params)) } } @@ -201,6 +205,10 @@ mod tests { let res = client.get("/?n=hi").await; assert_eq!(res.status(), StatusCode::BAD_REQUEST); + assert_eq!( + res.text().await, + "Failed to deserialize query string: n: invalid digit found in string" + ); } #[test] diff --git a/axum/src/form.rs b/axum/src/form.rs index f754c4c1b8..fd7c033815 100644 --- a/axum/src/form.rs +++ b/axum/src/form.rs @@ -84,14 +84,17 @@ where match req.extract().await { Ok(RawForm(bytes)) => { - let value = - serde_urlencoded::from_bytes(&bytes).map_err(|err| -> FormRejection { + let deserializer = + serde_urlencoded::Deserializer::new(form_urlencoded::parse(&bytes)); + let value = serde_path_to_error::deserialize(deserializer).map_err( + |err| -> FormRejection { if is_get_or_head { FailedToDeserializeForm::from_err(err).into() } else { FailedToDeserializeFormBody::from_err(err).into() } - })?; + }, + )?; Ok(Form(value)) } Err(RawFormRejection::BytesRejection(r)) => Err(FormRejection::BytesRejection(r)), @@ -252,6 +255,10 @@ mod tests { let res = client.get("/?a=false").await; assert_eq!(res.status(), StatusCode::BAD_REQUEST); + assert_eq!( + res.text().await, + "Failed to deserialize form: a: invalid digit found in string" + ); let res = client .post("/") @@ -259,5 +266,9 @@ mod tests { .body("a=false") .await; assert_eq!(res.status(), StatusCode::UNPROCESSABLE_ENTITY); + assert_eq!( + res.text().await, + "Failed to deserialize form body: a: invalid digit found in string" + ); } }