Skip to content

Commit

Permalink
Feat/Opaque thread freezer (#56)
Browse files Browse the repository at this point in the history
  • Loading branch information
cursey authored Jan 30, 2024
1 parent ac0f562 commit 6e8ed61
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 21 deletions.
18 changes: 7 additions & 11 deletions include/safetyhook/thread_freezer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,24 @@
#include <cstdint>
#include <functional>

#if __has_include(<Windows.h>)
#include <Windows.h>
#elif __has_include(<windows.h>)
#include <windows.h>
#else
#error "Windows.h not found"
#endif

namespace safetyhook {
using ThreadId = uint32_t;
using ThreadHandle = void*;
using ThreadContext = void*;

/// @brief Executes a function while all other threads are frozen. Also allows for visiting each frozen thread and
/// modifying it's context.
/// @param run_fn The function to run while all other threads are frozen.
/// @param visit_fn The function that will be called for each frozen thread.
/// @note The visit function will be called in the order that the threads were frozen.
/// @note The visit function will be called before the run function.
/// @note Keep the logic inside run_fn and visit_fn as simple as possible to avoid deadlocks.
void execute_while_frozen(
const std::function<void()>& run_fn, const std::function<void(uint32_t, HANDLE, CONTEXT&)>& visit_fn = {});
void execute_while_frozen(const std::function<void()>& run_fn,
const std::function<void(ThreadId, ThreadHandle, ThreadContext)>& visit_fn = {});

/// @brief Will modify the context of a thread's IP to point to a new address if its IP is at the old address.
/// @param ctx The thread context to modify.
/// @param old_ip The old IP address.
/// @param new_ip The new IP address.
void fix_ip(CONTEXT& ctx, uint8_t* old_ip, uint8_t* new_ip);
void fix_ip(ThreadContext ctx, uint8_t* old_ip, uint8_t* new_ip);
} // namespace safetyhook
6 changes: 3 additions & 3 deletions src/inline_hook.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ std::expected<void, InlineHook::Error> InlineHook::e9_hook(const std::shared_ptr
error = result.error();
}
},
[this](uint32_t, HANDLE, CONTEXT& ctx) {
[this](auto, auto, auto ctx) {
for (size_t i = 0; i < m_original_bytes.size(); ++i) {
fix_ip(ctx, m_target + i, m_trampoline.data() + i);
}
Expand Down Expand Up @@ -404,7 +404,7 @@ std::expected<void, InlineHook::Error> InlineHook::ff_hook(const std::shared_ptr
error = result.error();
}
},
[this](uint32_t, HANDLE, CONTEXT& ctx) {
[this](auto, auto, auto ctx) {
for (size_t i = 0; i < m_original_bytes.size(); ++i) {
fix_ip(ctx, m_target + i, m_trampoline.data() + i);
}
Expand All @@ -431,7 +431,7 @@ void InlineHook::destroy() {
std::copy(m_original_bytes.begin(), m_original_bytes.end(), m_target);
}
},
[this](uint32_t, HANDLE, CONTEXT& ctx) {
[this](auto, auto, auto ctx) {
for (size_t i = 0; i < m_original_bytes.size(); ++i) {
fix_ip(ctx, m_trampoline.data() + i, m_target + i);
}
Expand Down
17 changes: 10 additions & 7 deletions src/thread_freezer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ NtGetNextThread(HANDLE ProcessHandle, HANDLE ThreadHandle, ACCESS_MASK DesiredAc

namespace safetyhook {
void execute_while_frozen(
const std::function<void()>& run_fn, const std::function<void(uint32_t, HANDLE, CONTEXT&)>& visit_fn) {
const std::function<void()>& run_fn, const std::function<void(ThreadId, ThreadHandle, ThreadContext)>& visit_fn) {
// Freeze all threads.
int num_threads_frozen;
auto first_run = true;
Expand Down Expand Up @@ -73,7 +73,8 @@ void execute_while_frozen(
}

if (visit_fn) {
visit_fn(thread_id, thread, thread_ctx);
visit_fn(static_cast<ThreadId>(thread_id), static_cast<ThreadHandle>(thread),
static_cast<ThreadContext>(&thread_ctx));
}

++num_threads_frozen;
Expand Down Expand Up @@ -116,21 +117,23 @@ void execute_while_frozen(
}
}

void fix_ip(CONTEXT& ctx, uint8_t* old_ip, uint8_t* new_ip) {
void fix_ip(ThreadContext thread_ctx, uint8_t* old_ip, uint8_t* new_ip) {
auto* ctx = reinterpret_cast<CONTEXT*>(thread_ctx);

#ifdef _M_X64
auto ip = ctx.Rip;
auto ip = ctx->Rip;
#else
auto ip = ctx.Eip;
auto ip = ctx->Eip;
#endif

if (ip == reinterpret_cast<uintptr_t>(old_ip)) {
ip = reinterpret_cast<uintptr_t>(new_ip);
}

#ifdef _M_X64
ctx.Rip = ip;
ctx->Rip = ip;
#else
ctx.Eip = ip;
ctx->Eip = ip;
#endif
}
} // namespace safetyhook

0 comments on commit 6e8ed61

Please sign in to comment.