Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix fault injection crash #226

Merged
merged 1 commit into from
Nov 4, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 46 additions & 48 deletions src/ke.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include <format>
#include <mutex>
#include <optional>
#include <sstream>
#include <vector>
#undef ASSERT
Expand All @@ -30,7 +31,7 @@ static uint32_t _usersim_original_priority_class;
static std::vector<std::mutex> _usersim_dispatch_locks;

static TP_POOL* _usersim_threadpool = nullptr;
static TP_CALLBACK_ENVIRON _usersim_threadpool_callback_environment{};
static std::optional<TP_CALLBACK_ENVIRON> _usersim_threadpool_callback_environment = std::nullopt;

static NTSTATUS
_wait_for_kevent(_Inout_ KEVENT* event, _In_opt_ PLARGE_INTEGER timeout);
Expand All @@ -39,7 +40,6 @@ usersim_result_t
usersim_initialize_irql()
{
usersim_result_t result;
bool threadpool_environment_initialized = false;

_usersim_threadpool = CreateThreadpool(nullptr);
if (_usersim_threadpool == nullptr) {
Expand All @@ -53,10 +53,11 @@ usersim_initialize_irql()
goto Exit;
}

InitializeThreadpoolEnvironment(&_usersim_threadpool_callback_environment);
threadpool_environment_initialized = true;
_usersim_threadpool_callback_environment = std::make_optional<TP_CALLBACK_ENVIRON>();

SetThreadpoolCallbackPool(&_usersim_threadpool_callback_environment, _usersim_threadpool);
InitializeThreadpoolEnvironment(&_usersim_threadpool_callback_environment.value());

SetThreadpoolCallbackPool(&_usersim_threadpool_callback_environment.value(), _usersim_threadpool);

_usersim_original_priority_class = GetPriorityClass(GetCurrentProcess());
if (_usersim_original_priority_class == 0) {
Expand All @@ -74,17 +75,17 @@ usersim_initialize_irql()
result = STATUS_SUCCESS;

Exit:
if (result != STATUS_SUCCESS){
if (threadpool_environment_initialized) {
DestroyThreadpoolEnvironment(&_usersim_threadpool_callback_environment);
if (result != STATUS_SUCCESS) {
if (_usersim_threadpool_callback_environment.has_value()) {
DestroyThreadpoolEnvironment(&_usersim_threadpool_callback_environment.value());
_usersim_threadpool_callback_environment.reset();
}
if (_usersim_threadpool != nullptr) {
CloseThreadpool(_usersim_threadpool);
_usersim_threadpool = nullptr;
}
}


return result;
}

Expand All @@ -99,10 +100,14 @@ usersim_clean_up_irql()
_usersim_original_priority_class = 0;
}

DestroyThreadpoolEnvironment(&_usersim_threadpool_callback_environment);

CloseThreadpool(_usersim_threadpool);
_usersim_threadpool = nullptr;
if (_usersim_threadpool_callback_environment.has_value()) {
DestroyThreadpoolEnvironment(&_usersim_threadpool_callback_environment.value());
_usersim_threadpool_callback_environment.reset();
}
if (_usersim_threadpool != nullptr) {
CloseThreadpool(_usersim_threadpool);
_usersim_threadpool = nullptr;
}
}

const int _irql_thread_priority[3] = {
Expand Down Expand Up @@ -180,8 +185,6 @@ KeRaiseIrql(_In_ KIRQL new_irql, _Out_ PKIRQL old_irql)
*old_irql = KfRaiseIrql(new_irql);
}



_IRQL_requires_max_(HIGH_LEVEL) _IRQL_raises_(new_irql) _IRQL_saves_ KIRQL KfRaiseIrql(_In_ KIRQL new_irql)
{
KIRQL old_irql = KeGetCurrentIrql();
Expand Down Expand Up @@ -232,7 +235,8 @@ KeLowerIrql(_In_ KIRQL new_irql)
ASSERT(result);
}

void KfLowerIrql(_In_ KIRQL new_irql)
void
KfLowerIrql(_In_ KIRQL new_irql)
{
return KeLowerIrql(new_irql);
}
Expand All @@ -243,11 +247,13 @@ _IRQL_requires_min_(DISPATCH_LEVEL) NTKERNELAPI LOGICAL KeShouldYieldProcessor(V

void
KeEnterCriticalRegion(void)
{}
{
}

void
KeLeaveCriticalRegion(void)
{}
{
}

#pragma region spin_locks

Expand Down Expand Up @@ -316,16 +322,10 @@ ULONG
KeQueryMaximumProcessorCountEx(_In_ USHORT group_number) { return GetMaximumProcessorCount(group_number); }

ULONG
KeQueryActiveProcessorCount()
{
return KeQueryMaximumProcessorCount();
}
KeQueryActiveProcessorCount() { return KeQueryMaximumProcessorCount(); }

ULONG
KeQueryActiveProcessorCountEx(_In_ USHORT group_number)
{
return KeQueryMaximumProcessorCountEx(group_number);
}
KeQueryActiveProcessorCountEx(_In_ USHORT group_number) { return KeQueryMaximumProcessorCountEx(group_number); }

KAFFINITY
KeSetSystemAffinityThreadEx(KAFFINITY affinity)
Expand Down Expand Up @@ -353,24 +353,21 @@ _IRQL_requires_min_(PASSIVE_LEVEL) _IRQL_requires_max_(APC_LEVEL) NTKERNELAPI VO
KeSetSystemAffinityThreadEx(affinity);
}

void KeSetSystemGroupAffinityThread(
_In_ const PGROUP_AFFINITY Affinity,
_Out_opt_ PGROUP_AFFINITY PreviousAffinity
)
void
KeSetSystemGroupAffinityThread(_In_ const PGROUP_AFFINITY Affinity, _Out_opt_ PGROUP_AFFINITY PreviousAffinity)
{
if (!SetThreadGroupAffinity(GetCurrentThread(), Affinity, PreviousAffinity)) {
DWORD error = GetLastError();
#if defined(NDEBUG)
#if defined(NDEBUG)
UNREFERENCED_PARAMETER(error);
#else
#else
assert(error == 0);
#endif
#endif
}
}

void KeRevertToUserGroupAffinityThread(
PGROUP_AFFINITY PreviousAffinity
)
void
KeRevertToUserGroupAffinityThread(PGROUP_AFFINITY PreviousAffinity)
{
SetThreadGroupAffinity(GetCurrentThread(), PreviousAffinity, NULL);
}
Expand Down Expand Up @@ -875,7 +872,7 @@ KeSetCoalescableTimer(
BOOLEAN running = (timer->threadpool_timer != nullptr);
if (!running) {
timer->threadpool_timer =
CreateThreadpoolTimer(_usersim_timer_callback, timer, &_usersim_threadpool_callback_environment);
CreateThreadpoolTimer(_usersim_timer_callback, timer, &_usersim_threadpool_callback_environment.value());
if (timer->threadpool_timer == nullptr) {
KeBugCheck(0);
return FALSE; // Keep code analysis happy.
Expand Down Expand Up @@ -1051,23 +1048,24 @@ _wait_for_kevent(_Inout_ KEVENT* event, _In_opt_ PLARGE_INTEGER timeout)
}
}

NTSTATUS KeExpandKernelStackAndCalloutEx(
NTSTATUS
KeExpandKernelStackAndCalloutEx(
_In_ PEXPAND_STACK_CALLOUT Callout,
_In_opt_ PVOID Parameter,
_In_ SIZE_T Size,
_In_ BOOLEAN Wait,
_In_opt_ PVOID Context)
{
// This is a mock implementation of KeExpandKernelStackAndCalloutEx that does not
// actually expand the stack. This is sufficient for the purposes of the tests.
UNREFERENCED_PARAMETER(Size);
UNREFERENCED_PARAMETER(Wait);
UNREFERENCED_PARAMETER(Context);
{
// This is a mock implementation of KeExpandKernelStackAndCalloutEx that does not
// actually expand the stack. This is sufficient for the purposes of the tests.
UNREFERENCED_PARAMETER(Size);
UNREFERENCED_PARAMETER(Wait);
UNREFERENCED_PARAMETER(Context);

// Invoke the callout function.
Callout(Parameter);
// Invoke the callout function.
Callout(Parameter);

return STATUS_SUCCESS;
}
return STATUS_SUCCESS;
}

#pragma endregion events