Skip to content

Commit

Permalink
Added support to run unit tests in a multithreaded context
Browse files Browse the repository at this point in the history
- This is controlled by specifying the 'test_multithreaded' argument
  when running `unit_test`.
- The goal is to detect if the operator/transformation  fails in this
  context.
- In this mode, the test will be executed 5'000 times in 50 threads
  concurrently.
- Allocation & initialization of the operator/transformation is
  performed once in the main thread, while the evaluation is executed in
  the threads.
  - This is consistent with the library's support for multithreading,
    where initialization and loading of rules is expected to run once.
    See issue #3215.
  • Loading branch information
eduar-hte committed Aug 9, 2024
1 parent 7bdc3c8 commit 7a6052a
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 69 deletions.
19 changes: 11 additions & 8 deletions test/common/modsecurity_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,24 +93,24 @@ bool ModSecurityTest<T>::load_test_json(const std::string &file) {


template <class T>
std::pair<std::string, std::vector<T *>>*
void
ModSecurityTest<T>::load_tests(const std::string &path) {
DIR *dir;
struct dirent *ent;
struct stat buffer;

if ((dir = opendir(path.c_str())) == NULL) {
if ((dir = opendir(path.c_str())) == nullptr) {
/* if target is a file, use it as a single test. */
if (stat(path.c_str(), &buffer) == 0) {
if (load_test_json(path) == false) {
std::cout << "Problems loading from: " << path;
std::cout << std::endl;
}
}
return NULL;
return;
}

while ((ent = readdir(dir)) != NULL) {
while ((ent = readdir(dir)) != nullptr) {
std::string filename = ent->d_name;
std::string json = ".json";
if (filename.size() < json.size()
Expand All @@ -123,16 +123,15 @@ ModSecurityTest<T>::load_tests(const std::string &path) {
}
}
closedir(dir);

return NULL;
}


template <class T>
std::pair<std::string, std::vector<T *>>* ModSecurityTest<T>::load_tests() {
return load_tests(this->target);
void ModSecurityTest<T>::load_tests() {
load_tests(this->target);
}


template <class T>
void ModSecurityTest<T>::cmd_options(int argc, char **argv) {
int i = 1;
Expand All @@ -144,6 +143,10 @@ void ModSecurityTest<T>::cmd_options(int argc, char **argv) {
i++;
m_count_all = true;
}
if (argc > i && strcmp(argv[i], "test_multithreaded") == 0) {
i++;
m_test_multithreaded = true;
}
if (std::getenv("AUTOMAKE_TESTS")) {
m_automake_output = true;
}
Expand Down
8 changes: 5 additions & 3 deletions test/common/modsecurity_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,13 @@ template <class T> class ModSecurityTest :
ModSecurityTest()
: m_test_number(0),
m_automake_output(false),
m_count_all(false) { }
m_count_all(false),
m_test_multithreaded(false) { }

std::string header();
void cmd_options(int, char **);
std::pair<std::string, std::vector<T *>>* load_tests();
std::pair<std::string, std::vector<T *>>* load_tests(const std::string &path);
void load_tests();
void load_tests(const std::string &path);
bool load_test_json(const std::string &file);

std::string target;
Expand All @@ -48,6 +49,7 @@ template <class T> class ModSecurityTest :
int m_test_number;
bool m_automake_output;
bool m_count_all;
bool m_test_multithreaded;
};

} // namespace modsecurity_test
Expand Down
190 changes: 138 additions & 52 deletions test/unit/unit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

#include <string.h>
#include <cstring>

#include <cassert>
#include <thread>
#include <array>
#include <iostream>
#include <ctime>
#include <string>
Expand All @@ -38,6 +40,7 @@


using modsecurity_test::UnitTest;
using modsecurity_test::UnitTestResult;
using modsecurity_test::ModSecurityTest;
using modsecurity_test::ModSecurityTestResults;
using modsecurity::actions::transformations::Transformation;
Expand All @@ -53,64 +56,149 @@ void print_help() {
}


void perform_unit_test(ModSecurityTest<UnitTest> *test, UnitTest *t,
ModSecurityTestResults<UnitTest>* res) {
std::string error;
struct OperatorTest {
using ItemType = Operator;

static ItemType* init(const UnitTest &t) {
auto op = Operator::instantiate(t.name, t.param);
assert(op != nullptr);

std::string error;
op->init(t.filename, &error);

return op;
}

static UnitTestResult eval(ItemType &op, const UnitTest &t) {
return {op.evaluate(nullptr, nullptr, t.input, nullptr), {}};
}

static bool check(const UnitTestResult &result, const UnitTest &t) {
return result.ret != t.ret;
}
};


struct TransformationTest {
using ItemType = Transformation;

static ItemType* init(const UnitTest &t) {
auto tfn = Transformation::instantiate("t:" + t.name);
assert(tfn != nullptr);

return tfn;
}

static UnitTestResult eval(ItemType &tfn, const UnitTest &t) {
return {1, tfn.evaluate(t.input, nullptr)};
}

static bool check(const UnitTestResult &result, const UnitTest &t) {
return result.output != t.output;
}
};


template<typename TestType>
UnitTestResult perform_unit_test_once(UnitTest &t) {
std::unique_ptr<typename TestType::ItemType> item(TestType::init(t));
assert(item.get() != nullptr);

return TestType::eval(*item.get(), t);
}


template<typename TestType>
UnitTestResult perform_unit_test_multithreaded(UnitTest &t) {

constexpr auto NUM_THREADS = 50;
constexpr auto ITERATIONS = 5'000;

std::array<std::thread, NUM_THREADS> threads;
std::array<UnitTestResult, NUM_THREADS> results;

std::unique_ptr<typename TestType::ItemType> item(TestType::init(t));
assert(item.get() != nullptr);

for (auto i = 0; i != threads.size(); ++i)
{
auto &result = results[i];
threads[i] = std::thread(
[&item, &t, &result]()
{
for (auto j = 0; j != ITERATIONS; ++j)
result = TestType::eval(*item.get(), t);
});
}

UnitTestResult ret;

for (auto i = 0; i != threads.size(); ++i)
{
threads[i].join();
if (TestType::check(results[i], t))
ret = results[i]; // error value, keep iterating to join all threads
else if(i == 0)
ret = results[i]; // initial value
}

return ret;
}


template<typename TestType>
void perform_unit_test_helper(const ModSecurityTest<UnitTest> &test, UnitTest &t,
ModSecurityTestResults<UnitTest> &res) {

if (!test.m_test_multithreaded)
t.result = perform_unit_test_once<TestType>(t);
else
t.result = perform_unit_test_multithreaded<TestType>(t);

if (TestType::check(t.result, t)) {
res.push_back(&t);
if (test.m_automake_output) {
std::cout << "FAIL ";
}
} else if (test.m_automake_output) {
std::cout << "PASS ";
}
}


void perform_unit_test(const ModSecurityTest<UnitTest> &test, UnitTest &t,
ModSecurityTestResults<UnitTest> &res) {
bool found = true;

if (test->m_automake_output) {
if (test.m_automake_output) {
std::cout << ":test-result: ";
}

if (t->resource.empty() == false) {
found = (std::find(resources.begin(), resources.end(), t->resource)
!= resources.end());
if (t.resource.empty() == false) {
found = std::find(resources.begin(), resources.end(), t.resource)
!= resources.end();
}

if (!found) {
t->skipped = true;
res->push_back(t);
if (test->m_automake_output) {
t.skipped = true;
res.push_back(&t);
if (test.m_automake_output) {
std::cout << "SKIP ";
}
}

if (t->type == "op") {
Operator *op = Operator::instantiate(t->name, t->param);
op->init(t->filename, &error);
int ret = op->evaluate(NULL, NULL, t->input, NULL);
t->obtained = ret;
if (ret != t->ret) {
res->push_back(t);
if (test->m_automake_output) {
std::cout << "FAIL ";
}
} else if (test->m_automake_output) {
std::cout << "PASS ";
}
delete op;
} else if (t->type == "tfn") {
Transformation *tfn = Transformation::instantiate("t:" + t->name);
std::string ret = tfn->evaluate(t->input, NULL);
t->obtained = 1;
t->obtainedOutput = ret;
if (ret != t->output) {
res->push_back(t);
if (test->m_automake_output) {
std::cout << "FAIL ";
}
} else if (test->m_automake_output) {
std::cout << "PASS ";
}
delete tfn;
if (t.type == "op") {
perform_unit_test_helper<OperatorTest>(test, t, res);
} else if (t.type == "tfn") {
perform_unit_test_helper<TransformationTest>(test, t, res);
} else {
std::cerr << "Failed. Test type is unknown: << " << t->type;
std::cerr << "Failed. Test type is unknown: << " << t.type;
std::cerr << std::endl;
}

if (test->m_automake_output) {
std::cout << t->name << " "
<< modsecurity::utils::string::toHexIfNeeded(t->input)
if (test.m_automake_output) {
std::cout << t.name << " "
<< modsecurity::utils::string::toHexIfNeeded(t.input)
<< std::endl;
}
}
Expand Down Expand Up @@ -151,17 +239,15 @@ int main(int argc, char **argv) {
test.load_tests("test-cases/secrules-language-tests/transformations");
}

for (std::pair<std::string, std::vector<UnitTest *> *> a : test) {
std::vector<UnitTest *> *tests = a.second;

for (auto& [filename, tests] : test) {
total += tests->size();
for (UnitTest *t : *tests) {
for (auto t : *tests) {
ModSecurityTestResults<UnitTest> r;

if (!test.m_automake_output) {
std::cout << " " << a.first << "...\t";
std::cout << " " << filename << "...\t";
}
perform_unit_test(&test, t, &r);
perform_unit_test(test, *t, r);

if (!test.m_automake_output) {
int skp = 0;
Expand Down Expand Up @@ -191,7 +277,7 @@ int main(int argc, char **argv) {
std::cout << "Total >> " << total << std::endl;
}

for (UnitTest *t : results) {
for (const auto t : results) {
std::cout << t->print() << std::endl;
}

Expand All @@ -216,8 +302,8 @@ int main(int argc, char **argv) {
}

for (auto a : test) {
auto *vec = a.second;
for(auto *t : *vec)
auto vec = a.second;
for(auto t : *vec)
delete t;
delete vec;
}
Expand Down
8 changes: 4 additions & 4 deletions test/unit/unit_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,15 +102,15 @@ std::string UnitTest::print() {
i << " \"param\": \"" << this->param << "\"" << std::endl;
i << " \"output\": \"" << this->output << "\"" << std::endl;
i << "}" << std::endl;
if (this->ret != this->obtained) {
if (this->ret != this->result.ret) {
i << "Expecting: \"" << this->ret << "\" - returned: \"";
i << this->obtained << "\"" << std::endl;
i << this->result.ret << "\"" << std::endl;
}
if (this->output != this->obtainedOutput) {
if (this->output != this->result.output) {
i << "Expecting: \"";
i << modsecurity::utils::string::toHexIfNeeded(this->output);
i << "\" - returned: \"";
i << modsecurity::utils::string::toHexIfNeeded(this->obtainedOutput);
i << modsecurity::utils::string::toHexIfNeeded(this->result.output);
i << "\"";
i << std::endl;
}
Expand Down
9 changes: 7 additions & 2 deletions test/unit/unit_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@

namespace modsecurity_test {

class UnitTestResult {
public:
int ret;
std::string output;
};

class UnitTest {
public:
static UnitTest *from_yajl_node(const yajl_val &);
Expand All @@ -39,9 +45,8 @@ class UnitTest {
std::string filename;
std::string output;
int ret;
int obtained;
int skipped;
std::string obtainedOutput;
UnitTestResult result;
};

} // namespace modsecurity_test
Expand Down

0 comments on commit 7a6052a

Please sign in to comment.