diff --git a/Cargo.toml b/Cargo.toml index 8787370..b0e02ea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,6 +46,10 @@ perf = "deny" style = "deny" suspicious = "deny" +[[test]] +name = "lud02" +required-features = ["client", "server"] + [[test]] name = "lud03" required-features = ["client", "server"] diff --git a/src/client.rs b/src/client.rs index 5534c0c..116076d 100644 --- a/src/client.rs +++ b/src/client.rs @@ -57,31 +57,42 @@ impl ChannelRequest<'_> { /// # Errors /// /// Returns errors on network or deserialization failures. - pub async fn callback_accept(self, remoteid: &str, private: bool) -> Result<(), &'static str> { + pub async fn callback_accept( + self, + remoteid: &str, + private: bool, + ) -> Result { let callback = self.core.callback_accept(remoteid, private); - self.client + let response = self + .client .get(callback) .send() .await .map_err(|_| "request failed")?; - Ok(()) + let text = response.text().await.map_err(|_| "body failed")?; + text.parse().map_err(|_| "parse failed") } /// # Errors /// /// Returns errors on network or deserialization failures. - pub async fn callback_cancel(self, remoteid: &str) -> Result<(), &'static str> { + pub async fn callback_cancel( + self, + remoteid: &str, + ) -> Result { let callback = self.core.callback_cancel(remoteid); - self.client + let response = self + .client .get(callback) .send() .await .map_err(|_| "request failed")?; - Ok(()) + let text = response.text().await.map_err(|_| "body failed")?; + text.parse().map_err(|_| "parse failed") } } diff --git a/src/core/channel_request.rs b/src/core/channel_request.rs index 1026c05..a14c121 100644 --- a/src/core/channel_request.rs +++ b/src/core/channel_request.rs @@ -2,9 +2,35 @@ pub const TAG: &str = "channelRequest"; #[derive(Clone, Debug)] pub struct ChannelRequest { - callback: url::Url, + pub callback: url::Url, pub uri: String, - k1: String, + pub k1: String, +} + +impl std::str::FromStr for ChannelRequest { + type Err = &'static str; + + fn from_str(s: &str) -> Result { + let d: de::QueryResponse = + miniserde::json::from_str(s).map_err(|_| "deserialize failed")?; + + Ok(ChannelRequest { + callback: d.callback.0, + uri: d.uri, + k1: d.k1, + }) + } +} + +impl std::fmt::Display for ChannelRequest { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(&miniserde::json::to_string(&ser::QueryResponse { + tag: TAG, + callback: crate::serde::Url(self.callback.clone()), + uri: &self.uri, + k1: &self.k1, + })) + } } impl ChannelRequest { @@ -37,22 +63,61 @@ impl ChannelRequest { } } -impl std::str::FromStr for ChannelRequest { +#[derive(Debug)] +pub enum CallbackResponse { + Error(String), + Ok, +} + +impl std::str::FromStr for CallbackResponse { type Err = &'static str; fn from_str(s: &str) -> Result { - let d: serde::QueryResponse = - miniserde::json::from_str(s).map_err(|_| "deserialize failed")?; + let map = miniserde::json::from_str::>(s) + .map_err(|_| "bad json")?; + + match map.get("status").map(|s| s as &str) { + Some("OK") => Ok(CallbackResponse::Ok), + Some("ERROR") => Ok(CallbackResponse::Error( + map.get("reason").cloned().unwrap_or_default(), + )), + _ => Err("bad status field"), + } + } +} - Ok(ChannelRequest { - callback: d.callback.0, - uri: d.uri, - k1: d.k1, - }) +impl std::fmt::Display for CallbackResponse { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut map = std::collections::BTreeMap::new(); + + match self { + CallbackResponse::Error(reason) => { + map.insert("status", "ERROR"); + map.insert("reason", reason); + } + CallbackResponse::Ok => { + map.insert("status", "OK"); + } + } + + f.write_str(&miniserde::json::to_string(&map)) + } +} + +mod ser { + use crate::serde::Url; + use miniserde::Serialize; + + #[derive(Serialize)] + pub(super) struct QueryResponse<'a> { + pub tag: &'static str, + pub callback: Url, + pub uri: &'a str, + pub k1: &'a str, } } -mod serde { +mod de { use crate::serde::Url; use miniserde::Deserialize; diff --git a/src/server.rs b/src/server.rs index 5ccad0d..7e94b8f 100644 --- a/src/server.rs +++ b/src/server.rs @@ -2,7 +2,9 @@ use crate::core; use axum::{extract::RawQuery, http::StatusCode, routing::get, Router}; use std::future::Future; -pub struct Server { +pub struct Server { + channel_query: CQ, + channel_callback: CC, withdraw_query: WQ, withdraw_callback: WC, pay_query: PQ, @@ -11,6 +13,8 @@ pub struct Server { impl Default for Server< + unimplemented::Handler0, + unimplemented::Handler1<(String, String), core::channel_request::CallbackResponse>, unimplemented::Handler0, unimplemented::Handler1, unimplemented::Handler0, @@ -19,6 +23,8 @@ impl Default { fn default() -> Self { Server { + channel_query: unimplemented::handler0, + channel_callback: unimplemented::handler1, withdraw_query: unimplemented::handler0, withdraw_callback: unimplemented::handler1, pay_query: unimplemented::handler0, @@ -27,13 +33,30 @@ impl Default } } -impl Server { +impl Server { + pub fn channel_request( + self, + channel_query: CQ2, + channel_callback: CC2, + ) -> Server { + Server { + channel_query, + channel_callback, + pay_query: self.pay_query, + pay_callback: self.pay_callback, + withdraw_query: self.withdraw_query, + withdraw_callback: self.withdraw_callback, + } + } + pub fn withdraw_request( self, withdraw_query: WQ2, withdraw_callback: WC2, - ) -> Server { + ) -> Server { Server { + channel_query: self.channel_query, + channel_callback: self.channel_callback, pay_query: self.pay_query, pay_callback: self.pay_callback, withdraw_query, @@ -45,8 +68,10 @@ impl Server { self, pay_query: PQ2, pay_callback: PC2, - ) -> Server { + ) -> Server { Server { + channel_query: self.channel_query, + channel_callback: self.channel_callback, pay_query, pay_callback, withdraw_query: self.withdraw_query, @@ -55,8 +80,15 @@ impl Server { } } -impl Server +impl + Server where + CQ: 'static + Send + Clone + Fn() -> CQFut, + CQFut: Send + Future>, + + CC: 'static + Send + Clone + Fn((String, String)) -> CCFut, + CCFut: Send + Future>, + WQ: 'static + Send + Clone + Fn() -> WQFut, WQFut: Send + Future>, @@ -71,11 +103,37 @@ where { pub fn build(self) -> Router<()> { Router::new() + .route( + "/lnurlc", + get(move || { + let cq = self.channel_query.clone(); + async move { cq().await.map(|a| a.to_string()) } + }), + ) + .route( + "/lnurlc/callback", + get(move |RawQuery(q): RawQuery| { + let cc = self.channel_callback.clone(); + async move { + let q = q.ok_or(StatusCode::BAD_REQUEST)?; + let qs = q + .split('&') + .filter_map(|s| s.split_once('=')) + .collect::>(); + + let k1 = qs.get("k1").ok_or(StatusCode::BAD_REQUEST)?; + let remoteid = qs.get("remoteid").ok_or(StatusCode::BAD_REQUEST)?; + cc((String::from(*k1), String::from(*remoteid))) + .await + .map(|a| a.to_string()) + } + }), + ) .route( "/lnurlw", get(move || { - let pq = self.withdraw_query.clone(); - async move { pq().await.map(|a| a.to_string()) } + let wq = self.withdraw_query.clone(); + async move { wq().await.map(|a| a.to_string()) } }), ) .route( diff --git a/tests/lud02.rs b/tests/lud02.rs new file mode 100644 index 0000000..0d579a4 --- /dev/null +++ b/tests/lud02.rs @@ -0,0 +1,71 @@ +#[tokio::test] +async fn test() { + let listener = tokio::net::TcpListener::bind("0.0.0.0:0") + .await + .expect("net"); + + let addr = listener.local_addr().expect("addr"); + + let query_url = format!("http://{addr}/lnurlc"); + let callback_url = url::Url::parse(&format!("http://{addr}/lnurlc/callback")).expect("url"); + + let router = lnurlkit::server::Server::default() + .channel_request( + move || { + let callback = callback_url.clone(); + async { + Ok(lnurlkit::core::channel_request::ChannelRequest { + uri: String::from("u@r:i"), + k1: String::from("caum"), + callback, + }) + } + }, + |(k1, remoteid)| async move { + Ok(if remoteid == "idremoto" { + lnurlkit::core::channel_request::CallbackResponse::Ok + } else { + lnurlkit::core::channel_request::CallbackResponse::Error(k1) + }) + }, + ) + .build(); + + tokio::spawn(async move { + axum::serve(listener, router).await.expect("serve"); + }); + + let client = lnurlkit::client::Client::default(); + + let lnurl = bech32::encode( + "lnurl", + bech32::ToBase32::to_base32(&query_url), + bech32::Variant::Bech32, + ) + .expect("lnurl"); + + let queried = client.query(&lnurl).await.expect("query"); + let lnurlkit::client::Query::ChannelRequest(cr) = queried else { + panic!("not pay request"); + }; + + assert_eq!(cr.core.uri, "u@r:i"); + + let response = cr + .clone() + .callback_cancel("idremoto") + .await + .expect("callback"); + + assert!(matches!( + response, + lnurlkit::core::channel_request::CallbackResponse::Ok + )); + + let response = cr.callback_cancel("iderrado").await.expect("callback"); + + assert!(matches!( + response, + lnurlkit::core::channel_request::CallbackResponse::Error(r) if r == "caum" + )); +}