diff --git a/main.cpp b/main.cpp index f4ecab8..74f94a9 100644 --- a/main.cpp +++ b/main.cpp @@ -1,25 +1,33 @@ // 小彭老师作业05:假装是多线程 HTTP 服务器 - 富连网大厂面试官觉得很赞 #include #include +#include #include #include #include #include #include +#include +#include +#include +typedef std::shared_lock RLock; +typedef std::lock_guard WLock; struct User { std::string password; std::string school; std::string phone; }; +std::shared_mutex s_mtx; std::map users; -std::map has_login; // 换成 std::chrono::seconds 之类的 +std::map > has_login; // 换成 std::chrono::seconds 之类的 // 作业要求1:把这些函数变成多线程安全的 // 提示:能正确利用 shared_mutex 加分,用 lock_guard 系列加分 std::string do_register(std::string username, std::string password, std::string school, std::string phone) { + WLock lock_guard(s_mtx); User user = {password, school, phone}; if (users.emplace(username, user).second) return "注册成功"; @@ -28,10 +36,12 @@ std::string do_register(std::string username, std::string password, std::string } std::string do_login(std::string username, std::string password) { + RLock lock_guard(s_mtx); // 作业要求2:把这个登录计时器改成基于 chrono 的 - long now = time(NULL); // C 语言当前时间 + // long now = time(NULL); // C 语言当前时间 + std::chrono::time_point now = std::chrono::system_clock::now(); if (has_login.find(username) != has_login.end()) { - int sec = now - has_login.at(username); // C 语言算时间差 + int sec = std::chrono::duration_cast(now - has_login.at(username)).count(); return std::to_string(sec) + "秒内登录过"; } has_login[username] = now; @@ -44,6 +54,8 @@ std::string do_login(std::string username, std::string password) { } std::string do_queryuser(std::string username) { + RLock lock_guard(s_mtx); + if(users.find(username) == users.end()) return ""; auto &user = users.at(username); std::stringstream ss; ss << "用户名: " << username << std::endl; @@ -53,15 +65,186 @@ std::string do_queryuser(std::string username) { } -struct ThreadPool { - void create(std::function start) { - // 作业要求3:如何让这个线程保持在后台执行不要退出? - // 提示:改成 async 和 future 且用法正确也可以加分 - std::thread thr(start); - } -}; +namespace thread_pool{ + + class join_threads { + private: + std::vector &threads; + public: + explicit join_threads(std::vector &threads_) : + threads(threads_) {} + + ~join_threads() { + for (unsigned long i = 0; i < threads.size(); ++i) { + if (threads[i].joinable()) threads[i].join(); + } + } + + }; + + // push使用head_mutex,pop使用tail_mutex,使用细粒度锁可以提高并行。 + template + class queue{ + private: + struct node{ + std::shared_ptr data; + std::unique_ptr next; + }; + std::mutex head_mtx_; + std::mutex tail_mtx_; + std::unique_ptr head_; + node *tail_; + std::condition_variable data_cond_; + + node* get_tail(){ + std::lock_guard lockGuard(tail_mtx_); + return tail_; + } + std::unique_ptr pop_head(){ + std::unique_ptr old_head = std::move(head_); + head_ = std::move(old_head->next); + return old_head; + } + + std::unique_ptr try_pop_head(){ + std::lock_guard lockGuard(head_mtx_); + if (head_.get() == get_tail()){ + return std::unique_ptr(); + } + return pop_head(); + } + std::unique_ptr try_pop_head(T &value){ + std::lock_guard lockGuard(head_mtx_); + if (head_.get() == get_tail()){ + return std::unique_ptr(); + } + value = std::move(*head_->data); + return pop_head(); + } + + public: + queue(): head_(new node), tail_(head_.get()){} + queue(const queue &other) = delete; + queue& operator=(const queue &other) = delete; + + void push(const T &&new_value){ + std::shared_ptr new_data = std::make_shared(new_value); + std::unique_ptr tmp(new node); + node* new_tail = tmp.get(); + { + std::lock_guard lockGuard(tail_mtx_); + tail_->data = new_data; + tail_->next = std::move(tmp); + tail_ = new_tail; + } + data_cond_.notify_all(); + } + + void push(T &&new_value){ + std::shared_ptr new_data = std::make_shared(std::move(new_value)); + std::unique_ptr tmp(new node); + node* new_tail = tmp.get(); + { + std::lock_guard lockGuard(tail_mtx_); + tail_->data = new_data; + tail_->next = std::move(tmp); + tail_ = new_tail; + } + data_cond_.notify_all(); + } + + std::shared_ptr try_pop(){ + std::unique_ptr old_head = try_pop_head(); + return old_head ? old_head->data : std::shared_ptr(); + } + + bool try_pop(T &value){ + std::unique_ptr old_head = try_pop_head(value); + return old_head != nullptr; + } + + + std::shared_ptr wait_and_pop(){ + std::unique_lock uniqueLock(head_mtx_); + while (head_.get() == get_tail()){ + data_cond_.wait(uniqueLock); + } + std::unique_ptr old_head = pop_head(); + return old_head->data; + } + void wait_and_pop(T &value){ + std::unique_lock uniqueLock(head_mtx_); + while (head_.get() == get_tail()){ + data_cond_.wait(uniqueLock); + } + std::unique_ptr old_head = pop_head(); + value = std::move(*old_head->data); + } -ThreadPool tpool; + bool empty(){ + std::lock_guard lockGuard(head_mtx_); + return head_.get() == get_tail(); + } + }; + + class thread_pool { + + private: + + // TODO the declaration order is important. + std::atomic_bool done; + queue > work_queue; + std::vector threads; + join_threads joiner; + + void worker_thread(){ + while (!done){ + std::function task; + if (work_queue.try_pop(task)){ + // std::cout << "task\n"; + task(); + }else{ +// std::this_thread::sleep_for(std::chrono::seconds(1)); +// std::cout << "yield\n"; + std::this_thread::yield(); + + } + } + } + + public: + thread_pool(): done(false), joiner(threads){ + unsigned const thread_count = std::thread::hardware_concurrency(); + + try { + for (int i = 0; i < thread_count; ++i) { + threads.emplace_back(&thread_pool::worker_thread, this); // add work thread + } + } catch (...) { + done = true; + throw; + } + } + + ~thread_pool(){ + while(!work_queue.empty()){ + std::this_thread::yield(); + } + done = true; // TODO, when done is set to true, + // the worker thread will exit, even there are still tasks in work_queue. + } + + void create(std::function f){ + work_queue.push(std::move(f)); + } + + }; + +} + + + +thread_pool::thread_pool tpool; namespace test { // 测试用例?出水用力! @@ -72,6 +255,10 @@ std::string phone[] = {"110", "119", "120", "12315"}; } int main() { + std::cout << do_register(test::username[rand() % 4], test::password[rand() % 4], test::school[rand() % 4], test::phone[rand() % 4]) << std::endl; + std::cout << do_login(test::username[rand() % 4], test::password[rand() % 4]) << std::endl; + std::cout << do_queryuser(test::username[rand() % 4]) << std::endl; + for (int i = 0; i < 262144; i++) { tpool.create([&] { std::cout << do_register(test::username[rand() % 4], test::password[rand() % 4], test::school[rand() % 4], test::phone[rand() % 4]) << std::endl; @@ -83,6 +270,7 @@ int main() { std::cout << do_queryuser(test::username[rand() % 4]) << std::endl; }); } + // std::this_thread::sleep_for(std::chrono::seconds(100)); // 作业要求4:等待 tpool 中所有线程都结束后再退出 return 0; diff --git a/pr.md b/pr.md new file mode 100644 index 0000000..ac01af3 --- /dev/null +++ b/pr.md @@ -0,0 +1,4 @@ + +- 使用shared_mutex区分读和写,std::shared_lock是读锁,std::lock_guard是写锁。 +- 借鉴Cpp并发编程中ThreadPool的写法,创建了一个固定线程数量的线程。 +- 为了让tpool中所有线程都结束后再析构tpool,在tpool的析构函数里检查当前任务队列是否为空。 \ No newline at end of file