Skip to content

Commit

Permalink
Fix fault injection crash (#226)
Browse files Browse the repository at this point in the history
Signed-off-by: Alan Jowett <[email protected]>
Co-authored-by: Alan Jowett <[email protected]>
  • Loading branch information
Alan-Jowett and Alan Jowett authored Nov 4, 2024
1 parent ade618f commit 01df814
Showing 1 changed file with 46 additions and 48 deletions.
94 changes: 46 additions & 48 deletions src/ke.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include <format>
#include <mutex>
#include <optional>
#include <sstream>
#include <vector>
#undef ASSERT
Expand All @@ -30,7 +31,7 @@ static uint32_t _usersim_original_priority_class;
static std::vector<std::mutex> _usersim_dispatch_locks;

static TP_POOL* _usersim_threadpool = nullptr;
static TP_CALLBACK_ENVIRON _usersim_threadpool_callback_environment{};
static std::optional<TP_CALLBACK_ENVIRON> _usersim_threadpool_callback_environment = std::nullopt;

static NTSTATUS
_wait_for_kevent(_Inout_ KEVENT* event, _In_opt_ PLARGE_INTEGER timeout);
Expand All @@ -39,7 +40,6 @@ usersim_result_t
usersim_initialize_irql()
{
usersim_result_t result;
bool threadpool_environment_initialized = false;

_usersim_threadpool = CreateThreadpool(nullptr);
if (_usersim_threadpool == nullptr) {
Expand All @@ -53,10 +53,11 @@ usersim_initialize_irql()
goto Exit;
}

InitializeThreadpoolEnvironment(&_usersim_threadpool_callback_environment);
threadpool_environment_initialized = true;
_usersim_threadpool_callback_environment = std::make_optional<TP_CALLBACK_ENVIRON>();

SetThreadpoolCallbackPool(&_usersim_threadpool_callback_environment, _usersim_threadpool);
InitializeThreadpoolEnvironment(&_usersim_threadpool_callback_environment.value());

SetThreadpoolCallbackPool(&_usersim_threadpool_callback_environment.value(), _usersim_threadpool);

_usersim_original_priority_class = GetPriorityClass(GetCurrentProcess());
if (_usersim_original_priority_class == 0) {
Expand All @@ -74,17 +75,17 @@ usersim_initialize_irql()
result = STATUS_SUCCESS;

Exit:
if (result != STATUS_SUCCESS){
if (threadpool_environment_initialized) {
DestroyThreadpoolEnvironment(&_usersim_threadpool_callback_environment);
if (result != STATUS_SUCCESS) {
if (_usersim_threadpool_callback_environment.has_value()) {
DestroyThreadpoolEnvironment(&_usersim_threadpool_callback_environment.value());
_usersim_threadpool_callback_environment.reset();
}
if (_usersim_threadpool != nullptr) {
CloseThreadpool(_usersim_threadpool);
_usersim_threadpool = nullptr;
}
}


return result;
}

Expand All @@ -99,10 +100,14 @@ usersim_clean_up_irql()
_usersim_original_priority_class = 0;
}

DestroyThreadpoolEnvironment(&_usersim_threadpool_callback_environment);

CloseThreadpool(_usersim_threadpool);
_usersim_threadpool = nullptr;
if (_usersim_threadpool_callback_environment.has_value()) {
DestroyThreadpoolEnvironment(&_usersim_threadpool_callback_environment.value());
_usersim_threadpool_callback_environment.reset();
}
if (_usersim_threadpool != nullptr) {
CloseThreadpool(_usersim_threadpool);
_usersim_threadpool = nullptr;
}
}

const int _irql_thread_priority[3] = {
Expand Down Expand Up @@ -180,8 +185,6 @@ KeRaiseIrql(_In_ KIRQL new_irql, _Out_ PKIRQL old_irql)
*old_irql = KfRaiseIrql(new_irql);
}



_IRQL_requires_max_(HIGH_LEVEL) _IRQL_raises_(new_irql) _IRQL_saves_ KIRQL KfRaiseIrql(_In_ KIRQL new_irql)
{
KIRQL old_irql = KeGetCurrentIrql();
Expand Down Expand Up @@ -232,7 +235,8 @@ KeLowerIrql(_In_ KIRQL new_irql)
ASSERT(result);
}

void KfLowerIrql(_In_ KIRQL new_irql)
void
KfLowerIrql(_In_ KIRQL new_irql)
{
return KeLowerIrql(new_irql);
}
Expand All @@ -243,11 +247,13 @@ _IRQL_requires_min_(DISPATCH_LEVEL) NTKERNELAPI LOGICAL KeShouldYieldProcessor(V

void
KeEnterCriticalRegion(void)
{}
{
}

void
KeLeaveCriticalRegion(void)
{}
{
}

#pragma region spin_locks

Expand Down Expand Up @@ -316,16 +322,10 @@ ULONG
KeQueryMaximumProcessorCountEx(_In_ USHORT group_number) { return GetMaximumProcessorCount(group_number); }

ULONG
KeQueryActiveProcessorCount()
{
return KeQueryMaximumProcessorCount();
}
KeQueryActiveProcessorCount() { return KeQueryMaximumProcessorCount(); }

ULONG
KeQueryActiveProcessorCountEx(_In_ USHORT group_number)
{
return KeQueryMaximumProcessorCountEx(group_number);
}
KeQueryActiveProcessorCountEx(_In_ USHORT group_number) { return KeQueryMaximumProcessorCountEx(group_number); }

KAFFINITY
KeSetSystemAffinityThreadEx(KAFFINITY affinity)
Expand Down Expand Up @@ -353,24 +353,21 @@ _IRQL_requires_min_(PASSIVE_LEVEL) _IRQL_requires_max_(APC_LEVEL) NTKERNELAPI VO
KeSetSystemAffinityThreadEx(affinity);
}

void KeSetSystemGroupAffinityThread(
_In_ const PGROUP_AFFINITY Affinity,
_Out_opt_ PGROUP_AFFINITY PreviousAffinity
)
void
KeSetSystemGroupAffinityThread(_In_ const PGROUP_AFFINITY Affinity, _Out_opt_ PGROUP_AFFINITY PreviousAffinity)
{
if (!SetThreadGroupAffinity(GetCurrentThread(), Affinity, PreviousAffinity)) {
DWORD error = GetLastError();
#if defined(NDEBUG)
#if defined(NDEBUG)
UNREFERENCED_PARAMETER(error);
#else
#else
assert(error == 0);
#endif
#endif
}
}

void KeRevertToUserGroupAffinityThread(
PGROUP_AFFINITY PreviousAffinity
)
void
KeRevertToUserGroupAffinityThread(PGROUP_AFFINITY PreviousAffinity)
{
SetThreadGroupAffinity(GetCurrentThread(), PreviousAffinity, NULL);
}
Expand Down Expand Up @@ -875,7 +872,7 @@ KeSetCoalescableTimer(
BOOLEAN running = (timer->threadpool_timer != nullptr);
if (!running) {
timer->threadpool_timer =
CreateThreadpoolTimer(_usersim_timer_callback, timer, &_usersim_threadpool_callback_environment);
CreateThreadpoolTimer(_usersim_timer_callback, timer, &_usersim_threadpool_callback_environment.value());
if (timer->threadpool_timer == nullptr) {
KeBugCheck(0);
return FALSE; // Keep code analysis happy.
Expand Down Expand Up @@ -1051,23 +1048,24 @@ _wait_for_kevent(_Inout_ KEVENT* event, _In_opt_ PLARGE_INTEGER timeout)
}
}

NTSTATUS KeExpandKernelStackAndCalloutEx(
NTSTATUS
KeExpandKernelStackAndCalloutEx(
_In_ PEXPAND_STACK_CALLOUT Callout,
_In_opt_ PVOID Parameter,
_In_ SIZE_T Size,
_In_ BOOLEAN Wait,
_In_opt_ PVOID Context)
{
// This is a mock implementation of KeExpandKernelStackAndCalloutEx that does not
// actually expand the stack. This is sufficient for the purposes of the tests.
UNREFERENCED_PARAMETER(Size);
UNREFERENCED_PARAMETER(Wait);
UNREFERENCED_PARAMETER(Context);
{
// This is a mock implementation of KeExpandKernelStackAndCalloutEx that does not
// actually expand the stack. This is sufficient for the purposes of the tests.
UNREFERENCED_PARAMETER(Size);
UNREFERENCED_PARAMETER(Wait);
UNREFERENCED_PARAMETER(Context);

// Invoke the callout function.
Callout(Parameter);
// Invoke the callout function.
Callout(Parameter);

return STATUS_SUCCESS;
}
return STATUS_SUCCESS;
}

#pragma endregion events

0 comments on commit 01df814

Please sign in to comment.