Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Request<T> where T is a type that implements hyper::body::Body trait #1263

Merged
merged 12 commits into from
Sep 26, 2024
6 changes: 4 additions & 2 deletions juniper_hyper/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@ All user visible changes to `juniper_hyper` crate will be documented in this fil

### BC Breaks

- Bumped up [MSRV] to 1.75. ([#1272])
- Bumped up [MSRV] to 1.79. ([#1263])
- Made `hyper::Request` in `graphql()` and `graphql_sync()` functions generic over `T: hyper::body::Body`. ([#1263], [#1102])

[#1272]: /../../pull/1272
[#1102]: /../../issues/1102
[#1263]: /../../pull/1263



Expand Down
2 changes: 1 addition & 1 deletion juniper_hyper/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
name = "juniper_hyper"
version = "0.9.0"
edition = "2021"
rust-version = "1.75"
rust-version = "1.79"
description = "`juniper` GraphQL integration with `hyper`."
license = "BSD-2-Clause"
authors = [
Expand Down
163 changes: 119 additions & 44 deletions juniper_hyper/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{error::Error, fmt, string::FromUtf8Error, sync::Arc};

use http_body_util::BodyExt as _;
use hyper::{
body,
body::Body,
header::{self, HeaderValue},
Method, Request, Response, StatusCode,
};
Expand All @@ -15,10 +15,10 @@ use juniper::{
use serde_json::error::Error as SerdeError;
use url::form_urlencoded;

pub async fn graphql_sync<CtxT, QueryT, MutationT, SubscriptionT, S>(
pub async fn graphql_sync<CtxT, QueryT, MutationT, SubscriptionT, S, T>(
root_node: Arc<RootNode<'static, QueryT, MutationT, SubscriptionT, S>>,
context: Arc<CtxT>,
req: Request<body::Incoming>,
req: Request<T>,
) -> Response<String>
where
QueryT: GraphQLType<S, Context = CtxT>,
Expand All @@ -29,17 +29,18 @@ where
SubscriptionT::TypeInfo: Sync,
CtxT: Sync,
S: ScalarValue + Send + Sync,
T: Body<Error: fmt::Display>,
{
match parse_req(req).await {
Ok(req) => execute_request_sync(root_node, context, req).await,
Err(resp) => resp,
}
}

pub async fn graphql<CtxT, QueryT, MutationT, SubscriptionT, S>(
pub async fn graphql<CtxT, QueryT, MutationT, SubscriptionT, S, T>(
root_node: Arc<RootNode<'static, QueryT, MutationT, SubscriptionT, S>>,
context: Arc<CtxT>,
req: Request<body::Incoming>,
req: Request<T>,
) -> Response<String>
where
QueryT: GraphQLTypeAsync<S, Context = CtxT>,
Expand All @@ -50,16 +51,19 @@ where
SubscriptionT::TypeInfo: Sync,
CtxT: Sync,
S: ScalarValue + Send + Sync,
T: Body<Error: fmt::Display>,
{
match parse_req(req).await {
Ok(req) => execute_request(root_node, context, req).await,
Err(resp) => resp,
}
}

async fn parse_req<S: ScalarValue>(
req: Request<body::Incoming>,
) -> Result<GraphQLBatchRequest<S>, Response<String>> {
async fn parse_req<S, T>(req: Request<T>) -> Result<GraphQLBatchRequest<S>, Response<String>>
where
S: ScalarValue,
T: Body<Error: fmt::Display>,
{
match *req.method() {
Method::GET => parse_get_req(req),
Method::POST => {
Expand All @@ -78,9 +82,11 @@ async fn parse_req<S: ScalarValue>(
.map_err(render_error)
}

fn parse_get_req<S: ScalarValue>(
req: Request<body::Incoming>,
) -> Result<GraphQLBatchRequest<S>, GraphQLRequestError> {
fn parse_get_req<S, T>(req: Request<T>) -> Result<GraphQLBatchRequest<S>, GraphQLRequestError<T>>
where
S: ScalarValue,
T: Body,
{
req.uri()
.query()
.map(|q| gql_request_from_get(q).map(GraphQLBatchRequest::Single))
Expand All @@ -91,9 +97,13 @@ fn parse_get_req<S: ScalarValue>(
})
}

async fn parse_post_json_req<S: ScalarValue>(
body: body::Incoming,
) -> Result<GraphQLBatchRequest<S>, GraphQLRequestError> {
async fn parse_post_json_req<S, T>(
body: T,
) -> Result<GraphQLBatchRequest<S>, GraphQLRequestError<T>>
where
S: ScalarValue,
T: Body,
{
let chunk = body
.collect()
.await
Expand All @@ -106,9 +116,13 @@ async fn parse_post_json_req<S: ScalarValue>(
.map_err(GraphQLRequestError::BodyJSONError)
}

async fn parse_post_graphql_req<S: ScalarValue>(
body: body::Incoming,
) -> Result<GraphQLBatchRequest<S>, GraphQLRequestError> {
async fn parse_post_graphql_req<S, T>(
body: T,
) -> Result<GraphQLBatchRequest<S>, GraphQLRequestError<T>>
where
S: ScalarValue,
T: Body,
{
let chunk = body
.collect()
.await
Expand Down Expand Up @@ -143,7 +157,10 @@ pub async fn playground(
resp
}

fn render_error(err: GraphQLRequestError) -> Response<String> {
fn render_error<T>(err: GraphQLRequestError<T>) -> Response<String>
where
T: Body<Error: fmt::Display>,
{
let mut resp = new_response(StatusCode::BAD_REQUEST);
*resp.body_mut() = err.to_string();
resp
Expand Down Expand Up @@ -211,9 +228,12 @@ where
resp
}

fn gql_request_from_get<S>(input: &str) -> Result<JuniperGraphQLRequest<S>, GraphQLRequestError>
fn gql_request_from_get<S, T>(
input: &str,
) -> Result<JuniperGraphQLRequest<S>, GraphQLRequestError<T>>
where
S: ScalarValue,
T: Body,
{
let mut query = None;
let mut operation_name = None;
Expand Down Expand Up @@ -254,7 +274,7 @@ where
}
}

fn invalid_err(parameter_name: &str) -> GraphQLRequestError {
fn invalid_err<T: Body>(parameter_name: &str) -> GraphQLRequestError<T> {
GraphQLRequestError::Invalid(format!(
"`{parameter_name}` parameter is specified multiple times",
))
Expand All @@ -275,35 +295,57 @@ fn new_html_response(code: StatusCode) -> Response<String> {
resp
}

#[derive(Debug)]
enum GraphQLRequestError {
BodyHyper(hyper::Error),
enum GraphQLRequestError<T: Body> {
BodyHyper(T::Error),
BodyUtf8(FromUtf8Error),
BodyJSONError(SerdeError),
Variables(SerdeError),
Invalid(String),
}

impl fmt::Display for GraphQLRequestError {
// NOTE: Manual implementation instead of `#[derive(Debug)]` is used to omit imposing unnecessary
// `T: Debug` bound on the implementation.
impl<T> fmt::Debug for GraphQLRequestError<T>
where
T: Body<Error: fmt::Debug>,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
GraphQLRequestError::BodyHyper(err) => fmt::Display::fmt(err, f),
GraphQLRequestError::BodyUtf8(err) => fmt::Display::fmt(err, f),
GraphQLRequestError::BodyJSONError(err) => fmt::Display::fmt(err, f),
GraphQLRequestError::Variables(err) => fmt::Display::fmt(err, f),
GraphQLRequestError::Invalid(err) => fmt::Display::fmt(err, f),
Self::BodyHyper(e) => fmt::Debug::fmt(e, f),
Self::BodyUtf8(e) => fmt::Debug::fmt(e, f),
Self::BodyJSONError(e) => fmt::Debug::fmt(e, f),
Self::Variables(e) => fmt::Debug::fmt(e, f),
Self::Invalid(e) => fmt::Debug::fmt(e, f),
}
}
}

impl Error for GraphQLRequestError {
impl<T> fmt::Display for GraphQLRequestError<T>
where
T: Body<Error: fmt::Display>,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::BodyHyper(e) => fmt::Display::fmt(e, f),
Self::BodyUtf8(e) => fmt::Display::fmt(e, f),
Self::BodyJSONError(e) => fmt::Display::fmt(e, f),
Self::Variables(e) => fmt::Display::fmt(e, f),
Self::Invalid(e) => fmt::Display::fmt(e, f),
}
}
}

impl<T> Error for GraphQLRequestError<T>
where
T: Body<Error: Error + 'static>,
{
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
GraphQLRequestError::BodyHyper(err) => Some(err),
GraphQLRequestError::BodyUtf8(err) => Some(err),
GraphQLRequestError::BodyJSONError(err) => Some(err),
GraphQLRequestError::Variables(err) => Some(err),
GraphQLRequestError::Invalid(_) => None,
Self::BodyHyper(e) => Some(e),
Self::BodyUtf8(e) => Some(e),
Self::BodyJSONError(e) => Some(e),
Self::Variables(e) => Some(e),
Self::Invalid(_) => None,
}
}
}
Expand All @@ -314,7 +356,11 @@ mod tests {
convert::Infallible, error::Error, net::SocketAddr, panic, sync::Arc, time::Duration,
};

use hyper::{server::conn::http1, service::service_fn, Method, Response, StatusCode};
use http_body_util::BodyExt as _;
use hyper::{
body::Incoming, server::conn::http1, service::service_fn, Method, Request, Response,
StatusCode,
};
use hyper_util::rt::TokioIo;
use juniper::{
http::tests as http_tests,
Expand Down Expand Up @@ -376,8 +422,7 @@ mod tests {
}
}

async fn run_hyper_integration(is_sync: bool) {
let port = if is_sync { 3002 } else { 3001 };
async fn run_hyper_integration(port: u16, is_sync: bool, is_custom_type: bool) {
let addr = SocketAddr::from(([127, 0, 0, 1], port));

let db = Arc::new(Database::new());
Expand Down Expand Up @@ -405,7 +450,7 @@ mod tests {
if let Err(e) = http1::Builder::new()
.serve_connection(
io,
service_fn(move |req| {
service_fn(move |req: Request<Incoming>| {
let root_node = root_node.clone();
let db = db.clone();
let matches = {
Expand All @@ -419,10 +464,30 @@ mod tests {
};
async move {
Ok::<_, Infallible>(if matches {
if is_sync {
super::graphql_sync(root_node, db, req).await
if is_custom_type {
let (parts, mut body) = req.into_parts();
let body = {
let mut buf = String::new();
if let Some(Ok(frame)) = body.frame().await {
if let Ok(bytes) = frame.into_data() {
buf = String::from_utf8_lossy(&bytes)
tyranron marked this conversation as resolved.
Show resolved Hide resolved
.to_string();
}
}
buf
};
let req = Request::from_parts(parts, body);
if is_sync {
super::graphql_sync(root_node, db, req).await
} else {
super::graphql(root_node, db, req).await
}
} else {
super::graphql(root_node, db, req).await
if is_sync {
super::graphql_sync(root_node, db, req).await
} else {
super::graphql(root_node, db, req).await
}
}
} else {
let mut resp = Response::new(String::new());
Expand Down Expand Up @@ -460,11 +525,21 @@ mod tests {

#[tokio::test]
async fn test_hyper_integration() {
run_hyper_integration(false).await
run_hyper_integration(3000, false, false).await
}

#[tokio::test]
async fn test_sync_hyper_integration() {
run_hyper_integration(true).await
run_hyper_integration(3001, true, false).await
}

#[tokio::test]
async fn test_custom_request_hyper_integration() {
run_hyper_integration(3002, false, false).await
}

#[tokio::test]
async fn test_custom_request_sync_hyper_integration() {
run_hyper_integration(3003, true, true).await
}
}
Loading