Skip to content

Commit

Permalink
InlineHook/MidHook: Add enable() and disable()
Browse files Browse the repository at this point in the history
  • Loading branch information
cursey committed May 13, 2024
1 parent 4faf792 commit 1f835c8
Show file tree
Hide file tree
Showing 8 changed files with 276 additions and 51 deletions.
17 changes: 11 additions & 6 deletions include/safetyhook/easy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,34 @@ namespace safetyhook {
/// @brief Easy to use API for creating an InlineHook.
/// @param target The address of the function to hook.
/// @param destination The address of the destination function.
/// @param flags The flags to use.
/// @return The InlineHook object.
[[nodiscard]] InlineHook create_inline(void* target, void* destination);
[[nodiscard]] InlineHook create_inline(void* target, void* destination, InlineHook::Flags flags = InlineHook::Default);

/// @brief Easy to use API for creating an InlineHook.
/// @param target The address of the function to hook.
/// @param destination The address of the destination function.
/// @param flags The flags to use.
/// @return The InlineHook object.
[[nodiscard]] InlineHook create_inline(FnPtr auto target, FnPtr auto destination) {
return create_inline(reinterpret_cast<void*>(target), reinterpret_cast<void*>(destination));
[[nodiscard]] InlineHook create_inline(
FnPtr auto target, FnPtr auto destination, InlineHook::Flags flags = InlineHook::Default) {
return create_inline(reinterpret_cast<void*>(target), reinterpret_cast<void*>(destination), flags);
}

/// @brief Easy to use API for creating a MidHook.
/// @param target the address of the function to hook.
/// @param destination The destination function.
/// @param flags The flags to use.
/// @return The MidHook object.
[[nodiscard]] MidHook create_mid(void* target, MidHookFn destination);
[[nodiscard]] MidHook create_mid(void* target, MidHookFn destination, MidHook::Flags = MidHook::Default);

/// @brief Easy to use API for creating a MidHook.
/// @param target the address of the function to hook.
/// @param destination The destination function.
/// @param flags The flags to use.
/// @return The MidHook object.
[[nodiscard]] MidHook create_mid(FnPtr auto target, MidHookFn destination) {
return create_mid(reinterpret_cast<void*>(target), destination);
[[nodiscard]] MidHook create_mid(FnPtr auto target, MidHookFn destination, MidHook::Flags flags = MidHook::Default) {
return create_mid(reinterpret_cast<void*>(target), destination, flags);
}

/// @brief Easy to use API for creating a VmtHook.
Expand Down
41 changes: 35 additions & 6 deletions include/safetyhook/inline_hook.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,42 +87,54 @@ class InlineHook final {
[[nodiscard]] static Error not_enough_space(uint8_t* ip) { return {.type = NOT_ENOUGH_SPACE, .ip = ip}; }
};

/// @brief Flags for InlineHook.
enum Flags : int {
Default = 0, ///< Default flags.
StartDisabled = 1 << 0, ///< Start the hook disabled.
};

/// @brief Create an inline hook.
/// @param target The address of the function to hook.
/// @param destination The destination address.
/// @param flags The flags to use.
/// @return The InlineHook or an InlineHook::Error if an error occurred.
/// @note This will use the default global Allocator.
/// @note If you don't care about error handling, use the easy API (safetyhook::create_inline).
[[nodiscard]] static std::expected<InlineHook, Error> create(void* target, void* destination);
[[nodiscard]] static std::expected<InlineHook, Error> create(
void* target, void* destination, Flags flags = Default);

/// @brief Create an inline hook.
/// @param target The address of the function to hook.
/// @param destination The destination address.
/// @param flags The flags to use.
/// @return The InlineHook or an InlineHook::Error if an error occurred.
/// @note This will use the default global Allocator.
/// @note If you don't care about error handling, use the easy API (safetyhook::create_inline).
[[nodiscard]] static std::expected<InlineHook, Error> create(FnPtr auto target, FnPtr auto destination) {
return create(reinterpret_cast<void*>(target), reinterpret_cast<void*>(destination));
[[nodiscard]] static std::expected<InlineHook, Error> create(
FnPtr auto target, FnPtr auto destination, Flags flags = Default) {
return create(reinterpret_cast<void*>(target), reinterpret_cast<void*>(destination), flags);
}

/// @brief Create an inline hook with a given Allocator.
/// @param allocator The allocator to use.
/// @param target The address of the function to hook.
/// @param destination The destination address.
/// @param flags The flags to use.
/// @return The InlineHook or an InlineHook::Error if an error occurred.
/// @note If you don't care about error handling, use the easy API (safetyhook::create_inline).
[[nodiscard]] static std::expected<InlineHook, Error> create(
const std::shared_ptr<Allocator>& allocator, void* target, void* destination);
const std::shared_ptr<Allocator>& allocator, void* target, void* destination, Flags flags = Default);

/// @brief Create an inline hook with a given Allocator.
/// @param allocator The allocator to use.
/// @param target The address of the function to hook.
/// @param destination The destination address.
/// @param flags The flags to use.
/// @return The InlineHook or an InlineHook::Error if an error occurred.
/// @note If you don't care about error handling, use the easy API (safetyhook::create_inline).
[[nodiscard]] static std::expected<InlineHook, Error> create(
const std::shared_ptr<Allocator>& allocator, FnPtr auto target, FnPtr auto destination) {
return create(allocator, reinterpret_cast<void*>(target), reinterpret_cast<void*>(destination));
const std::shared_ptr<Allocator>& allocator, FnPtr auto target, FnPtr auto destination, Flags flags = Default) {
return create(allocator, reinterpret_cast<void*>(target), reinterpret_cast<void*>(destination), flags);
}

InlineHook() = default;
Expand Down Expand Up @@ -285,15 +297,32 @@ class InlineHook final {
return original<RetT(SAFETYHOOK_FASTCALL*)(Args...)>()(args...);
}

/// @brief Enable the hook.
[[nodiscard]] std::expected<void, Error> enable();

/// @brief Disable the hook.
[[nodiscard]] std::expected<void, Error> disable();

/// @brief Check if the hook is enabled.
[[nodiscard]] bool enabled() const { return m_enabled; }

private:
friend class MidHook;

enum class Type {
Unset,
E9,
FF,
};

uint8_t* m_target{};
uint8_t* m_destination{};
Allocation m_trampoline{};
std::vector<uint8_t> m_original_bytes{};
uintptr_t m_trampoline_size{};
std::recursive_mutex m_mutex{};
bool m_enabled{};
Type m_type{Type::Unset};

std::expected<void, Error> setup(
const std::shared_ptr<Allocator>& allocator, uint8_t* target, uint8_t* destination);
Expand Down
35 changes: 28 additions & 7 deletions include/safetyhook/mid_hook.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,43 +52,55 @@ class MidHook final {
}
};

/// @brief Flags for MidHook.
enum Flags : int {
Default = 0, ///< Default flags.
StartDisabled = 1, ///< Start the hook disabled.
};

/// @brief Creates a new MidHook object.
/// @param target The address of the function to hook.
/// @param destination_fn The destination function.
/// @param flags The flags to use.
/// @return The MidHook object or a MidHook::Error if an error occurred.
/// @note This will use the default global Allocator.
/// @note If you don't care about error handling, use the easy API (safetyhook::create_mid).
[[nodiscard]] static std::expected<MidHook, Error> create(void* target, MidHookFn destination_fn);
[[nodiscard]] static std::expected<MidHook, Error> create(
void* target, MidHookFn destination_fn, Flags flags = Default);

/// @brief Creates a new MidHook object.
/// @param target The address of the function to hook.
/// @param destination_fn The destination function.
/// @param flags The flags to use.
/// @return The MidHook object or a MidHook::Error if an error occurred.
/// @note This will use the default global Allocator.
/// @note If you don't care about error handling, use the easy API (safetyhook::create_mid).
[[nodiscard]] static std::expected<MidHook, Error> create(FnPtr auto target, MidHookFn destination_fn) {
return create(reinterpret_cast<void*>(target), destination_fn);
[[nodiscard]] static std::expected<MidHook, Error> create(
FnPtr auto target, MidHookFn destination_fn, Flags flags = Default) {
return create(reinterpret_cast<void*>(target), destination_fn, flags);
}

/// @brief Creates a new MidHook object with a given Allocator.
/// @param allocator The Allocator to use.
/// @param target The address of the function to hook.
/// @param destination_fn The destination function.
/// @param flags The flags to use.
/// @return The MidHook object or a MidHook::Error if an error occurred.
/// @note If you don't care about error handling, use the easy API (safetyhook::create_mid).
[[nodiscard]] static std::expected<MidHook, Error> create(
const std::shared_ptr<Allocator>& allocator, void* target, MidHookFn destination_fn);
const std::shared_ptr<Allocator>& allocator, void* target, MidHookFn destination_fn, Flags flags = Default);

/// @brief Creates a new MidHook object with a given Allocator.
/// @tparam T The type of the function to hook.
/// @param allocator The Allocator to use.
/// @param target The address of the function to hook.
/// @param destination_fn The destination function.
/// @param flags The flags to use.
/// @return The MidHook object or a MidHook::Error if an error occurred.
/// @note If you don't care about error handling, use the easy API (safetyhook::create_mid).
[[nodiscard]] static std::expected<MidHook, Error> create(
const std::shared_ptr<Allocator>& allocator, FnPtr auto target, MidHookFn destination_fn) {
return create(allocator, reinterpret_cast<void*>(target), destination_fn);
[[nodiscard]] static std::expected<MidHook, Error> create(const std::shared_ptr<Allocator>& allocator,
FnPtr auto target, MidHookFn destination_fn, Flags flags = Default) {
return create(allocator, reinterpret_cast<void*>(target), destination_fn, flags);
}

MidHook() = default;
Expand Down Expand Up @@ -123,6 +135,15 @@ class MidHook final {
/// @return true if the hook is valid, false otherwise.
explicit operator bool() const { return static_cast<bool>(m_stub); }

/// @brief Enable the hook.
[[nodiscard]] std::expected<void, Error> enable();

/// @brief Disable the hook.
[[nodiscard]] std::expected<void, Error> disable();

/// @brief Check if the hook is enabled.
[[nodiscard]] bool enabled() const { return m_hook.enabled(); }

private:
InlineHook m_hook{};
uint8_t* m_target{};
Expand Down
8 changes: 4 additions & 4 deletions src/easy.cpp
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
#include "safetyhook/easy.hpp"

namespace safetyhook {
InlineHook create_inline(void* target, void* destination) {
if (auto hook = InlineHook::create(target, destination)) {
InlineHook create_inline(void* target, void* destination, InlineHook::Flags flags) {
if (auto hook = InlineHook::create(target, destination, flags)) {
return std::move(*hook);
} else {
return {};
}
}

MidHook create_mid(void* target, MidHookFn destination) {
if (auto hook = MidHook::create(target, destination)) {
MidHook create_mid(void* target, MidHookFn destination, MidHook::Flags flags) {
if (auto hook = MidHook::create(target, destination, flags)) {
return std::move(*hook);
} else {
return {};
Expand Down
88 changes: 64 additions & 24 deletions src/inline_hook.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,12 @@ static bool decode(ZydisDecodedInstruction* ix, uint8_t* ip) {
return ZYAN_SUCCESS(ZydisDecoderDecodeInstruction(&decoder, nullptr, ip, 15, ix));
}

std::expected<InlineHook, InlineHook::Error> InlineHook::create(void* target, void* destination) {
return create(Allocator::global(), target, destination);
std::expected<InlineHook, InlineHook::Error> InlineHook::create(void* target, void* destination, Flags flags) {
return create(Allocator::global(), target, destination, flags);
}

std::expected<InlineHook, InlineHook::Error> InlineHook::create(
const std::shared_ptr<Allocator>& allocator, void* target, void* destination) {
const std::shared_ptr<Allocator>& allocator, void* target, void* destination, Flags flags) {
InlineHook hook{};

if (const auto setup_result =
Expand All @@ -128,6 +128,12 @@ std::expected<InlineHook, InlineHook::Error> InlineHook::create(
return std::unexpected{setup_result.error()};
}

if (!(flags & StartDisabled)) {
if (auto enable_result = hook.enable(); !enable_result) {
return std::unexpected{enable_result.error()};
}
}

return hook;
}

Expand All @@ -146,10 +152,14 @@ InlineHook& InlineHook::operator=(InlineHook&& other) noexcept {
m_trampoline = std::move(other.m_trampoline);
m_trampoline_size = other.m_trampoline_size;
m_original_bytes = std::move(other.m_original_bytes);
m_enabled = other.m_enabled;
m_type = other.m_type;

other.m_target = nullptr;
other.m_destination = nullptr;
other.m_trampoline_size = 0;
other.m_enabled = false;
other.m_type = Type::Unset;
}

return *this;
Expand Down Expand Up @@ -305,20 +315,7 @@ std::expected<void, InlineHook::Error> InlineHook::e9_hook(const std::shared_ptr
}
#endif

std::optional<Error> error;

// jmp from original to trampoline.
trap_threads(m_target, m_trampoline.data(), m_original_bytes.size(), [this, &trampoline_epilogue, &error] {
if (auto result = emit_jmp_e9(m_target, reinterpret_cast<uint8_t*>(&trampoline_epilogue->jmp_to_destination),
m_original_bytes.size());
!result) {
error = result.error();
}
});

if (error) {
return std::unexpected{*error};
}
m_type = Type::E9;

return {};
}
Expand Down Expand Up @@ -367,34 +364,77 @@ std::expected<void, InlineHook::Error> InlineHook::ff_hook(const std::shared_ptr
return std::unexpected{result.error()};
}

m_type = Type::FF;

return {};
}
#endif

std::expected<void, InlineHook::Error> InlineHook::enable() {
std::scoped_lock lock{m_mutex};

if (m_enabled) {
return {};
}

std::optional<Error> error;

// jmp from original to trampoline.
trap_threads(m_target, m_trampoline.data(), m_original_bytes.size(), [this, &error] {
if (auto result = emit_jmp_ff(m_target, m_destination, m_target + sizeof(JmpFF), m_original_bytes.size());
!result) {
error = result.error();
if (m_type == Type::E9) {
auto trampoline_epilogue = reinterpret_cast<TrampolineEpilogueE9*>(
m_trampoline.address() + m_trampoline_size - sizeof(TrampolineEpilogueE9));

if (auto result = emit_jmp_e9(m_target,
reinterpret_cast<uint8_t*>(&trampoline_epilogue->jmp_to_destination), m_original_bytes.size());
!result) {
error = result.error();
}
}

#if SAFETYHOOK_ARCH_X86_64
if (m_type == Type::FF) {
if (auto result = emit_jmp_ff(m_target, m_destination, m_target + sizeof(JmpFF), m_original_bytes.size());
!result) {
error = result.error();
}
}
#endif
});

if (error) {
return std::unexpected{*error};
}

m_enabled = true;

return {};
}
#endif

void InlineHook::destroy() {
std::expected<void, InlineHook::Error> InlineHook::disable() {
std::scoped_lock lock{m_mutex};

if (!m_trampoline) {
return;
if (!m_enabled) {
return {};
}

trap_threads(m_trampoline.data(), m_target, m_original_bytes.size(),
[this] { std::copy(m_original_bytes.begin(), m_original_bytes.end(), m_target); });

m_enabled = false;

return {};
}

void InlineHook::destroy() {
[[maybe_unused]] auto disable_result = disable();

std::scoped_lock lock{m_mutex};

if (!m_trampoline) {
return;
}

m_trampoline.free();
}
} // namespace safetyhook
Loading

0 comments on commit 1f835c8

Please sign in to comment.