From 181bc2453ce68311d338f4025b4201294ee2b3c5 Mon Sep 17 00:00:00 2001 From: "Matt Ige (from Dev Box)" Date: Mon, 25 Mar 2024 16:32:09 -0700 Subject: [PATCH 1/2] add rtl avl functions Signed-off-by: Matt Ige (from Dev Box) --- inc/usersim/rtl.h | 178 ++++++++++++++++-------- tests/rtl_test.cpp | 329 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 452 insertions(+), 55 deletions(-) diff --git a/inc/usersim/rtl.h b/inc/usersim/rtl.h index 269115f..17db3d1 100644 --- a/inc/usersim/rtl.h +++ b/inc/usersim/rtl.h @@ -148,64 +148,34 @@ _IRQL_requires_max_(DISPATCH_LEVEL) _At_(destination_string->Buffer, _Post_equal USERSIM_API VOID NTAPI RtlInitUnicodeString(_Out_ PUNICODE_STRING destination_string, _In_opt_z_ __drv_aliasesMem PCWSTR source_string); -_IRQL_requires_max_(DISPATCH_LEVEL) -USERSIM_API -VOID -NTAPI -RtlInitUTF8String( - _Out_ PUTF8_STRING DestinationString, - _In_opt_z_ __drv_aliasesMem const char* SourceString - ); - -_IRQL_requires_max_(PASSIVE_LEVEL) -USERSIM_API -VOID -NTAPI -RtlFreeUnicodeString( - _Inout_ _At_(UnicodeString->Buffer, _Frees_ptr_opt_) - PUNICODE_STRING UnicodeString - ); - -_When_(AllocateDestinationString, - _At_(DestinationString->MaximumLength, - _Out_range_(<=, (SourceString->MaximumLength / sizeof(WCHAR))))) -_When_(!AllocateDestinationString, - _At_(DestinationString->Buffer, _Const_) - _At_(DestinationString->MaximumLength, _Const_)) -_IRQL_requires_max_(PASSIVE_LEVEL) -_When_(AllocateDestinationString, _Must_inspect_result_) -USERSIM_API -NTSTATUS -NTAPI -RtlUnicodeStringToUTF8String( - _When_(AllocateDestinationString, _Out_ _At_(DestinationString->Buffer, __drv_allocatesMem(Mem))) - _When_(!AllocateDestinationString, _Inout_) - PUTF8_STRING DestinationString, - _In_ PCUNICODE_STRING SourceString, - _In_ BOOLEAN AllocateDestinationString - ); - -_IRQL_requires_max_(PASSIVE_LEVEL) -_Must_inspect_result_ -NTSYSAPI -NTSTATUS -NTAPI -RtlUTF8StringToUnicodeString( +_IRQL_requires_max_(DISPATCH_LEVEL) USERSIM_API VOID NTAPI + RtlInitUTF8String(_Out_ PUTF8_STRING DestinationString, _In_opt_z_ __drv_aliasesMem const char* SourceString); + +_IRQL_requires_max_(PASSIVE_LEVEL) USERSIM_API VOID NTAPI + RtlFreeUnicodeString(_Inout_ _At_(UnicodeString->Buffer, _Frees_ptr_opt_) PUNICODE_STRING UnicodeString); + +_When_( + AllocateDestinationString, + _At_(DestinationString->MaximumLength, _Out_range_(<=, (SourceString->MaximumLength / sizeof(WCHAR))))) + _When_( + !AllocateDestinationString, + _At_(DestinationString->Buffer, _Const_) _At_(DestinationString->MaximumLength, _Const_)) + _IRQL_requires_max_(PASSIVE_LEVEL) + _When_(AllocateDestinationString, _Must_inspect_result_) USERSIM_API NTSTATUS NTAPI + RtlUnicodeStringToUTF8String( + _When_(AllocateDestinationString, _Out_ _At_(DestinationString->Buffer, __drv_allocatesMem(Mem))) + _When_(!AllocateDestinationString, _Inout_) PUTF8_STRING DestinationString, + _In_ PCUNICODE_STRING SourceString, + _In_ BOOLEAN AllocateDestinationString); + +_IRQL_requires_max_(PASSIVE_LEVEL) _Must_inspect_result_ NTSYSAPI NTSTATUS NTAPI RtlUTF8StringToUnicodeString( _When_(AllocateDestinationString, _Out_ _At_(DestinationString->Buffer, __drv_allocatesMem(Mem))) - _When_(!AllocateDestinationString, _Inout_) - PUNICODE_STRING DestinationString, + _When_(!AllocateDestinationString, _Inout_) PUNICODE_STRING DestinationString, _In_ PUTF8_STRING SourceString, - _In_ BOOLEAN AllocateDestinationString - ); + _In_ BOOLEAN AllocateDestinationString); -_IRQL_requires_max_(PASSIVE_LEVEL) -USERSIM_API -void -NTAPI -RtlFreeUTF8String( - _Inout_ _At_(utf8String->Buffer, _Frees_ptr_opt_) - PUTF8_STRING utf8String - ); +_IRQL_requires_max_(PASSIVE_LEVEL) USERSIM_API void NTAPI + RtlFreeUTF8String(_Inout_ _At_(utf8String->Buffer, _Frees_ptr_opt_) PUTF8_STRING utf8String); typedef struct _OBJECT_ATTRIBUTES { @@ -217,6 +187,104 @@ typedef struct _OBJECT_ATTRIBUTES SECURITY_QUALITY_OF_SERVICE* SecurityQualityOfService; } OBJECT_ATTRIBUTES, *POBJECT_ATTRIBUTES; +typedef ULONG CLONG; + +typedef struct _RTL_BALANCED_LINKS +{ + struct _RTL_BALANCED_LINKS* Parent; + struct _RTL_BALANCED_LINKS* LeftChild; + struct _RTL_BALANCED_LINKS* RightChild; + CHAR Balance; + UCHAR Reserved[3]; +} RTL_BALANCED_LINKS; +typedef RTL_BALANCED_LINKS* PRTL_BALANCED_LINKS; + +typedef enum _TABLE_SEARCH_RESULT +{ + TableEmptyTree, + TableFoundNode, + TableInsertAsLeft, + TableInsertAsRight +} TABLE_SEARCH_RESULT; + +typedef enum _RTL_GENERIC_COMPARE_RESULTS +{ + GenericLessThan, + GenericGreaterThan, + GenericEqual +} RTL_GENERIC_COMPARE_RESULTS; + +typedef RTL_GENERIC_COMPARE_RESULTS (*PRTL_AVL_COMPARE_ROUTINE)( + _In_ struct _RTL_AVL_TABLE* Table, _In_ PVOID FirstStruct, _In_ PVOID SecondStruct); + +typedef PVOID (*PRTL_AVL_ALLOCATE_ROUTINE)(_In_ struct _RTL_AVL_TABLE* Table, _In_ CLONG ByteSize); + +typedef VOID (*PRTL_AVL_FREE_ROUTINE)(_In_ struct _RTL_AVL_TABLE* Table, _In_ PVOID Buffer); + +typedef struct _RTL_AVL_TABLE +{ + RTL_BALANCED_LINKS BalancedRoot; + PVOID OrderedPointer; + ULONG WhichOrderedElement; + ULONG NumberGenericTableElements; + ULONG DepthOfTree; + PRTL_BALANCED_LINKS RestartKey; + ULONG DeleteCount; + PRTL_AVL_COMPARE_ROUTINE CompareRoutine; + PRTL_AVL_ALLOCATE_ROUTINE AllocateRoutine; + PRTL_AVL_FREE_ROUTINE FreeRoutine; + PVOID TableContext; +} RTL_AVL_TABLE, *PRTL_AVL_TABLE; + +NTSYSAPI VOID +RtlInitializeGenericTableAvl( + _Out_ PRTL_AVL_TABLE Table, + _In_ PRTL_AVL_COMPARE_ROUTINE CompareRoutine, + _In_ PRTL_AVL_ALLOCATE_ROUTINE AllocateRoutine, + _In_ PRTL_AVL_FREE_ROUTINE FreeRoutine, + _In_opt_ PVOID TableContext); + +NTSYSAPI BOOLEAN +RtlDeleteElementGenericTableAvl(_In_ PRTL_AVL_TABLE Table, _In_ PVOID Buffer); + +NTSYSAPI PVOID +RtlEnumerateGenericTableWithoutSplayingAvl(_In_ PRTL_AVL_TABLE Table, _Inout_ PVOID* RestartKey); + +NTSYSAPI PVOID +RtlEnumerateGenericTableAvl(_In_ PRTL_AVL_TABLE Table, _In_ BOOLEAN Restart); + +NTSYSAPI PVOID +RtlGetElementGenericTableAvl(_In_ PRTL_AVL_TABLE Table, _In_ ULONG I); + +NTSYSAPI PVOID +RtlInsertElementGenericTableAvl( + _In_ PRTL_AVL_TABLE Table, _In_ PVOID Buffer, _In_ CLONG BufferSize, _Out_opt_ PBOOLEAN NewElement); + +NTSYSAPI PVOID +RtlInsertElementGenericTableFullAvl( + _In_ PRTL_AVL_TABLE Table, + _In_ PVOID Buffer, + _In_ CLONG BufferSize, + _Out_opt_ PBOOLEAN NewElement, + _In_ PVOID NodeOrParent, + _In_ TABLE_SEARCH_RESULT SearchResult); + +NTSYSAPI BOOLEAN +RtlIsGenericTableEmptyAvl(_In_ PRTL_AVL_TABLE Table); + +NTSYSAPI PVOID +RtlLookupElementGenericTableAvl(_In_ PRTL_AVL_TABLE Table, _In_ PVOID Buffer); + +NTSYSAPI PVOID +RtlLookupElementGenericTableFullAvl( + _In_ PRTL_AVL_TABLE Table, _In_ PVOID Buffer, _Out_ PVOID* NodeOrParent, _Out_ TABLE_SEARCH_RESULT* SearchResult); + +NTSYSAPI PVOID +RtlLookupFirstMatchingElementGenericTableAvl(_In_ PRTL_AVL_TABLE Table, _In_ PVOID Buffer, _In_ PVOID* RestartKey); + +NTSYSAPI ULONG +RtlNumberGenericTableElementsAvl(_In_ PRTL_AVL_TABLE Table); + // Include Rtl* implementations from ntdll.lib. #pragma comment(lib, "ntdll.lib") diff --git a/tests/rtl_test.cpp b/tests/rtl_test.cpp index 583d698..4efb157 100644 --- a/tests/rtl_test.cpp +++ b/tests/rtl_test.cpp @@ -109,4 +109,333 @@ TEST_CASE("RtlUTF8StringToUnicodeString", "[rtl]") REQUIRE(unicode_string.MaximumLength == 10); REQUIRE(wcscmp(unicode_string.Buffer, L"test") == 0); RtlFreeUnicodeString(&unicode_string); +} + +static RTL_GENERIC_COMPARE_RESULTS +_test_avl_compare_routine(_In_ RTL_AVL_TABLE* table, _In_ PVOID first_struct, _In_ PVOID second_struct) +{ + int first = *(reinterpret_cast(first_struct)); + int second = *(reinterpret_cast(second_struct)); + + if (first < second) { + return GenericLessThan; + } else if (first > second) { + return GenericGreaterThan; + } else { + return GenericEqual; + } +} + +static PVOID +_test_avl_allocate_routine(_In_ struct _RTL_AVL_TABLE* Table, _In_ CLONG ByteSize) +{ + return malloc(ByteSize); +} + +static VOID +_test_avl_free_routine(_In_ struct _RTL_AVL_TABLE* Table, _In_ PVOID Buffer) +{ + free(Buffer); +} + +TEST_CASE("RtlInitializeGenericTableAvl", "[rtl]") +{ + RTL_AVL_TABLE table = {0}; + int context = 0; + + // Invoke without context + RtlInitializeGenericTableAvl( + &table, _test_avl_compare_routine, _test_avl_allocate_routine, _test_avl_free_routine, nullptr); + + // Invoke with context + RtlInitializeGenericTableAvl( + &table, _test_avl_compare_routine, _test_avl_allocate_routine, _test_avl_free_routine, &context); +} + +TEST_CASE("RtlInsertElementGenericTableAvl", "[rtl]") +{ + RTL_AVL_TABLE table = {0}; + int entry = 0; + BOOLEAN new_element = FALSE; + + RtlInitializeGenericTableAvl( + &table, _test_avl_compare_routine, _test_avl_allocate_routine, _test_avl_free_routine, nullptr); + + // Insert a new entry. + REQUIRE(RtlInsertElementGenericTableAvl(&table, &entry, sizeof(entry), &new_element) != nullptr); + REQUIRE(new_element == TRUE); + + // Re-insert the same entry. + REQUIRE(RtlInsertElementGenericTableAvl(&table, &entry, sizeof(entry), &new_element) != nullptr); + REQUIRE(new_element == FALSE); + + // Insert the another new entry. + entry = 1; + REQUIRE(RtlInsertElementGenericTableAvl(&table, &entry, sizeof(entry), &new_element) != nullptr); + REQUIRE(new_element == TRUE); + + // Remove the entries + entry = 0; + REQUIRE(RtlDeleteElementGenericTableAvl(&table, &entry) == TRUE); + entry = 1; + REQUIRE(RtlDeleteElementGenericTableAvl(&table, &entry) == TRUE); +} + +TEST_CASE("RtlInsertElementGenericTableFullAvl", "[rtl]") +{ + RTL_AVL_TABLE table = {0}; + int entry = 0; + BOOLEAN new_element = FALSE; + PVOID node_or_parent = nullptr; + TABLE_SEARCH_RESULT search_result; + + RtlInitializeGenericTableAvl( + &table, _test_avl_compare_routine, _test_avl_allocate_routine, _test_avl_free_routine, nullptr); + + // Lookup entry + REQUIRE(RtlLookupElementGenericTableFullAvl(&table, &entry, &node_or_parent, &search_result) == nullptr); + + // Insert entry + REQUIRE( + RtlInsertElementGenericTableFullAvl(&table, &entry, sizeof(entry), nullptr, node_or_parent, search_result) != + nullptr); + + // Search for entry while table is populated + entry = 1; + REQUIRE(RtlLookupElementGenericTableFullAvl(&table, &entry, &node_or_parent, &search_result) == nullptr); + + // Insert entry + REQUIRE( + RtlInsertElementGenericTableFullAvl(&table, &entry, sizeof(entry), nullptr, node_or_parent, search_result) != + nullptr); + + // Delete added entries + entry = 0; + REQUIRE(RtlDeleteElementGenericTableAvl(&table, &entry) == TRUE); + entry = 1; + REQUIRE(RtlDeleteElementGenericTableAvl(&table, &entry) == TRUE); +} + +TEST_CASE("RtlDeleteElementGenericTableAvl", "[rtl]") +{ + RTL_AVL_TABLE table = {0}; + int entry = 0; + + RtlInitializeGenericTableAvl( + &table, _test_avl_compare_routine, _test_avl_allocate_routine, _test_avl_free_routine, nullptr); + + // Deleting an entry which does not exist should fail. + REQUIRE(RtlDeleteElementGenericTableAvl(&table, &entry) == FALSE); + + // Insert and remove the entry. + REQUIRE(RtlInsertElementGenericTableAvl(&table, &entry, sizeof(entry), nullptr) != nullptr); + REQUIRE(RtlDeleteElementGenericTableAvl(&table, &entry) == TRUE); + + // Deleting an already deleted enry should fail. + REQUIRE(RtlDeleteElementGenericTableAvl(&table, &entry) == FALSE); +} + +TEST_CASE("RtlIsGenericTableEmptyAvl", "[rtl]") +{ + RTL_AVL_TABLE table = {0}; + int entry = 0; + + RtlInitializeGenericTableAvl( + &table, _test_avl_compare_routine, _test_avl_allocate_routine, _test_avl_free_routine, nullptr); + + // Table should be empty after initialization. + REQUIRE(RtlIsGenericTableEmptyAvl(&table) == TRUE); + + // Table should not be empty after inserting an element. + REQUIRE(RtlInsertElementGenericTableAvl(&table, &entry, sizeof(entry), nullptr) != nullptr); + REQUIRE(RtlIsGenericTableEmptyAvl(&table) == FALSE); + + // Table should be empty after removing the element. + REQUIRE(RtlDeleteElementGenericTableAvl(&table, &entry) == TRUE); + REQUIRE(RtlIsGenericTableEmptyAvl(&table) == TRUE); +} + +TEST_CASE("RtlLookupElementGenericTableAvl", "[rtl]") +{ + RTL_AVL_TABLE table = {0}; + int entry = 0; + + RtlInitializeGenericTableAvl( + &table, _test_avl_compare_routine, _test_avl_allocate_routine, _test_avl_free_routine, nullptr); + + // Lookup on an empty table should return nullptr. + REQUIRE(RtlLookupElementGenericTableAvl(&table, &entry) == nullptr); + + // Lookup should succeed after inserting an element. + REQUIRE(RtlInsertElementGenericTableAvl(&table, &entry, sizeof(entry), nullptr) != nullptr); + PVOID buffer = RtlLookupElementGenericTableAvl(&table, &entry); + REQUIRE(buffer != nullptr); + REQUIRE(entry == *(reinterpret_cast(buffer))); + + // Lookup should fail after removingthe element. + REQUIRE(RtlDeleteElementGenericTableAvl(&table, &entry) == TRUE); + REQUIRE(RtlLookupElementGenericTableAvl(&table, &entry) == nullptr); +} + +TEST_CASE("RtlLookupElementGenericTableFullAvl", "[rtl]") +{ + RTL_AVL_TABLE table = {0}; + int entry = 0; + PVOID node_or_parent = nullptr; + TABLE_SEARCH_RESULT result; + + RtlInitializeGenericTableAvl( + &table, _test_avl_compare_routine, _test_avl_allocate_routine, _test_avl_free_routine, nullptr); + + // Lookup on an empty table should return nullptr. + REQUIRE(RtlLookupElementGenericTableFullAvl(&table, &entry, &node_or_parent, &result) == nullptr); + REQUIRE(result == TableEmptyTree); + + // Lookup should succeed after inserting an element. + REQUIRE(RtlInsertElementGenericTableAvl(&table, &entry, sizeof(entry), nullptr) != nullptr); + REQUIRE(RtlLookupElementGenericTableFullAvl(&table, &entry, &node_or_parent, &result) != nullptr); + REQUIRE(result == TableFoundNode); + + // Search for an entry greater than the inserted entry. + entry = 1; + REQUIRE(RtlLookupElementGenericTableFullAvl(&table, &entry, &node_or_parent, &result) == nullptr); + REQUIRE(result == TableInsertAsRight); + + // Search for an entry less than the inserted entry. + entry = -1; + REQUIRE(RtlLookupElementGenericTableFullAvl(&table, &entry, &node_or_parent, &result) == nullptr); + REQUIRE(result == TableInsertAsLeft); + + // Delete the entry + entry = 0; + REQUIRE(RtlDeleteElementGenericTableAvl(&table, &entry) == TRUE); +} + +TEST_CASE("RtlLookupFirstMatchingElementGenericTableAvl", "[rtl]") +{ + RTL_AVL_TABLE table = {0}; + int entry = 0; + PVOID restart_key = nullptr; + + RtlInitializeGenericTableAvl( + &table, _test_avl_compare_routine, _test_avl_allocate_routine, _test_avl_free_routine, nullptr); + + // Lookup on an empty table should return nullptr. + REQUIRE(RtlLookupFirstMatchingElementGenericTableAvl(&table, &entry, &restart_key) == nullptr); + REQUIRE(restart_key == nullptr); + + // Lookup should succeed after inserting an element. + REQUIRE(RtlInsertElementGenericTableAvl(&table, &entry, sizeof(entry), nullptr) != nullptr); + REQUIRE(RtlLookupFirstMatchingElementGenericTableAvl(&table, &entry, &restart_key) != nullptr); + REQUIRE(restart_key != nullptr); + + // Delete the entry + entry = 0; + REQUIRE(RtlDeleteElementGenericTableAvl(&table, &entry) == TRUE); +} + +TEST_CASE("RtlGetElementGenericTableAvl", "[rtl]") +{ + RTL_AVL_TABLE table = {0}; + int entry = 0; + + RtlInitializeGenericTableAvl( + &table, _test_avl_compare_routine, _test_avl_allocate_routine, _test_avl_free_routine, nullptr); + + for (int i = 0; i < 100; ++i) { + // Get on element i prior to insertion. This should fail. + REQUIRE(RtlGetElementGenericTableAvl(&table, i) == nullptr); + + // Get on element should succeed after insertions. + REQUIRE(RtlInsertElementGenericTableAvl(&table, &i, sizeof(i), nullptr) != nullptr); + REQUIRE(RtlGetElementGenericTableAvl(&table, i) != nullptr); + } + + // Remove all elements. + for (int i = 0; i < 100; ++i) { + REQUIRE(RtlDeleteElementGenericTableAvl(&table, &i)); + } +} + +TEST_CASE("RtlNumberGenericTableElementsAvl", "[rtl]") +{ + RTL_AVL_TABLE table = {0}; + int entry = 0; + + RtlInitializeGenericTableAvl( + &table, _test_avl_compare_routine, _test_avl_allocate_routine, _test_avl_free_routine, nullptr); + + for (int i = 0; i < 100; ++i) { + // Get count of elements prior to insertion. + REQUIRE(RtlNumberGenericTableElementsAvl(&table) == i); + + // Insert a new element and get the count. + REQUIRE(RtlInsertElementGenericTableAvl(&table, &i, sizeof(i), nullptr) != nullptr); + REQUIRE(RtlNumberGenericTableElementsAvl(&table) == i + 1); + + // Re-insertion should not affect count of elements. + REQUIRE(RtlInsertElementGenericTableAvl(&table, &i, sizeof(i), nullptr) != nullptr); + REQUIRE(RtlNumberGenericTableElementsAvl(&table) == i + 1); + } + + // Remove all elements. + for (int i = 0; i < 100; ++i) { + REQUIRE(RtlDeleteElementGenericTableAvl(&table, &i)); + } +} + +TEST_CASE("RtlEnumerateGenericTableAvl", "[rtl]") +{ + RTL_AVL_TABLE table = {0}; + int entry = 0; + + RtlInitializeGenericTableAvl( + &table, _test_avl_compare_routine, _test_avl_allocate_routine, _test_avl_free_routine, nullptr); + + // Populate the table + for (int i = 0; i < 100; ++i) { + REQUIRE(RtlInsertElementGenericTableAvl(&table, &i, sizeof(i), nullptr) != nullptr); + } + + int expected_entry = 0; + PVOID enumerated_entry = nullptr; + for (enumerated_entry = RtlEnumerateGenericTableAvl(&table, TRUE); enumerated_entry != nullptr; + enumerated_entry = RtlEnumerateGenericTableAvl(&table, FALSE)) { + + REQUIRE(expected_entry++ == *(reinterpret_cast(enumerated_entry))); + } + + // Remove all elements. + for (int i = 0; i < 100; ++i) { + REQUIRE(RtlDeleteElementGenericTableAvl(&table, &i)); + } +} + +TEST_CASE("RtlEnumerateGenericTableWithoutSplayingAvl", "[rtl]") +{ + RTL_AVL_TABLE table = {0}; + int entry = 0; + + RtlInitializeGenericTableAvl( + &table, _test_avl_compare_routine, _test_avl_allocate_routine, _test_avl_free_routine, nullptr); + + // Populate the table + for (int i = 0; i < 100; ++i) { + REQUIRE(RtlInsertElementGenericTableAvl(&table, &i, sizeof(i), nullptr) != nullptr); + } + + int expected_entry = 0; + PVOID enumerated_entry = nullptr; + PVOID restart_key = nullptr; + for (enumerated_entry = RtlEnumerateGenericTableWithoutSplayingAvl(&table, &restart_key); + restart_key != nullptr && enumerated_entry != nullptr; + enumerated_entry = RtlEnumerateGenericTableWithoutSplayingAvl(&table, &restart_key)) { + + REQUIRE(expected_entry++ == *(reinterpret_cast(enumerated_entry))); + } + + // Remove all elements. + for (int i = 0; i < 100; ++i) { + REQUIRE(RtlDeleteElementGenericTableAvl(&table, &i)); + } } \ No newline at end of file From aab42b9f2bf421e6321e81a9ecb15718be56e7de Mon Sep 17 00:00:00 2001 From: "Matt Ige (from Dev Box)" Date: Mon, 25 Mar 2024 16:35:05 -0700 Subject: [PATCH 2/2] reorder Signed-off-by: Matt Ige (from Dev Box) --- inc/usersim/rtl.h | 26 ++++++------ tests/rtl_test.cpp | 102 ++++++++++++++++++++++----------------------- 2 files changed, 64 insertions(+), 64 deletions(-) diff --git a/inc/usersim/rtl.h b/inc/usersim/rtl.h index 17db3d1..02d98d2 100644 --- a/inc/usersim/rtl.h +++ b/inc/usersim/rtl.h @@ -244,18 +244,6 @@ RtlInitializeGenericTableAvl( _In_ PRTL_AVL_FREE_ROUTINE FreeRoutine, _In_opt_ PVOID TableContext); -NTSYSAPI BOOLEAN -RtlDeleteElementGenericTableAvl(_In_ PRTL_AVL_TABLE Table, _In_ PVOID Buffer); - -NTSYSAPI PVOID -RtlEnumerateGenericTableWithoutSplayingAvl(_In_ PRTL_AVL_TABLE Table, _Inout_ PVOID* RestartKey); - -NTSYSAPI PVOID -RtlEnumerateGenericTableAvl(_In_ PRTL_AVL_TABLE Table, _In_ BOOLEAN Restart); - -NTSYSAPI PVOID -RtlGetElementGenericTableAvl(_In_ PRTL_AVL_TABLE Table, _In_ ULONG I); - NTSYSAPI PVOID RtlInsertElementGenericTableAvl( _In_ PRTL_AVL_TABLE Table, _In_ PVOID Buffer, _In_ CLONG BufferSize, _Out_opt_ PBOOLEAN NewElement); @@ -270,7 +258,10 @@ RtlInsertElementGenericTableFullAvl( _In_ TABLE_SEARCH_RESULT SearchResult); NTSYSAPI BOOLEAN -RtlIsGenericTableEmptyAvl(_In_ PRTL_AVL_TABLE Table); +RtlDeleteElementGenericTableAvl(_In_ PRTL_AVL_TABLE Table, _In_ PVOID Buffer); + +NTSYSAPI PVOID +RtlGetElementGenericTableAvl(_In_ PRTL_AVL_TABLE Table, _In_ ULONG I); NTSYSAPI PVOID RtlLookupElementGenericTableAvl(_In_ PRTL_AVL_TABLE Table, _In_ PVOID Buffer); @@ -282,6 +273,15 @@ RtlLookupElementGenericTableFullAvl( NTSYSAPI PVOID RtlLookupFirstMatchingElementGenericTableAvl(_In_ PRTL_AVL_TABLE Table, _In_ PVOID Buffer, _In_ PVOID* RestartKey); +NTSYSAPI PVOID +RtlEnumerateGenericTableAvl(_In_ PRTL_AVL_TABLE Table, _In_ BOOLEAN Restart); + +NTSYSAPI PVOID +RtlEnumerateGenericTableWithoutSplayingAvl(_In_ PRTL_AVL_TABLE Table, _Inout_ PVOID* RestartKey); + +NTSYSAPI BOOLEAN +RtlIsGenericTableEmptyAvl(_In_ PRTL_AVL_TABLE Table); + NTSYSAPI ULONG RtlNumberGenericTableElementsAvl(_In_ PRTL_AVL_TABLE Table); diff --git a/tests/rtl_test.cpp b/tests/rtl_test.cpp index 4efb157..b4efd8a 100644 --- a/tests/rtl_test.cpp +++ b/tests/rtl_test.cpp @@ -235,7 +235,7 @@ TEST_CASE("RtlDeleteElementGenericTableAvl", "[rtl]") REQUIRE(RtlDeleteElementGenericTableAvl(&table, &entry) == FALSE); } -TEST_CASE("RtlIsGenericTableEmptyAvl", "[rtl]") +TEST_CASE("RtlGetElementGenericTableAvl", "[rtl]") { RTL_AVL_TABLE table = {0}; int entry = 0; @@ -243,16 +243,19 @@ TEST_CASE("RtlIsGenericTableEmptyAvl", "[rtl]") RtlInitializeGenericTableAvl( &table, _test_avl_compare_routine, _test_avl_allocate_routine, _test_avl_free_routine, nullptr); - // Table should be empty after initialization. - REQUIRE(RtlIsGenericTableEmptyAvl(&table) == TRUE); + for (int i = 0; i < 100; ++i) { + // Get on element i prior to insertion. This should fail. + REQUIRE(RtlGetElementGenericTableAvl(&table, i) == nullptr); - // Table should not be empty after inserting an element. - REQUIRE(RtlInsertElementGenericTableAvl(&table, &entry, sizeof(entry), nullptr) != nullptr); - REQUIRE(RtlIsGenericTableEmptyAvl(&table) == FALSE); + // Get on element should succeed after insertions. + REQUIRE(RtlInsertElementGenericTableAvl(&table, &i, sizeof(i), nullptr) != nullptr); + REQUIRE(RtlGetElementGenericTableAvl(&table, i) != nullptr); + } - // Table should be empty after removing the element. - REQUIRE(RtlDeleteElementGenericTableAvl(&table, &entry) == TRUE); - REQUIRE(RtlIsGenericTableEmptyAvl(&table) == TRUE); + // Remove all elements. + for (int i = 0; i < 100; ++i) { + REQUIRE(RtlDeleteElementGenericTableAvl(&table, &i)); + } } TEST_CASE("RtlLookupElementGenericTableAvl", "[rtl]") @@ -334,7 +337,7 @@ TEST_CASE("RtlLookupFirstMatchingElementGenericTableAvl", "[rtl]") REQUIRE(RtlDeleteElementGenericTableAvl(&table, &entry) == TRUE); } -TEST_CASE("RtlGetElementGenericTableAvl", "[rtl]") +TEST_CASE("RtlEnumerateGenericTableAvl", "[rtl]") { RTL_AVL_TABLE table = {0}; int entry = 0; @@ -342,13 +345,17 @@ TEST_CASE("RtlGetElementGenericTableAvl", "[rtl]") RtlInitializeGenericTableAvl( &table, _test_avl_compare_routine, _test_avl_allocate_routine, _test_avl_free_routine, nullptr); + // Populate the table for (int i = 0; i < 100; ++i) { - // Get on element i prior to insertion. This should fail. - REQUIRE(RtlGetElementGenericTableAvl(&table, i) == nullptr); - - // Get on element should succeed after insertions. REQUIRE(RtlInsertElementGenericTableAvl(&table, &i, sizeof(i), nullptr) != nullptr); - REQUIRE(RtlGetElementGenericTableAvl(&table, i) != nullptr); + } + + int expected_entry = 0; + PVOID enumerated_entry = nullptr; + for (enumerated_entry = RtlEnumerateGenericTableAvl(&table, TRUE); enumerated_entry != nullptr; + enumerated_entry = RtlEnumerateGenericTableAvl(&table, FALSE)) { + + REQUIRE(expected_entry++ == *(reinterpret_cast(enumerated_entry))); } // Remove all elements. @@ -357,7 +364,7 @@ TEST_CASE("RtlGetElementGenericTableAvl", "[rtl]") } } -TEST_CASE("RtlNumberGenericTableElementsAvl", "[rtl]") +TEST_CASE("RtlEnumerateGenericTableWithoutSplayingAvl", "[rtl]") { RTL_AVL_TABLE table = {0}; int entry = 0; @@ -365,17 +372,19 @@ TEST_CASE("RtlNumberGenericTableElementsAvl", "[rtl]") RtlInitializeGenericTableAvl( &table, _test_avl_compare_routine, _test_avl_allocate_routine, _test_avl_free_routine, nullptr); + // Populate the table for (int i = 0; i < 100; ++i) { - // Get count of elements prior to insertion. - REQUIRE(RtlNumberGenericTableElementsAvl(&table) == i); - - // Insert a new element and get the count. REQUIRE(RtlInsertElementGenericTableAvl(&table, &i, sizeof(i), nullptr) != nullptr); - REQUIRE(RtlNumberGenericTableElementsAvl(&table) == i + 1); + } - // Re-insertion should not affect count of elements. - REQUIRE(RtlInsertElementGenericTableAvl(&table, &i, sizeof(i), nullptr) != nullptr); - REQUIRE(RtlNumberGenericTableElementsAvl(&table) == i + 1); + int expected_entry = 0; + PVOID enumerated_entry = nullptr; + PVOID restart_key = nullptr; + for (enumerated_entry = RtlEnumerateGenericTableWithoutSplayingAvl(&table, &restart_key); + restart_key != nullptr && enumerated_entry != nullptr; + enumerated_entry = RtlEnumerateGenericTableWithoutSplayingAvl(&table, &restart_key)) { + + REQUIRE(expected_entry++ == *(reinterpret_cast(enumerated_entry))); } // Remove all elements. @@ -384,7 +393,7 @@ TEST_CASE("RtlNumberGenericTableElementsAvl", "[rtl]") } } -TEST_CASE("RtlEnumerateGenericTableAvl", "[rtl]") +TEST_CASE("RtlIsGenericTableEmptyAvl", "[rtl]") { RTL_AVL_TABLE table = {0}; int entry = 0; @@ -392,26 +401,19 @@ TEST_CASE("RtlEnumerateGenericTableAvl", "[rtl]") RtlInitializeGenericTableAvl( &table, _test_avl_compare_routine, _test_avl_allocate_routine, _test_avl_free_routine, nullptr); - // Populate the table - for (int i = 0; i < 100; ++i) { - REQUIRE(RtlInsertElementGenericTableAvl(&table, &i, sizeof(i), nullptr) != nullptr); - } - - int expected_entry = 0; - PVOID enumerated_entry = nullptr; - for (enumerated_entry = RtlEnumerateGenericTableAvl(&table, TRUE); enumerated_entry != nullptr; - enumerated_entry = RtlEnumerateGenericTableAvl(&table, FALSE)) { + // Table should be empty after initialization. + REQUIRE(RtlIsGenericTableEmptyAvl(&table) == TRUE); - REQUIRE(expected_entry++ == *(reinterpret_cast(enumerated_entry))); - } + // Table should not be empty after inserting an element. + REQUIRE(RtlInsertElementGenericTableAvl(&table, &entry, sizeof(entry), nullptr) != nullptr); + REQUIRE(RtlIsGenericTableEmptyAvl(&table) == FALSE); - // Remove all elements. - for (int i = 0; i < 100; ++i) { - REQUIRE(RtlDeleteElementGenericTableAvl(&table, &i)); - } + // Table should be empty after removing the element. + REQUIRE(RtlDeleteElementGenericTableAvl(&table, &entry) == TRUE); + REQUIRE(RtlIsGenericTableEmptyAvl(&table) == TRUE); } -TEST_CASE("RtlEnumerateGenericTableWithoutSplayingAvl", "[rtl]") +TEST_CASE("RtlNumberGenericTableElementsAvl", "[rtl]") { RTL_AVL_TABLE table = {0}; int entry = 0; @@ -419,19 +421,17 @@ TEST_CASE("RtlEnumerateGenericTableWithoutSplayingAvl", "[rtl]") RtlInitializeGenericTableAvl( &table, _test_avl_compare_routine, _test_avl_allocate_routine, _test_avl_free_routine, nullptr); - // Populate the table for (int i = 0; i < 100; ++i) { - REQUIRE(RtlInsertElementGenericTableAvl(&table, &i, sizeof(i), nullptr) != nullptr); - } + // Get count of elements prior to insertion. + REQUIRE(RtlNumberGenericTableElementsAvl(&table) == i); - int expected_entry = 0; - PVOID enumerated_entry = nullptr; - PVOID restart_key = nullptr; - for (enumerated_entry = RtlEnumerateGenericTableWithoutSplayingAvl(&table, &restart_key); - restart_key != nullptr && enumerated_entry != nullptr; - enumerated_entry = RtlEnumerateGenericTableWithoutSplayingAvl(&table, &restart_key)) { + // Insert a new element and get the count. + REQUIRE(RtlInsertElementGenericTableAvl(&table, &i, sizeof(i), nullptr) != nullptr); + REQUIRE(RtlNumberGenericTableElementsAvl(&table) == i + 1); - REQUIRE(expected_entry++ == *(reinterpret_cast(enumerated_entry))); + // Re-insertion should not affect count of elements. + REQUIRE(RtlInsertElementGenericTableAvl(&table, &i, sizeof(i), nullptr) != nullptr); + REQUIRE(RtlNumberGenericTableElementsAvl(&table) == i + 1); } // Remove all elements.