Skip to content

Commit

Permalink
feat(channel): add client/server and a test for withdraw (lud02)
Browse files Browse the repository at this point in the history
  • Loading branch information
lsunsi committed Dec 4, 2023
1 parent 6cbaff3 commit 825b57e
Show file tree
Hide file tree
Showing 6 changed files with 234 additions and 25 deletions.
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ perf = "deny"
style = "deny"
suspicious = "deny"

[[test]]
name = "lud02"
required-features = ["client", "server"]

[[test]]
name = "lud03"
required-features = ["client", "server"]
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ This library works as a toolkit so you can serve and make your LNURL requests wi
## Current support

- [LUD-01](https://github.com/lnurl/luds/blob/luds/01.md): ✅ core ✅ client ✅ server ⚠️ tests
- [LUD-02](https://github.com/lnurl/luds/blob/luds/02.md): ✅ core ⚠️ client 🆘 server 🆘 tests
- [LUD-02](https://github.com/lnurl/luds/blob/luds/02.md): ✅ core client server ⚠️ tests
- [LUD-03](https://github.com/lnurl/luds/blob/luds/03.md): ✅ core ✅ client ✅ server ⚠️ tests
- [LUD-04](https://github.com/lnurl/luds/blob/luds/04.md): 🆘 core 🆘 client 🆘 server 🆘 tests
- [LUD-05](https://github.com/lnurl/luds/blob/luds/05.md): 🆘 core 🆘 client 🆘 server 🆘 tests
Expand Down
23 changes: 17 additions & 6 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,31 +57,42 @@ impl ChannelRequest<'_> {
/// # Errors
///
/// Returns errors on network or deserialization failures.
pub async fn callback_accept(self, remoteid: &str, private: bool) -> Result<(), &'static str> {
pub async fn callback_accept(
self,
remoteid: &str,
private: bool,
) -> Result<core::channel_request::CallbackResponse, &'static str> {
let callback = self.core.callback_accept(remoteid, private);

self.client
let response = self
.client
.get(callback)
.send()
.await
.map_err(|_| "request failed")?;

Ok(())
let text = response.text().await.map_err(|_| "body failed")?;
text.parse().map_err(|_| "parse failed")
}

/// # Errors
///
/// Returns errors on network or deserialization failures.
pub async fn callback_cancel(self, remoteid: &str) -> Result<(), &'static str> {
pub async fn callback_cancel(
self,
remoteid: &str,
) -> Result<core::channel_request::CallbackResponse, &'static str> {
let callback = self.core.callback_cancel(remoteid);

self.client
let response = self
.client
.get(callback)
.send()
.await
.map_err(|_| "request failed")?;

Ok(())
let text = response.text().await.map_err(|_| "body failed")?;
text.parse().map_err(|_| "parse failed")
}
}

Expand Down
87 changes: 76 additions & 11 deletions src/core/channel_request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,35 @@ pub const TAG: &str = "channelRequest";

#[derive(Clone, Debug)]
pub struct ChannelRequest {
callback: url::Url,
pub callback: url::Url,
pub uri: String,
k1: String,
pub k1: String,
}

impl std::str::FromStr for ChannelRequest {
type Err = &'static str;

fn from_str(s: &str) -> Result<Self, Self::Err> {
let d: de::QueryResponse =
miniserde::json::from_str(s).map_err(|_| "deserialize failed")?;

Ok(ChannelRequest {
callback: d.callback.0,
uri: d.uri,
k1: d.k1,
})
}
}

impl std::fmt::Display for ChannelRequest {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&miniserde::json::to_string(&ser::QueryResponse {
tag: TAG,
callback: crate::serde::Url(self.callback.clone()),
uri: &self.uri,
k1: &self.k1,
}))
}
}

impl ChannelRequest {
Expand Down Expand Up @@ -37,22 +63,61 @@ impl ChannelRequest {
}
}

impl std::str::FromStr for ChannelRequest {
#[derive(Debug)]
pub enum CallbackResponse {
Error(String),
Ok,
}

impl std::str::FromStr for CallbackResponse {
type Err = &'static str;

fn from_str(s: &str) -> Result<Self, Self::Err> {
let d: serde::QueryResponse =
miniserde::json::from_str(s).map_err(|_| "deserialize failed")?;
let map = miniserde::json::from_str::<std::collections::BTreeMap<String, String>>(s)
.map_err(|_| "bad json")?;

match map.get("status").map(|s| s as &str) {
Some("OK") => Ok(CallbackResponse::Ok),
Some("ERROR") => Ok(CallbackResponse::Error(
map.get("reason").cloned().unwrap_or_default(),
)),
_ => Err("bad status field"),
}
}
}

Ok(ChannelRequest {
callback: d.callback.0,
uri: d.uri,
k1: d.k1,
})
impl std::fmt::Display for CallbackResponse {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut map = std::collections::BTreeMap::new();

match self {
CallbackResponse::Error(reason) => {
map.insert("status", "ERROR");
map.insert("reason", reason);
}
CallbackResponse::Ok => {
map.insert("status", "OK");
}
}

f.write_str(&miniserde::json::to_string(&map))
}
}

mod ser {
use crate::serde::Url;
use miniserde::Serialize;

#[derive(Serialize)]
pub(super) struct QueryResponse<'a> {
pub tag: &'static str,
pub callback: Url,
pub uri: &'a str,
pub k1: &'a str,
}
}

mod serde {
mod de {
use crate::serde::Url;
use miniserde::Deserialize;

Expand Down
72 changes: 65 additions & 7 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use crate::core;
use axum::{extract::RawQuery, http::StatusCode, routing::get, Router};
use std::future::Future;

pub struct Server<WQ, WC, PQ, PC> {
pub struct Server<CQ, CC, WQ, WC, PQ, PC> {
channel_query: CQ,
channel_callback: CC,
withdraw_query: WQ,
withdraw_callback: WC,
pay_query: PQ,
Expand All @@ -11,6 +13,8 @@ pub struct Server<WQ, WC, PQ, PC> {

impl Default
for Server<
unimplemented::Handler0<core::channel_request::ChannelRequest>,
unimplemented::Handler1<(String, String), core::channel_request::CallbackResponse>,
unimplemented::Handler0<core::withdraw_request::WithdrawRequest>,
unimplemented::Handler1<String, core::withdraw_request::CallbackResponse>,
unimplemented::Handler0<core::pay_request::PayRequest>,
Expand All @@ -19,6 +23,8 @@ impl Default
{
fn default() -> Self {
Server {
channel_query: unimplemented::handler0,
channel_callback: unimplemented::handler1,
withdraw_query: unimplemented::handler0,
withdraw_callback: unimplemented::handler1,
pay_query: unimplemented::handler0,
Expand All @@ -27,13 +33,30 @@ impl Default
}
}

impl<WQ, WC, PQ, PC> Server<WQ, WC, PQ, PC> {
impl<CQ, CC, WQ, WC, PQ, PC> Server<CQ, CC, WQ, WC, PQ, PC> {
pub fn channel_request<CQ2, CC2>(
self,
channel_query: CQ2,
channel_callback: CC2,
) -> Server<CQ2, CC2, WQ, WC, PQ, PC> {
Server {
channel_query,
channel_callback,
pay_query: self.pay_query,
pay_callback: self.pay_callback,
withdraw_query: self.withdraw_query,
withdraw_callback: self.withdraw_callback,
}
}

pub fn withdraw_request<WQ2, WC2>(
self,
withdraw_query: WQ2,
withdraw_callback: WC2,
) -> Server<WQ2, WC2, PQ, PC> {
) -> Server<CQ, CC, WQ2, WC2, PQ, PC> {
Server {
channel_query: self.channel_query,
channel_callback: self.channel_callback,
pay_query: self.pay_query,
pay_callback: self.pay_callback,
withdraw_query,
Expand All @@ -45,8 +68,10 @@ impl<WQ, WC, PQ, PC> Server<WQ, WC, PQ, PC> {
self,
pay_query: PQ2,
pay_callback: PC2,
) -> Server<WQ, WC, PQ2, PC2> {
) -> Server<CQ, CC, WQ, WC, PQ2, PC2> {
Server {
channel_query: self.channel_query,
channel_callback: self.channel_callback,
pay_query,
pay_callback,
withdraw_query: self.withdraw_query,
Expand All @@ -55,8 +80,15 @@ impl<WQ, WC, PQ, PC> Server<WQ, WC, PQ, PC> {
}
}

impl<WQ, WQFut, WC, WCFut, PQ, PQFut, PC, PCFut> Server<WQ, WC, PQ, PC>
impl<CQ, CQFut, CC, CCFut, WQ, WQFut, WC, WCFut, PQ, PQFut, PC, PCFut>
Server<CQ, CC, WQ, WC, PQ, PC>
where
CQ: 'static + Send + Clone + Fn() -> CQFut,
CQFut: Send + Future<Output = Result<core::channel_request::ChannelRequest, StatusCode>>,

CC: 'static + Send + Clone + Fn((String, String)) -> CCFut,
CCFut: Send + Future<Output = Result<core::channel_request::CallbackResponse, StatusCode>>,

WQ: 'static + Send + Clone + Fn() -> WQFut,
WQFut: Send + Future<Output = Result<core::withdraw_request::WithdrawRequest, StatusCode>>,

Expand All @@ -71,11 +103,37 @@ where
{
pub fn build(self) -> Router<()> {
Router::new()
.route(
"/lnurlc",
get(move || {
let cq = self.channel_query.clone();
async move { cq().await.map(|a| a.to_string()) }
}),
)
.route(
"/lnurlc/callback",
get(move |RawQuery(q): RawQuery| {
let cc = self.channel_callback.clone();
async move {
let q = q.ok_or(StatusCode::BAD_REQUEST)?;
let qs = q
.split('&')
.filter_map(|s| s.split_once('='))
.collect::<std::collections::BTreeMap<_, _>>();

let k1 = qs.get("k1").ok_or(StatusCode::BAD_REQUEST)?;
let remoteid = qs.get("remoteid").ok_or(StatusCode::BAD_REQUEST)?;
cc((String::from(*k1), String::from(*remoteid)))
.await
.map(|a| a.to_string())
}
}),
)
.route(
"/lnurlw",
get(move || {
let pq = self.withdraw_query.clone();
async move { pq().await.map(|a| a.to_string()) }
let wq = self.withdraw_query.clone();
async move { wq().await.map(|a| a.to_string()) }
}),
)
.route(
Expand Down
71 changes: 71 additions & 0 deletions tests/lud02.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#[tokio::test]
async fn test() {
let listener = tokio::net::TcpListener::bind("0.0.0.0:0")
.await
.expect("net");

let addr = listener.local_addr().expect("addr");

let query_url = format!("http://{addr}/lnurlc");
let callback_url = url::Url::parse(&format!("http://{addr}/lnurlc/callback")).expect("url");

let router = lnurlkit::server::Server::default()
.channel_request(
move || {
let callback = callback_url.clone();
async {
Ok(lnurlkit::core::channel_request::ChannelRequest {
uri: String::from("u@r:i"),
k1: String::from("caum"),
callback,
})
}
},
|(k1, remoteid)| async move {
Ok(if remoteid == "idremoto" {
lnurlkit::core::channel_request::CallbackResponse::Ok
} else {
lnurlkit::core::channel_request::CallbackResponse::Error(k1)
})
},
)
.build();

tokio::spawn(async move {
axum::serve(listener, router).await.expect("serve");
});

let client = lnurlkit::client::Client::default();

let lnurl = bech32::encode(
"lnurl",
bech32::ToBase32::to_base32(&query_url),
bech32::Variant::Bech32,
)
.expect("lnurl");

let queried = client.query(&lnurl).await.expect("query");
let lnurlkit::client::Query::ChannelRequest(cr) = queried else {
panic!("not pay request");
};

assert_eq!(cr.core.uri, "u@r:i");

let response = cr
.clone()
.callback_cancel("idremoto")
.await
.expect("callback");

assert!(matches!(
response,
lnurlkit::core::channel_request::CallbackResponse::Ok
));

let response = cr.callback_cancel("iderrado").await.expect("callback");

assert!(matches!(
response,
lnurlkit::core::channel_request::CallbackResponse::Error(r) if r == "caum"
));
}

0 comments on commit 825b57e

Please sign in to comment.