diff --git a/lib/dune b/lib/dune index 63a629e..ccde7b8 100644 --- a/lib/dune +++ b/lib/dune @@ -27,6 +27,7 @@ eio.unix fmt fqueue + httpx jingoo logs mirage-clock-unix diff --git a/lib/guild.ml b/lib/guild.ml index 54f1594..12ce52e 100644 --- a/lib/guild.ml +++ b/lib/guild.ml @@ -84,10 +84,10 @@ let query_voice_provider env ~config ~provider ~text = match provider with | Config.Post endpoint -> Eio.Switch.run @@ fun sw -> - let resp = Discord.Httpx.post env ~sw ~body:(`Fixed text) endpoint in + let resp = Httpx.Http.post env ~sw ~body:(`Fixed text) endpoint in let status = fst resp |> Http.Response.status in if status |> Cohttp.Code.code_of_status |> Cohttp.Code.is_success then - Discord.Httpx.drain_resp_body resp + Httpx.Http.drain_resp_body resp else failwith (Printf.sprintf "Failed to get speech: %s" @@ -104,13 +104,13 @@ let query_voice_provider env ~config ~provider ~text = Uri.encoded_of_query [ ("text", [ text ]); ("key", [ key ]) ] in let resp = - Discord.Httpx.post env ~sw + Httpx.Http.post env ~sw ~headers:[ ("content-type", "application/x-www-form-urlencoded") ] ~body:(`Fixed body) endpoint in let status = fst resp |> Http.Response.status in if status |> Cohttp.Code.code_of_status |> Cohttp.Code.is_success then - Discord.Httpx.drain_resp_body resp + Httpx.Http.drain_resp_body resp else failwith (Printf.sprintf "Failed to get speech: %s" @@ -125,7 +125,7 @@ let query_voice_provider env ~config ~provider ~text = [ ("style_id", [ string_of_int style_id ]); ("text", [ text ]) ] in let resp = - Discord.Httpx.post env ~sw + Httpx.Http.post env ~sw ~headers:[ ("content-type", "application/x-www-form-urlencoded") ] (Uri.to_string u) in @@ -133,14 +133,14 @@ let query_voice_provider env ~config ~provider ~text = fst resp |> Http.Response.status |> Cohttp.Code.code_of_status |> Cohttp.Code.is_success |> not then failwith "query_voice_provider: Voicevox: Failed to get query.json"; - let query_json = Discord.Httpx.drain_resp_body resp in + let query_json = Httpx.Http.drain_resp_body resp in let u = Uri.with_path (Uri.of_string config.voicevox_endpoint) "/synthesis" in let u = Uri.with_query u [ ("style_id", [ string_of_int style_id ]) ] in let resp = - Discord.Httpx.post env ~sw + Httpx.Http.post env ~sw ~headers:[ ("content-type", "application/json") ] ~body:(`Fixed query_json) (Uri.to_string u) in @@ -148,7 +148,7 @@ let query_voice_provider env ~config ~provider ~text = fst resp |> Http.Response.status |> Cohttp.Code.code_of_status |> Cohttp.Code.is_success |> not then failwith "query_voice_provider: Voicevox: Failed to get speech"; - Discord.Httpx.drain_resp_body resp + Httpx.Http.drain_resp_body resp let format_discord_message (msg : Discord.Object.message) = (* Concat dummy to content if there are attachments *) diff --git a/lib_discord/discord.ml b/lib_discord/discord.ml index abb9409..2b96201 100644 --- a/lib_discord/discord.ml +++ b/lib_discord/discord.ml @@ -36,7 +36,6 @@ module Agent = Agent module Consumer = Consumer module Event = Event -module Httpx = Httpx module Intent = Intent module Object = Object module Rest = Rest diff --git a/lib_discord/dune b/lib_discord/dune index 5c55812..08cbf49 100644 --- a/lib_discord/dune +++ b/lib_discord/dune @@ -16,6 +16,7 @@ eio.core eio.unix fqueue + httpx ipaddr logs logs.fmt diff --git a/lib_discord/httpx.ml b/lib_discord/httpx.ml deleted file mode 100644 index a4025cb..0000000 --- a/lib_discord/httpx.ml +++ /dev/null @@ -1,33 +0,0 @@ -open Util - -let null_auth ?ip:_ ~host:_ _ = - Ok None (* Warning: use a real authenticator in your code! *) - -let https ~authenticator = - let tls_config = Tls.Config.client ~authenticator () in - fun uri raw -> - let host = - Uri.host uri - |> Option.map (fun x -> Domain_name.(host_exn (of_string_exn x))) - in - Tls_eio.client_of_flow ?host tls_config raw - -let request ?headers ?body ~meth env ~sw (url : string) = - let headers = headers |> Option.map Cohttp.Header.of_list in - let body = - body |> Option.map (function `Fixed src -> Cohttp_eio.Body.of_string src) - in - let client = - Cohttp_eio.Client.make - ~https:(Some (https ~authenticator:null_auth)) - (Eio.Stdenv.net env) - in - Cohttp_eio.Client.call ~sw ?headers ?body client meth (Uri.of_string url) - -let get = request ~meth:`GET -let post = request ~meth:`POST -let put = request ~meth:`PUT -let delete = request ~meth:`DELETE - -let drain_resp_body (_, body) = - Eio.Buf_read.(parse_exn take_all) body ~max_size:max_int diff --git a/lib_discord/rest.ml b/lib_discord/rest.ml index e550462..31167c6 100644 --- a/lib_discord/rest.ml +++ b/lib_discord/rest.ml @@ -15,8 +15,8 @@ let request ~meth ?body env ~token path = (body |> Option.fold ~none:"" ~some:Yojson.Safe.to_string)); let body = body |> Option.map (fun x -> `Fixed (Yojson.Safe.to_string x)) in Eio.Switch.run @@ fun sw -> - let resp = Httpx.request ~meth ~headers ?body env ~sw url in - let body = Httpx.drain_resp_body resp in + let resp = Httpx.Http.request ~meth ~headers ?body env ~sw url in + let body = Httpx.Http.drain_resp_body resp in let body = try body |> Yojson.Safe.from_string |> Option.some with _ -> None in diff --git a/lib_discord/ws.ml b/lib_discord/ws.ml index c46d86a..4284e1a 100644 --- a/lib_discord/ws.ml +++ b/lib_discord/ws.ml @@ -1,117 +1,4 @@ -include Websocket.Make (Cohttp_eio.Private.IO) - -type conn = { - id : string; - read_frame : unit -> Websocket.Frame.t; - write_frame : Websocket.Frame.t -> unit; -} - -let drain_handshake req ic oc nonce = - Request.write (fun _ -> ()) req oc; - let resp = - match Response.read ic with - | `Ok r -> r - | `Eof -> raise End_of_file - | `Invalid s -> failwith s - in - let status = Cohttp.Response.status resp in - let headers = Cohttp.Response.headers resp in - if Cohttp.Code.(is_error (code_of_status status)) then - failwith ("error status: " ^ Cohttp.Code.(string_of_status status)); - if Cohttp.Response.version resp <> `HTTP_1_1 then - failwith "invalid HTTP version"; - if status <> `Switching_protocols then failwith "wrong status"; - (match Cohttp.Header.get headers "upgrade" with - | Some a when String.lowercase_ascii a = "websocket" -> () - | _ -> failwith "wrong upgrade"); - if not (Websocket.upgrade_present headers) then - failwith "upgrade header not present"; - (match Cohttp.Header.get headers "sec-websocket-accept" with - | Some accept - when accept - = Websocket.b64_encoded_sha1sum (nonce ^ Websocket.websocket_uuid) -> - () - | _ -> failwith "wrong accept"); - () - -let connect' env sw url nonce extra_headers = - (* Make request *) - let headers = - Cohttp.Header.add_list extra_headers - [ - ("Upgrade", "websocket"); - ("Connection", "Upgrade"); - ("Sec-WebSocket-Key", nonce); - ("Sec-WebSocket-Version", "13"); - ] - in - let req = Cohttp.Request.make ~headers url in - - (* Make socket *) - let host = Uri.host url |> Option.get in - let service = Uri.scheme url |> Option.get in - let addr = - match Eio.Net.getaddrinfo_stream (Eio.Stdenv.net env) host ~service with - | [] -> failwith "getaddrinfo failed" - | addr :: _ -> addr - in - let socket = Eio.Net.connect ~sw (Eio.Stdenv.net env) addr in - let flow = - let authenticator = - let null_auth ?ip:_ ~host:_ _ = Ok None in - null_auth - in - let host = - Result.to_option - (Result.bind (Domain_name.of_string host) Domain_name.host) - in - Tls_eio.client_of_flow - Tls.Config.( - client ~version:(`TLS_1_0, `TLS_1_3) ~authenticator - ~ciphers:Ciphers.supported ()) - ?host socket - in - - (* Drain handshake *) - let ic = Eio.Buf_read.of_flow ~max_size:max_int flow in - Eio.Buf_write.with_flow flow (fun oc -> drain_handshake req ic oc nonce); - - (flow, ic) - -let connect ?(extra_headers = Cohttp.Header.init ()) ~sw env url = - let url = Uri.of_string url in - - let nonce = Base64.encode_exn (Csprng.random_string 16) in - let flow, ic = connect' env sw url nonce extra_headers in - - (* Start writer fiber. All writes must be done in this fiber, - because Eio.Flow.write is not thread-safe. - c.f.: https://github.com/ocaml-multicore/eio/blob/v0.11/lib_eio/flow.mli#L73-L74 - *) - let write_queue = Eio.Stream.create 10 in - (let rec writer () = - try - let frame = Eio.Stream.take write_queue in - let buf = Buffer.create 128 in - write_frame_to_buf ~mode:(Client Csprng.random_string) buf frame; - Eio.Buf_write.with_flow flow (fun oc -> - Eio.Buf_write.string oc (Buffer.contents buf)); - writer () - with Eio.Io _ -> () - in - Eio.Fiber.fork ~sw writer); - - let write_frame frame = Eio.Stream.add write_queue frame in - let read_frame () = - Eio.Buf_write.with_flow flow (fun oc -> - make_read_frame ~mode:(Client Csprng.random_string) ic oc ()) - in - - { id = Csprng.random_string 10; read_frame; write_frame } - -let id { id; _ } = id -let read { read_frame; _ } = read_frame () -let write { write_frame; _ } frame = write_frame frame +include Httpx.Ws module Process = struct type msg = diff --git a/lib_httpx/dune b/lib_httpx/dune new file mode 100644 index 0000000..332e330 --- /dev/null +++ b/lib_httpx/dune @@ -0,0 +1,17 @@ +(library + (name httpx) + (libraries + base64 + cohttp + cohttp-eio + cstruct + domain-name + eio + eio.core + eio.unix + mirage-crypto-rng + tls + tls-eio + uri + websocket + x509)) diff --git a/lib_httpx/httpx.ml b/lib_httpx/httpx.ml new file mode 100644 index 0000000..eb72d86 --- /dev/null +++ b/lib_httpx/httpx.ml @@ -0,0 +1,155 @@ +open struct + let random_string len = Mirage_crypto_rng.generate len |> Cstruct.to_string +end + +module Http = struct + let null_auth ?ip:_ ~host:_ _ = + Ok None (* Warning: use a real authenticator in your code! *) + + let https ~authenticator = + let tls_config = Tls.Config.client ~authenticator () in + fun uri raw -> + let host = + Uri.host uri + |> Option.map (fun x -> Domain_name.(host_exn (of_string_exn x))) + in + Tls_eio.client_of_flow ?host tls_config raw + + let request ?headers ?body ~meth env ~sw (url : string) = + let headers = headers |> Option.map Cohttp.Header.of_list in + let body = + body + |> Option.map (function `Fixed src -> Cohttp_eio.Body.of_string src) + in + let client = + Cohttp_eio.Client.make + ~https:(Some (https ~authenticator:null_auth)) + (Eio.Stdenv.net env) + in + Cohttp_eio.Client.call ~sw ?headers ?body client meth (Uri.of_string url) + + let get = request ~meth:`GET + let post = request ~meth:`POST + let put = request ~meth:`PUT + let delete = request ~meth:`DELETE + + let drain_resp_body (_, body) = + Eio.Buf_read.(parse_exn take_all) body ~max_size:max_int +end + +module Ws = struct + include Websocket.Make (Cohttp_eio.Private.IO) + + type conn = { + id : string; + read_frame : unit -> Websocket.Frame.t; + write_frame : Websocket.Frame.t -> unit; + } + + let drain_handshake req ic oc nonce = + Request.write (fun _ -> ()) req oc; + let resp = + match Response.read ic with + | `Ok r -> r + | `Eof -> raise End_of_file + | `Invalid s -> failwith s + in + let status = Cohttp.Response.status resp in + let headers = Cohttp.Response.headers resp in + if Cohttp.Code.(is_error (code_of_status status)) then + failwith ("error status: " ^ Cohttp.Code.(string_of_status status)); + if Cohttp.Response.version resp <> `HTTP_1_1 then + failwith "invalid HTTP version"; + if status <> `Switching_protocols then failwith "wrong status"; + (match Cohttp.Header.get headers "upgrade" with + | Some a when String.lowercase_ascii a = "websocket" -> () + | _ -> failwith "wrong upgrade"); + if not (Websocket.upgrade_present headers) then + failwith "upgrade header not present"; + (match Cohttp.Header.get headers "sec-websocket-accept" with + | Some accept + when accept + = Websocket.b64_encoded_sha1sum (nonce ^ Websocket.websocket_uuid) -> + () + | _ -> failwith "wrong accept"); + () + + let connect' env sw url nonce extra_headers = + (* Make request *) + let headers = + Cohttp.Header.add_list extra_headers + [ + ("Upgrade", "websocket"); + ("Connection", "Upgrade"); + ("Sec-WebSocket-Key", nonce); + ("Sec-WebSocket-Version", "13"); + ] + in + let req = Cohttp.Request.make ~headers url in + + (* Make socket *) + let host = Uri.host url |> Option.get in + let service = Uri.scheme url |> Option.get in + let addr = + match Eio.Net.getaddrinfo_stream (Eio.Stdenv.net env) host ~service with + | [] -> failwith "getaddrinfo failed" + | addr :: _ -> addr + in + let socket = Eio.Net.connect ~sw (Eio.Stdenv.net env) addr in + let flow = + let authenticator = + let null_auth ?ip:_ ~host:_ _ = Ok None in + null_auth + in + let host = + Result.to_option + (Result.bind (Domain_name.of_string host) Domain_name.host) + in + Tls_eio.client_of_flow + Tls.Config.( + client ~version:(`TLS_1_0, `TLS_1_3) ~authenticator + ~ciphers:Ciphers.supported ()) + ?host socket + in + + (* Drain handshake *) + let ic = Eio.Buf_read.of_flow ~max_size:max_int flow in + Eio.Buf_write.with_flow flow (fun oc -> drain_handshake req ic oc nonce); + + (flow, ic) + + let connect ?(extra_headers = Cohttp.Header.init ()) ~sw env url = + let url = Uri.of_string url in + + let nonce = Base64.encode_exn (random_string 16) in + let flow, ic = connect' env sw url nonce extra_headers in + + (* Start writer fiber. All writes must be done in this fiber, + because Eio.Flow.write is not thread-safe. + c.f.: https://github.com/ocaml-multicore/eio/blob/v0.11/lib_eio/flow.mli#L73-L74 + *) + let write_queue = Eio.Stream.create 10 in + (let rec writer () = + try + let frame = Eio.Stream.take write_queue in + let buf = Buffer.create 128 in + write_frame_to_buf ~mode:(Client random_string) buf frame; + Eio.Buf_write.with_flow flow (fun oc -> + Eio.Buf_write.string oc (Buffer.contents buf)); + writer () + with Eio.Io _ -> () + in + Eio.Fiber.fork ~sw writer); + + let write_frame frame = Eio.Stream.add write_queue frame in + let read_frame () = + Eio.Buf_write.with_flow flow (fun oc -> + make_read_frame ~mode:(Client random_string) ic oc ()) + in + + { id = random_string 10; read_frame; write_frame } + + let id { id; _ } = id + let read { read_frame; _ } = read_frame () + let write { write_frame; _ } frame = write_frame frame +end