Skip to content

Commit

Permalink
Add with_headers() method to RemoteHttpPlugin (#15651)
Browse files Browse the repository at this point in the history
# Objective

- fulfill the needs presented in this issue, which requires the ability
to set custom HTTP headers for responses in the Bevy Remote Protocol
server. #15551

## Solution

- Created a `Headers` struct to store custom HTTP headers as key-value
pairs.
- Added a `headers` field to the `RemoteHttpPlugin` struct.
- Implemented a `with_headers` method in `RemoteHttpPlugin` to allow
users to set custom headers.
- Passed the headers into the processing chain.

## Testing

- I added cors_headers in example/remote/server.rs and tested it with a
static html
[file](https://github.com/spacemen0/bevy/blob/test_file/test.html)

---
  • Loading branch information
spacemen0 authored Oct 6, 2024
1 parent d9190e4 commit 7c4a068
Showing 1 changed file with 98 additions and 4 deletions.
102 changes: 98 additions & 4 deletions crates/bevy_remote/src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@ use bevy_ecs::system::{Res, Resource};
use bevy_tasks::IoTaskPool;
use core::net::{IpAddr, Ipv4Addr};
use http_body_util::{BodyExt as _, Full};
pub use hyper::header::{HeaderName, HeaderValue};
use hyper::{
body::{Bytes, Incoming},
header::HeaderValue,
server::conn::http1,
service, Request, Response,
};
use serde_json::Value;
use smol_hyper::rt::{FuturesIo, SmolTimer};
use std::collections::HashMap;
use std::net::TcpListener;
use std::net::TcpStream;

Expand All @@ -36,6 +37,37 @@ pub const DEFAULT_PORT: u16 = 15702;
/// The default host address that Bevy will use for its server.
pub const DEFAULT_ADDR: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));

/// A struct that holds a collection of HTTP headers.
///
/// This struct is used to store a set of HTTP headers as key-value pairs, where the keys are
/// of type [`HeaderName`] and the values are of type [`HeaderValue`].
///
#[derive(Debug, Resource, Clone)]
pub struct Headers {
headers: HashMap<HeaderName, HeaderValue>,
}

impl Headers {
/// Create a new instance of `Headers`.
pub fn new() -> Self {
Self {
headers: HashMap::new(),
}
}

/// Add a key value pair to the `Headers` instance.
pub fn add(mut self, key: HeaderName, value: HeaderValue) -> Self {
self.headers.insert(key, value);
self
}
}

impl Default for Headers {
fn default() -> Self {
Self::new()
}
}

/// Add this plugin to your [`App`] to allow remote connections over HTTP to inspect and modify entities.
/// It requires the [`RemotePlugin`](super::RemotePlugin).
///
Expand All @@ -44,18 +76,40 @@ pub const DEFAULT_ADDR: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
/// The defaults are:
/// - [`DEFAULT_ADDR`] : 127.0.0.1.
/// - [`DEFAULT_PORT`] : 15702.
///
/// /// # Example
///
/// ```ignore
///
/// // Create CORS headers
/// let cors_headers = Headers::new()
/// .add(HeaderName::from_static("Access-Control-Allow-Origin"), HeaderValue::from_static("*"))
/// .add(HeaderName::from_static("Access-Control-Allow-Headers"), HeaderValue::from_static("Content-Type, Authorization"));
///
/// // Create the Bevy app and add the RemoteHttpPlugin with CORS headers
/// fn main() {
/// App::new()
/// .add_plugins(DefaultPlugins)
/// .add_plugins(RemoteHttpPlugin::default()
/// .with_headers(cors_headers))
/// .run();
/// }
/// ```
pub struct RemoteHttpPlugin {
/// The address that Bevy will bind to.
address: IpAddr,
/// The port that Bevy will listen on.
port: u16,
/// The headers that Bevy will include in its HTTP responses
headers: Headers,
}

impl Default for RemoteHttpPlugin {
fn default() -> Self {
Self {
address: DEFAULT_ADDR,
port: DEFAULT_PORT,
headers: Headers::new(),
}
}
}
Expand All @@ -64,6 +118,7 @@ impl Plugin for RemoteHttpPlugin {
fn build(&self, app: &mut App) {
app.insert_resource(HostAddress(self.address))
.insert_resource(HostPort(self.port))
.insert_resource(HostHeaders(self.headers.clone()))
.add_systems(Startup, start_http_server);
}
}
Expand All @@ -75,13 +130,34 @@ impl RemoteHttpPlugin {
self.address = address.into();
self
}

/// Set the remote port that the server will listen on.
#[must_use]
pub fn with_port(mut self, port: u16) -> Self {
self.port = port;
self
}
/// Set the extra headers that the response will include.
#[must_use]
pub fn with_headers(mut self, headers: Headers) -> Self {
self.headers = headers;
self
}
/// Add a single header to the response headers.
#[must_use]
pub fn with_header(
mut self,
name: impl TryInto<HeaderName>,
value: impl TryInto<HeaderValue>,
) -> Self {
let Ok(header_name) = name.try_into() else {
panic!("Invalid header name")
};
let Ok(header_value) = value.try_into() else {
panic!("Invalid header value")
};
self.headers = self.headers.add(header_name, header_value);
self
}
}

/// A resource containing the IP address that Bevy will host on.
Expand All @@ -98,17 +174,24 @@ pub struct HostAddress(pub IpAddr);
#[derive(Debug, Resource)]
pub struct HostPort(pub u16);

/// A resource containing the headers that Bevy will include in its HTTP responses.
///
#[derive(Debug, Resource)]
struct HostHeaders(pub Headers);

/// A system that starts up the Bevy Remote Protocol HTTP server.
fn start_http_server(
request_sender: Res<BrpSender>,
address: Res<HostAddress>,
remote_port: Res<HostPort>,
headers: Res<HostHeaders>,
) {
IoTaskPool::get()
.spawn(server_main(
address.0,
remote_port.0,
request_sender.clone(),
headers.0.clone(),
))
.detach();
}
Expand All @@ -118,25 +201,29 @@ async fn server_main(
address: IpAddr,
port: u16,
request_sender: Sender<BrpMessage>,
headers: Headers,
) -> AnyhowResult<()> {
listen(
Async::<TcpListener>::bind((address, port))?,
&request_sender,
&headers,
)
.await
}

async fn listen(
listener: Async<TcpListener>,
request_sender: &Sender<BrpMessage>,
headers: &Headers,
) -> AnyhowResult<()> {
loop {
let (client, _) = listener.accept().await?;

let request_sender = request_sender.clone();
let headers = headers.clone();
IoTaskPool::get()
.spawn(async move {
let _ = handle_client(client, request_sender).await;
let _ = handle_client(client, request_sender, headers).await;
})
.detach();
}
Expand All @@ -145,12 +232,15 @@ async fn listen(
async fn handle_client(
client: Async<TcpStream>,
request_sender: Sender<BrpMessage>,
headers: Headers,
) -> AnyhowResult<()> {
http1::Builder::new()
.timer(SmolTimer::new())
.serve_connection(
FuturesIo::new(client),
service::service_fn(|request| process_request_batch(request, &request_sender)),
service::service_fn(|request| {
process_request_batch(request, &request_sender, &headers)
}),
)
.await?;

Expand All @@ -162,6 +252,7 @@ async fn handle_client(
async fn process_request_batch(
request: Request<Incoming>,
request_sender: &Sender<BrpMessage>,
headers: &Headers,
) -> AnyhowResult<Response<Full<Bytes>>> {
let batch_bytes = request.into_body().collect().await?.to_bytes();
let batch: Result<BrpBatch, _> = serde_json::from_slice(&batch_bytes);
Expand Down Expand Up @@ -198,6 +289,9 @@ async fn process_request_batch(
hyper::header::CONTENT_TYPE,
HeaderValue::from_static("application/json"),
);
for (key, value) in &headers.headers {
response.headers_mut().insert(key, value.clone());
}
Ok(response)
}

Expand Down

0 comments on commit 7c4a068

Please sign in to comment.