diff --git a/include/oas_validator.hpp b/include/oas_validator.hpp index eec0eac..650c633 100644 --- a/include/oas_validator.hpp +++ b/include/oas_validator.hpp @@ -18,6 +18,7 @@ #include #include #include +#include class ValidatorInitExc; ///< Forward declaration for the custom exception class. class OASValidatorImp; ///< Forward declaration for the implementation class. @@ -56,13 +57,29 @@ class OASValidator public: /** - * @brief Constructor that takes the path to the OAS specification file. + * @brief Constructor that takes the path to the OAS specification file and an optional method mapping. + * * @param oas_specs File path to the OAS specification in JSON format or JSON string containing the OAS * specification. * - * @note The OAS specification can be provided as a file path or as a JSON string. + * @param method_map An optional unordered_map where each key is an HTTP method and the value is an unordered_set + * of methods that can be treated as the key method. This allows certain HTTP methods to be treated as others. + * + * For example: + * @code + * std::unordered_map> method_map = { + * {"OPTIONS", {"GET", "POST", "PUT", "DELETE", "HEAD", "PATCH"}}, + * {"HEAD", {"GET"}} // Treat HEAD request as GET + * }; + * OASValidator validator(oas_specs, method_map); + * @endcode + * + * @note The OAS specification can be provided as a file path or as a JSON string. If the method map is provided, + * it allows certain HTTP methods to be treated as others. For instance, with the mapping {"HEAD", {"GET"}}, + * a HEAD request can be validated as the GET request, if HEAD method is not defined. */ - explicit OASValidator(const std::string& oas_specs); + explicit OASValidator(const std::string& oas_specs, + const std::unordered_map>& method_map = {}); /** * @brief Copy constructor. diff --git a/include/oas_validator_imp.hpp b/include/oas_validator_imp.hpp index 312db9b..eb07707 100644 --- a/include/oas_validator_imp.hpp +++ b/include/oas_validator_imp.hpp @@ -15,7 +15,8 @@ class OASValidatorImp { public: - explicit OASValidatorImp(const std::string& oas_specs, bool head_mapped_get = false); + explicit OASValidatorImp(const std::string& oas_specs, + const std::unordered_map>& method_map = {}); ValidationError ValidateRoute(const std::string& method, const std::string& http_path, std::string& error_msg); ValidationError ValidateBody(const std::string& method, const std::string& http_path, const std::string& json_body, std::string& error_msg); @@ -45,17 +46,17 @@ class OASValidatorImp PathTrie path_trie{}; }; - bool head_mapped_get_; + const std::unordered_map> method_map_; std::array(HttpMethod::COUNT)> oas_validators_{}; MethodValidator method_validator_{}; - ValidationError GetValidators(const std::string& method, const std::string& http_path, ValidatorsStore*& validators, - std::string& error_msg, std::unordered_map* param_idxs = nullptr, - std::string* query = nullptr); ValidationError GetValidators(const std::string& method, const std::string& mapped_method, const std::string& http_path, ValidatorsStore*& validators, std::string& error_msg, std::unordered_map* param_idxs = nullptr, std::string* query = nullptr); + ValidationError GetValidators(const std::string& method, const std::string& http_path, ValidatorsStore*& validators, + std::string& error_msg, std::unordered_map* param_idxs = nullptr, + std::string* query = nullptr); static std::vector Split(const std::string& str); static rapidjson::Value* ResolvePath(rapidjson::Document& doc, const std::string& path); static void ParseSpecs(const std::string& oas_specs, rapidjson::Document& doc); diff --git a/src/oas_validator.cpp b/src/oas_validator.cpp index efcccf4..a378e3d 100644 --- a/src/oas_validator.cpp +++ b/src/oas_validator.cpp @@ -8,8 +8,9 @@ #include "oas_validator_imp.hpp" -OASValidator::OASValidator(const std::string& oas_specs) - : impl_(new OASValidatorImp(oas_specs)) +OASValidator::OASValidator(const std::string& oas_specs, + const std::unordered_map>& method_map) + : impl_(new OASValidatorImp(oas_specs, method_map)) { } diff --git a/src/oas_validator_imp.cpp b/src/oas_validator_imp.cpp index 2d4f2f1..e9fa0dc 100644 --- a/src/oas_validator_imp.cpp +++ b/src/oas_validator_imp.cpp @@ -9,8 +9,9 @@ #include #include -OASValidatorImp::OASValidatorImp(const std::string& oas_specs, bool head_mapped_get) - : head_mapped_get_(head_mapped_get) +OASValidatorImp::OASValidatorImp(const std::string& oas_specs, + const std::unordered_map>& method_map) + : method_map_(method_map) { rapidjson::Document doc; ParseSpecs(oas_specs, doc); @@ -177,9 +178,18 @@ ValidationError OASValidatorImp::GetValidators(const std::string& method, const CHECK_ERROR(err_code) err_code = GetValidators(method, method, http_path, validators, error_msg, param_idxs, query); - - if (head_mapped_get_ && ValidationError::INVALID_ROUTE == err_code && (method == "head" || method == "HEAD")) { - return GetValidators(method, "get", http_path, validators, error_msg, param_idxs, query); + if (ValidationError::INVALID_ROUTE == err_code) { + try { + auto mapped_methods = method_map_.at(method); + for (const auto& mapped_method : mapped_methods) { + err_code = GetValidators(method, mapped_method, http_path, validators, error_msg, param_idxs, query); + if (ValidationError::NONE == err_code) { + return err_code; + } + } + } catch (const std::out_of_range&) { + return err_code; + } } return err_code;