diff --git a/README.md b/README.md index 85d3d9779c..30bf3b1282 100644 --- a/README.md +++ b/README.md @@ -333,6 +333,14 @@ if (cli.send(requests, responses)) { } ``` +### Redirect + +```cpp +httplib::Client cli("yahoo.com"); +cli.follow_location(true); +auto ret = cli.Get("/"); +``` + OpenSSL Support --------------- diff --git a/httplib.h b/httplib.h index 26e1081928..f9d59a1832 100644 --- a/httplib.h +++ b/httplib.h @@ -113,6 +113,7 @@ inline const unsigned char *ASN1_STRING_get0_data(const ASN1_STRING *asn1) { #define CPPHTTPLIB_READ_TIMEOUT_SECOND 5 #define CPPHTTPLIB_READ_TIMEOUT_USECOND 0 #define CPPHTTPLIB_REQUEST_URI_MAX_LENGTH 8192 +#define CPPHTTPLIB_REDIRECT_MAX_COUNT 20 #define CPPHTTPLIB_PAYLOAD_MAX_LENGTH (std::numeric_limits::max)() #define CPPHTTPLIB_RECV_BUFSIZ size_t(4096u) #define CPPHTTPLIB_THREAD_POOL_COUNT 8 @@ -146,8 +147,8 @@ typedef std::function ContentProvider; -typedef std::function +typedef std::function ContentReceiver; typedef std::function Progress; @@ -172,17 +173,21 @@ typedef std::pair Range; typedef std::vector Ranges; struct Request { - std::string version; std::string method; - std::string target; std::string path; Headers headers; std::string body; + + // for server + std::string version; + std::string target; Params params; MultipartFiles files; Ranges ranges; Match matches; + // for client + size_t redirect_count = CPPHTTPLIB_REDIRECT_MAX_COUNT; ContentReceiver content_receiver; Progress progress; @@ -491,77 +496,104 @@ class Client { virtual bool is_valid() const; - std::shared_ptr Get(const char *path, Progress progress = nullptr); + std::shared_ptr Get(const char *path); + + std::shared_ptr Get(const char *path, const Headers &headers); + + std::shared_ptr Get(const char *path, Progress progress); + std::shared_ptr Get(const char *path, const Headers &headers, - Progress progress = nullptr); + Progress progress); std::shared_ptr Get(const char *path, - ContentReceiver content_receiver, - Progress progress = nullptr); + ContentReceiver content_receiver); + + std::shared_ptr Get(const char *path, const Headers &headers, + ContentReceiver content_receiver); + + std::shared_ptr + Get(const char *path, ContentReceiver content_receiver, Progress progress); + std::shared_ptr Get(const char *path, const Headers &headers, ContentReceiver content_receiver, - Progress progress = nullptr); + Progress progress); std::shared_ptr Head(const char *path); + std::shared_ptr Head(const char *path, const Headers &headers); std::shared_ptr Post(const char *path, const std::string &body, const char *content_type); + std::shared_ptr Post(const char *path, const Headers &headers, const std::string &body, const char *content_type); std::shared_ptr Post(const char *path, const Params ¶ms); + std::shared_ptr Post(const char *path, const Headers &headers, const Params ¶ms); std::shared_ptr Post(const char *path, const MultipartFormDataItems &items); + std::shared_ptr Post(const char *path, const Headers &headers, const MultipartFormDataItems &items); std::shared_ptr Put(const char *path, const std::string &body, const char *content_type); + std::shared_ptr Put(const char *path, const Headers &headers, const std::string &body, const char *content_type); std::shared_ptr Patch(const char *path, const std::string &body, const char *content_type); + std::shared_ptr Patch(const char *path, const Headers &headers, const std::string &body, const char *content_type); - std::shared_ptr Delete(const char *path, - const std::string &body = std::string(), - const char *content_type = nullptr); + std::shared_ptr Delete(const char *path); + + std::shared_ptr Delete(const char *path, const std::string &body, + const char *content_type); + + std::shared_ptr Delete(const char *path, const Headers &headers); + std::shared_ptr Delete(const char *path, const Headers &headers, - const std::string &body = std::string(), - const char *content_type = nullptr); + const std::string &body, + const char *content_type); std::shared_ptr Options(const char *path); + std::shared_ptr Options(const char *path, const Headers &headers); - bool send(Request &req, Response &res); + bool send(const Request &req, Response &res); - bool send(std::vector &requests, std::vector& responses); + bool send(const std::vector &requests, + std::vector &responses); void set_keep_alive_max_count(size_t count); + void follow_location(bool on); + protected: - bool process_request(Stream &strm, Request &req, Response &res, - bool &connection_close); + bool process_request(Stream &strm, const Request &req, Response &res, + bool last_connection, bool &connection_close); const std::string host_; const int port_; time_t timeout_sec_; const std::string host_and_port_; size_t keep_alive_max_count_; + size_t follow_location_; private: socket_t create_client_socket() const; bool read_response_line(Stream &strm, Response &res); - void write_request(Stream &strm, Request &req); + void write_request(Stream &strm, const Request &req, bool last_connection); + bool redirect(const Request &req, Response &res); virtual bool process_and_close_socket( socket_t sock, size_t request_count, @@ -572,7 +604,8 @@ class Client { virtual bool is_ssl() const; }; -inline void Get(std::vector &requests, const char *path, const Headers &headers) { +inline void Get(std::vector &requests, const char *path, + const Headers &headers) { Request req; req.method = "GET"; req.path = path; @@ -584,7 +617,9 @@ inline void Get(std::vector &requests, const char *path) { Get(requests, path, Headers()); } -inline void Post(std::vector &requests, const char *path, const Headers &headers, const std::string &body, const char *content_type) { +inline void Post(std::vector &requests, const char *path, + const Headers &headers, const std::string &body, + const char *content_type) { Request req; req.method = "POST"; req.path = path; @@ -594,7 +629,8 @@ inline void Post(std::vector &requests, const char *path, const Headers requests.emplace_back(std::move(req)); } -inline void Post(std::vector &requests, const char *path, const std::string &body, const char *content_type) { +inline void Post(std::vector &requests, const char *path, + const std::string &body, const char *content_type) { Post(requests, path, Headers(), body, content_type); } @@ -1443,7 +1479,8 @@ bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status, return ret; } -template inline int write_headers(Stream &strm, const T &info) { +template +inline int write_headers(Stream &strm, const T &info, const Headers &headers) { auto write_len = 0; for (const auto &x : info.headers) { auto len = @@ -1451,6 +1488,12 @@ template inline int write_headers(Stream &strm, const T &info) { if (len < 0) { return len; } write_len += len; } + for (const auto &x : headers) { + auto len = + strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str()); + if (len < 0) { return len; } + write_len += len; + } auto len = strm.write("\r\n"); if (len < 0) { return len; } write_len += len; @@ -1458,7 +1501,7 @@ template inline int write_headers(Stream &strm, const T &info) { } inline ssize_t write_content(Stream &strm, ContentProvider content_provider, - size_t offset, size_t length) { + size_t offset, size_t length) { size_t begin_offset = offset; size_t end_offset = offset + length; while (offset < end_offset) { @@ -1476,7 +1519,7 @@ inline ssize_t write_content(Stream &strm, ContentProvider content_provider, } inline ssize_t write_content_chunked(Stream &strm, - ContentProvider content_provider) { + ContentProvider content_provider) { size_t offset = 0; auto data_available = true; ssize_t total_written_length = 0; @@ -1503,6 +1546,24 @@ inline ssize_t write_content_chunked(Stream &strm, return total_written_length; } +template +inline bool redirect(T &cli, const Request &req, Response &res, + const std::string &path) { + Request new_req; + new_req.method = req.method; + new_req.path = path; + new_req.headers = req.headers; + new_req.body = req.body; + new_req.redirect_count = req.redirect_count - 1; + new_req.content_receiver = req.content_receiver; + new_req.progress = req.progress; + + Response new_res; + auto ret = cli.send(new_req, new_res); + if (ret) { res = new_res; } + return ret; +} + inline std::string encode_url(const std::string &s) { std::string result; @@ -1674,23 +1735,27 @@ inline bool parse_range_header(const std::string &s, Ranges &ranges) { if (std::regex_match(s, m, re)) { auto pos = m.position(1); auto len = m.length(1); - detail::split( - &s[pos], &s[pos + len], ',', [&](const char *b, const char *e) { - static auto re = std::regex(R"(\s*(\d*)-(\d*))"); - std::cmatch m; - if (std::regex_match(b, e, m, re)) { - ssize_t first = -1; - if (!m.str(1).empty()) { first = static_cast(std::stoll(m.str(1))); } - - ssize_t last = -1; - if (!m.str(2).empty()) { last = static_cast(std::stoll(m.str(2))); } - - if (first != -1 && last != -1 && first > last) { - throw std::runtime_error("invalid range error"); - } - ranges.emplace_back(std::make_pair(first, last)); - } - }); + detail::split(&s[pos], &s[pos + len], ',', + [&](const char *b, const char *e) { + static auto re = std::regex(R"(\s*(\d*)-(\d*))"); + std::cmatch m; + if (std::regex_match(b, e, m, re)) { + ssize_t first = -1; + if (!m.str(1).empty()) { + first = static_cast(std::stoll(m.str(1))); + } + + ssize_t last = -1; + if (!m.str(2).empty()) { + last = static_cast(std::stoll(m.str(2))); + } + + if (first != -1 && last != -1 && first > last) { + throw std::runtime_error("invalid range error"); + } + ranges.emplace_back(std::make_pair(first, last)); + } + }); return true; } return false; @@ -1742,8 +1807,7 @@ get_range_offset_and_length(const Request &req, size_t content_length, return std::make_pair(r.first, r.second - r.first + 1); } -inline std::string make_content_range_header_field(size_t offset, - size_t length, +inline std::string make_content_range_header_field(size_t offset, size_t length, size_t content_length) { std::string field = "bytes "; field += std::to_string(offset); @@ -1988,9 +2052,8 @@ inline void Response::set_chunked_content_provider( std::function provider, std::function resource_releaser) { content_provider_resource_length = 0; - content_provider = [provider](size_t offset, size_t, DataSink sink, Done done) { - provider(offset, sink, done); - }; + content_provider = [provider](size_t offset, size_t, DataSink sink, + Done done) { provider(offset, sink, done); }; content_provider_resource_releaser = resource_releaser; } @@ -2300,7 +2363,7 @@ inline bool Server::write_response(Stream &strm, bool last_connection, res.set_header("Content-Length", length); } - if (!detail::write_headers(strm, res)) { return false; } + if (!detail::write_headers(strm, res, Headers())) { return false; } // Body if (req.method != "HEAD") { @@ -2591,7 +2654,8 @@ inline bool Server::process_and_close_socket(socket_t sock) { inline Client::Client(const char *host, int port, time_t timeout_sec) : host_(host), port_(port), timeout_sec_(timeout_sec), host_and_port_(host_ + ":" + std::to_string(port_)), - keep_alive_max_count_(CPPHTTPLIB_KEEPALIVE_MAX_COUNT) {} + keep_alive_max_count_(CPPHTTPLIB_KEEPALIVE_MAX_COUNT), + follow_location_(false) {} inline Client::~Client() {} @@ -2635,20 +2699,27 @@ inline bool Client::read_response_line(Stream &strm, Response &res) { return true; } -inline bool Client::send(Request &req, Response &res) { +inline bool Client::send(const Request &req, Response &res) { if (req.path.empty()) { return false; } auto sock = create_client_socket(); if (sock == INVALID_SOCKET) { return false; } - return process_and_close_socket( - sock, 1, - [&](Stream &strm, bool /*last_connection*/, bool &connection_close) { - return process_request(strm, req, res, connection_close); + auto ret = process_and_close_socket( + sock, 1, [&](Stream &strm, bool last_connection, bool &connection_close) { + return process_request(strm, req, res, last_connection, + connection_close); }); + + if (ret && follow_location_ && (300 < res.status && res.status < 400)) { + ret = redirect(req, res); + } + + return ret; } -inline bool Client::send(std::vector &requests, std::vector& responses) { +inline bool Client::send(const std::vector &requests, + std::vector &responses) { size_t i = 0; while (i < requests.size()) { auto sock = create_client_socket(); @@ -2662,9 +2733,16 @@ inline bool Client::send(std::vector &requests, std::vector& i++; if (req.path.empty()) { return false; } - if (last_connection) { req.set_header("Connection", "close"); } - auto ret = process_request(strm, req, res, connection_close); + auto ret = process_request(strm, req, res, last_connection, + connection_close); + + if (ret && follow_location_ && + (300 < res.status && res.status < 400)) { + ret = redirect(req, res); + } + if (ret) { responses.emplace_back(std::move(res)); } + return ret; })) { return false; @@ -2674,7 +2752,48 @@ inline bool Client::send(std::vector &requests, std::vector& return true; } -inline void Client::write_request(Stream &strm, Request &req) { +inline bool Client::redirect(const Request &req, Response &res) { + if (req.redirect_count == 0) { return false; } + + auto location = res.get_header_value("location"); + if (location.empty()) { return false; } + + std::regex re( + R"(^(?:([^:/?#]+):)?(?://([^/?#]*))?([^?#]*(?:\?[^#]*)?)(?:#.*)?)"); + + auto scheme = is_ssl() ? "https" : "http"; + + std::smatch m; + if (regex_match(location, m, re)) { + auto next_scheme = m[1].str(); + auto next_host = m[2].str(); + auto next_path = m[3].str(); + if (next_host.empty()) { next_host = host_; } + if (next_path.empty()) { next_path = "/"; } + + if (next_scheme == scheme && next_host == host_) { + return detail::redirect(*this, req, res, next_path); + } else { + if (next_scheme == "https") { +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + SSLClient cli(next_host.c_str()); + cli.follow_location(true); + return detail::redirect(cli, req, res, next_path); +#else + return false; +#endif + } else { + Client cli(next_host.c_str()); + cli.follow_location(true); + return detail::redirect(cli, req, res, next_path); + } + } + } + return false; +} + +inline void Client::write_request(Stream &strm, const Request &req, + bool last_connection) { BufferStream bstrm; // Request line @@ -2682,45 +2801,48 @@ inline void Client::write_request(Stream &strm, Request &req) { bstrm.write_format("%s %s HTTP/1.1\r\n", req.method.c_str(), path.c_str()); - // Headers + // Additonal headers + Headers headers; + if (last_connection) { headers.emplace("Connection", "close"); } + if (!req.has_header("Host")) { if (is_ssl()) { if (port_ == 443) { - req.set_header("Host", host_); + headers.emplace("Host", host_); } else { - req.set_header("Host", host_and_port_); + headers.emplace("Host", host_and_port_); } } else { if (port_ == 80) { - req.set_header("Host", host_); + headers.emplace("Host", host_); } else { - req.set_header("Host", host_and_port_); + headers.emplace("Host", host_and_port_); } } } - if (!req.has_header("Accept")) { req.set_header("Accept", "*/*"); } + if (!req.has_header("Accept")) { headers.emplace("Accept", "*/*"); } if (!req.has_header("User-Agent")) { - req.set_header("User-Agent", "cpp-httplib/0.2"); + headers.emplace("User-Agent", "cpp-httplib/0.2"); } if (req.body.empty()) { if (req.method == "POST" || req.method == "PUT" || req.method == "PATCH") { - req.set_header("Content-Length", "0"); + headers.emplace("Content-Length", "0"); } } else { if (!req.has_header("Content-Type")) { - req.set_header("Content-Type", "text/plain"); + headers.emplace("Content-Type", "text/plain"); } if (!req.has_header("Content-Length")) { auto length = std::to_string(req.body.size()); - req.set_header("Content-Length", length); + headers.emplace("Content-Length", length); } } - detail::write_headers(bstrm, req); + detail::write_headers(bstrm, req, headers); // Body if (!req.body.empty()) { bstrm.write(req.body); } @@ -2730,10 +2852,11 @@ inline void Client::write_request(Stream &strm, Request &req) { strm.write(data.data(), data.size()); } -inline bool Client::process_request(Stream &strm, Request &req, Response &res, +inline bool Client::process_request(Stream &strm, const Request &req, + Response &res, bool last_connection, bool &connection_close) { // Send request - write_request(strm, req); + write_request(strm, req, last_connection); // Receive response and headers if (!read_response_line(strm, res) || @@ -2749,9 +2872,7 @@ inline bool Client::process_request(Stream &strm, Request &req, Response &res, // Body if (req.method != "HEAD") { detail::ContentReceiverCore out = [&](const char *buf, size_t n) { - if (res.body.size() + n > res.body.max_size()) { - return false; - } + if (res.body.size() + n > res.body.max_size()) { return false; } res.body.append(buf, n); return true; }; @@ -2788,11 +2909,22 @@ inline bool Client::process_and_close_socket( inline bool Client::is_ssl() const { return false; } +inline std::shared_ptr Client::Get(const char *path) { + Progress dummy; + return Get(path, Headers(), dummy); +} + inline std::shared_ptr Client::Get(const char *path, Progress progress) { return Get(path, Headers(), progress); } +inline std::shared_ptr Client::Get(const char *path, + const Headers &headers) { + Progress dummy; + return Get(path, headers, dummy); +} + inline std::shared_ptr Client::Get(const char *path, const Headers &headers, Progress progress) { Request req; @@ -2805,12 +2937,25 @@ Client::Get(const char *path, const Headers &headers, Progress progress) { return send(req, *res) ? res : nullptr; } +inline std::shared_ptr Client::Get(const char *path, + ContentReceiver content_receiver) { + Progress dummy; + return Get(path, Headers(), content_receiver, dummy); +} + inline std::shared_ptr Client::Get(const char *path, ContentReceiver content_receiver, Progress progress) { return Get(path, Headers(), content_receiver, progress); } +inline std::shared_ptr Client::Get(const char *path, + const Headers &headers, + ContentReceiver content_receiver) { + Progress dummy; + return Get(path, headers, content_receiver, dummy); +} + inline std::shared_ptr Client::Get(const char *path, const Headers &headers, ContentReceiver content_receiver, @@ -2968,12 +3113,21 @@ inline std::shared_ptr Client::Patch(const char *path, return send(req, *res) ? res : nullptr; } +inline std::shared_ptr Client::Delete(const char *path) { + return Delete(path, Headers(), std::string(), nullptr); +} + inline std::shared_ptr Client::Delete(const char *path, const std::string &body, const char *content_type) { return Delete(path, Headers(), body, content_type); } +inline std::shared_ptr Client::Delete(const char *path, + const Headers &headers) { + return Delete(path, headers, std::string(), nullptr); +} + inline std::shared_ptr Client::Delete(const char *path, const Headers &headers, const std::string &body, @@ -3011,6 +3165,8 @@ inline void Client::set_keep_alive_max_count(size_t count) { keep_alive_max_count_ = count; } +inline void Client::follow_location(bool on) { follow_location_ = on; } + /* * SSL Implementation */ diff --git a/test/test.cc b/test/test.cc index 5eb57dd273..a5f569b593 100644 --- a/test/test.cc +++ b/test/test.cc @@ -431,6 +431,78 @@ TEST(BaseAuthTest, FromHTTPWatch) { } } +TEST(AbsoluteRedirectTest, Redirect) { + auto host = "httpbin.org"; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + httplib::SSLClient cli(host); +#else + httplib::Client cli(host); +#endif + + cli.follow_location(true); + auto ret = cli.Get("/absolute-redirect/3"); + ASSERT_TRUE(ret != nullptr); +} + +TEST(RedirectTest, Redirect) { + auto host = "httpbin.org"; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + httplib::SSLClient cli(host); +#else + httplib::Client cli(host); +#endif + + cli.follow_location(true); + auto ret = cli.Get("/redirect/3"); + ASSERT_TRUE(ret != nullptr); +} + +TEST(RelativeRedirectTest, Redirect) { + auto host = "httpbin.org"; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + httplib::SSLClient cli(host); +#else + httplib::Client cli(host); +#endif + + cli.follow_location(true); + auto ret = cli.Get("/relative-redirect/3"); + ASSERT_TRUE(ret != nullptr); +} + +TEST(TooManyRedirectTest, Redirect) { + auto host = "httpbin.org"; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + httplib::SSLClient cli(host); +#else + httplib::Client cli(host); +#endif + + cli.follow_location(true); + auto ret = cli.Get("/redirect/21"); + ASSERT_TRUE(ret == nullptr); +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +TEST(YahooRedirectTest, Redirect) { + httplib::Client cli("yahoo.com"); + cli.follow_location(true); + auto ret = cli.Get("/"); + ASSERT_TRUE(ret != nullptr); +} + +TEST(Https2HttpRedirectTest, Redirect) { + httplib::SSLClient cli("httpbin.org"); + cli.follow_location(true); + auto ret = cli.Get("/redirect-to?url=http%3A%2F%2Fwww.google.com&status_code=302"); + ASSERT_TRUE(ret != nullptr); +} +#endif + TEST(Server, BindAndListenSeparately) { Server svr; int port = svr.bind_to_any_port("localhost");