diff --git a/inc/usersim/ex.h b/inc/usersim/ex.h index d5a7f7d..4ba1e38 100644 --- a/inc/usersim/ex.h +++ b/inc/usersim/ex.h @@ -36,8 +36,12 @@ typedef struct _EX_SPIN_LOCK { SRWLOCK lock; } EX_SPIN_LOCK; + typedef cxplat_rundown_reference_t EX_RUNDOWN_REF; +// Initial usersim version of this API will not be cache aware and will use the same type as EX_RUNDOWN_REF. +typedef cxplat_rundown_reference_t EX_RUNDOWN_REF_CACHE_AWARE; + // // Pool Allocation routines (in pool.c) // @@ -104,6 +108,26 @@ USERSIM_API void ExReleaseRundownProtection(_Inout_ EX_RUNDOWN_REF* rundown_ref); +USERSIM_API +EX_RUNDOWN_REF_CACHE_AWARE* +ExAllocateCacheAwareRundownProtection( + _In_ __drv_strictTypeMatch(__drv_typeExpr) POOL_TYPE PoolType, unsigned long PoolTag); + +USERSIM_API +BOOLEAN ExAcquireRundownProtectionCacheAware( + _Inout_ EX_RUNDOWN_REF_CACHE_AWARE* RunRefCacheAware +); + +USERSIM_API +void ExReleaseRundownProtectionCacheAware( + _Inout_ EX_RUNDOWN_REF_CACHE_AWARE* RunRefCacheAware +); + +USERSIM_API +void ExFreeCacheAwareRundownProtection( + _Inout_ EX_RUNDOWN_REF_CACHE_AWARE* RunRefCacheAware +); + USERSIM_API _Acquires_exclusive_lock_(push_lock->lock) void ExAcquirePushLockExclusiveEx( _Inout_ _Requires_lock_not_held_(*_Curr_) _Acquires_lock_(*_Curr_) EX_PUSH_LOCK* push_lock, diff --git a/src/ex.cpp b/src/ex.cpp index 4622536..1c24abe 100644 --- a/src/ex.cpp +++ b/src/ex.cpp @@ -45,6 +45,37 @@ ExReleaseRundownProtection(_Inout_ EX_RUNDOWN_REF* rundown_reference) cxplat_release_rundown_protection(rundown_reference); } +EX_RUNDOWN_REF_CACHE_AWARE* +ExAllocateCacheAwareRundownProtection( + _In_ __drv_strictTypeMatch(__drv_typeExpr) POOL_TYPE PoolType, unsigned long PoolTag) +{ + EX_RUNDOWN_REF_CACHE_AWARE* rundown_reference = + (EX_RUNDOWN_REF_CACHE_AWARE*)ExAllocatePoolWithTag(PoolType, sizeof(EX_RUNDOWN_REF_CACHE_AWARE), PoolTag); + + if (rundown_reference != nullptr) { + cxplat_initialize_rundown_protection(rundown_reference); + } + return rundown_reference; +} + +BOOLEAN +ExAcquireRundownProtectionCacheAware(_Inout_ EX_RUNDOWN_REF_CACHE_AWARE* RunRefCacheAware) +{ + return (BOOLEAN)cxplat_acquire_rundown_protection(RunRefCacheAware); +} + +void +ExReleaseRundownProtectionCacheAware(_Inout_ EX_RUNDOWN_REF_CACHE_AWARE* RunRefCacheAware) +{ + cxplat_release_rundown_protection(RunRefCacheAware); +} + +void +ExFreeCacheAwareRundownProtection(_Inout_ EX_RUNDOWN_REF_CACHE_AWARE* RunRefCacheAware) +{ + ExFreePool(RunRefCacheAware); +} + _Acquires_exclusive_lock_(push_lock->lock) void ExAcquirePushLockExclusiveEx( _Inout_ _Requires_lock_not_held_(*_Curr_) _Acquires_lock_(*_Curr_) EX_PUSH_LOCK* push_lock, _In_ unsigned long flags) diff --git a/tests/ex_test.cpp b/tests/ex_test.cpp index f969e9d..fc767db 100644 --- a/tests/ex_test.cpp +++ b/tests/ex_test.cpp @@ -7,6 +7,9 @@ #include #endif #include "usersim/ex.h" +#include "cxplat_winuser.h" + +#include TEST_CASE("ExAllocatePool", "[ex]") { @@ -116,4 +119,112 @@ TEST_CASE("ExRaiseDatatypeMisalignment", "[ex]") int64_t code = _atoi64(ex); REQUIRE(code == STATUS_DATATYPE_MISALIGNMENT); } +} + +TEST_CASE("EX_RUNDOWN_REF", "[ex]") +{ + EX_RUNDOWN_REF ref; + std::atomic thread_completed = false; + ExInitializeRundownProtection(&ref); + + // Acquire before rundown is initiated. + // Acquire the first rundown protection reference. + REQUIRE(ExAcquireRundownProtection(&ref)); + + // Acquire the second rundown protection reference. + REQUIRE(ExAcquireRundownProtection(&ref)); + + // Wait for the rundown protection to be released. + std::thread thread([&]() { + // Wait for the rundown protection to be released. + ExWaitForRundownProtectionRelease(&ref); + thread_completed = true; + }); + + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + // Thread should be waiting for the rundown protection to be released. + REQUIRE(!thread_completed); + + // Acquire after rundown is initiated. + // Future acquire of the rundown protection should fail. + REQUIRE(!ExAcquireRundownProtection(&ref)); + + // Release the second rundown protection reference. + ExReleaseRundownProtection(&ref); + + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + // Thread should be waiting for the rundown protection to be released. + REQUIRE(!thread_completed); + + // Release the first rundown protection reference. + ExReleaseRundownProtection(&ref); + + // Thread should have completed. + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + thread.join(); + + // Thread should be waiting for the rundown protection to be released. + REQUIRE(thread_completed); + + // Acquire after rundown is completed. + + // Future acquire of the rundown protection should fail. + REQUIRE(!ExAcquireRundownProtection(&ref)); +} + +TEST_CASE("EX_RUNDOWN_REF_CACHE_AWARE", "[ex]") +{ + EX_RUNDOWN_REF_CACHE_AWARE* ref = ExAllocateCacheAwareRundownProtection(NonPagedPoolNx, 'tset'); + std::atomic thread_completed = false; + REQUIRE(ref != nullptr); + + // Acquire before rundown is initiated. + // Acquire the first rundown protection reference. + REQUIRE(ExAcquireRundownProtectionCacheAware(ref)); + + // Acquire the second rundown protection reference. + REQUIRE(ExAcquireRundownProtectionCacheAware(ref)); + + // Wait for the rundown protection to be released. + std::thread thread([&]() { + // Wait for the rundown protection to be released. + ExWaitForRundownProtectionRelease(ref); + thread_completed = true; + }); + + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + // Thread should be waiting for the rundown protection to be released. + REQUIRE(!thread_completed); + + // Acquire after rundown is initiated. + // Future acquire of the rundown protection should fail. + REQUIRE(!ExAcquireRundownProtectionCacheAware(ref)); + + // Release the second rundown protection reference. + ExReleaseRundownProtectionCacheAware(ref); + + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + // Thread should be waiting for the rundown protection to be released. + REQUIRE(!thread_completed); + + // Release the first rundown protection reference. + ExReleaseRundownProtectionCacheAware(ref); + + // Thread should have completed. + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + thread.join(); + REQUIRE(thread_completed); + + // Acquire after rundown is completed. + + // Future acquire of the rundown protection should fail. + REQUIRE(!ExAcquireRundownProtectionCacheAware(ref)); + + ExFreeCacheAwareRundownProtection(ref); } \ No newline at end of file