diff --git a/src/libraries/Common/src/Interop/Windows/WinHttp/Interop.winhttp_types.cs b/src/libraries/Common/src/Interop/Windows/WinHttp/Interop.winhttp_types.cs index e76fa9b67e5bcb..e08c5ae5bfee4d 100644 --- a/src/libraries/Common/src/Interop/Windows/WinHttp/Interop.winhttp_types.cs +++ b/src/libraries/Common/src/Interop/Windows/WinHttp/Interop.winhttp_types.cs @@ -173,6 +173,8 @@ internal static partial class WinHttp public const uint WINHTTP_OPTION_STREAM_ERROR_CODE = 159; public const uint WINHTTP_OPTION_REQUIRE_STREAM_END = 160; + public const uint WINHTTP_OPTION_CONNECTION_GUID = 178; + public enum WINHTTP_WEB_SOCKET_BUFFER_TYPE { WINHTTP_WEB_SOCKET_BINARY_MESSAGE_BUFFER_TYPE = 0, diff --git a/src/libraries/System.Net.Http.WinHttpHandler/src/System/Net/Http/WinHttpRequestCallback.cs b/src/libraries/System.Net.Http.WinHttpHandler/src/System/Net/Http/WinHttpRequestCallback.cs index c30694a20460b1..8b85399f300941 100644 --- a/src/libraries/System.Net.Http.WinHttpHandler/src/System/Net/Http/WinHttpRequestCallback.cs +++ b/src/libraries/System.Net.Http.WinHttpHandler/src/System/Net/Http/WinHttpRequestCallback.cs @@ -60,6 +60,10 @@ private static void RequestCallback( OnRequestHandleClosing(state); return; + case Interop.WinHttp.WINHTTP_CALLBACK_STATUS_REQUEST_SENT: + OnRequestRequestSent(state); + return; + case Interop.WinHttp.WINHTTP_CALLBACK_STATUS_SENDREQUEST_COMPLETE: OnRequestSendRequestComplete(state); return; @@ -131,6 +135,44 @@ private static void OnRequestHandleClosing(WinHttpRequestState state) state.Dispose(); } + private static unsafe Guid GetGuidForConnection(SafeWinHttpHandle handle) + { + Guid guid = Guid.Empty; + Guid* pGuid = &guid; + uint guidSize = (uint)sizeof(Guid); + if (!Interop.WinHttp.WinHttpQueryOption( + handle, + Interop.WinHttp.WINHTTP_OPTION_CONNECTION_GUID, + (IntPtr)pGuid, + ref guidSize)) + { + int lastError = Marshal.GetLastWin32Error(); + if (NetEventSource.Log.IsEnabled()) NetEventSource.Error(null, $"Error getting WINHTTP_OPTION_CONNECTION_GUID, {lastError}"); + return Guid.Empty; + } + return guid; + } + + private static void OnRequestRequestSent(WinHttpRequestState state) + { + Debug.Assert(state != null, "OnRequestRequestSent: state is null"); + Debug.Assert(state.RequestHandle != null, "OnRequestRequestSent: state.RequestHandle is null"); + Guid connectionGuid = GetGuidForConnection(state.RequestHandle); + if (connectionGuid == Guid.Empty) + { + Guid guid = Guid.NewGuid(); + unsafe + { + if (!Interop.WinHttp.WinHttpSetOption(state.RequestHandle!, Interop.WinHttp.WINHTTP_OPTION_CONNECTION_GUID, (IntPtr)(&guid), (uint)sizeof(Guid))) + { + int lastError = Marshal.GetLastWin32Error(); + if (NetEventSource.Log.IsEnabled()) NetEventSource.Error(null, $"Error setting WINHTTP_OPTION_CONNECTION_GUID, {lastError}"); + } + } + + } + } + private static void OnRequestSendRequestComplete(WinHttpRequestState state) { Debug.Assert(state != null, "OnRequestSendRequestComplete: state is null"); @@ -244,7 +286,9 @@ private static void OnRequestSendingRequest(WinHttpRequestState state) // the TransportContext object. state.TransportContext.SetChannelBinding(state.RequestHandle); - if (state.ServerCertificateValidationCallback != null) + Guid connectionGuid = GetGuidForConnection(state.RequestHandle); + + if (state.ServerCertificateValidationCallback != null && connectionGuid == Guid.Empty) { IntPtr certHandle = IntPtr.Zero; uint certHandleSize = (uint)IntPtr.Size; diff --git a/src/libraries/System.Net.Http.WinHttpHandler/tests/FunctionalTests/WinHttpHandlerTest.cs b/src/libraries/System.Net.Http.WinHttpHandler/tests/FunctionalTests/WinHttpHandlerTest.cs index cc2b97bdde6da7..3883e2e1671b49 100644 --- a/src/libraries/System.Net.Http.WinHttpHandler/tests/FunctionalTests/WinHttpHandlerTest.cs +++ b/src/libraries/System.Net.Http.WinHttpHandler/tests/FunctionalTests/WinHttpHandlerTest.cs @@ -9,7 +9,7 @@ using System.Text; using System.Threading; using System.Threading.Tasks; - +using TestUtilities; using Xunit; using Xunit.Abstractions; @@ -46,6 +46,89 @@ public void SendAsync_SimpleGet_Success() } } + [OuterLoop] + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsWindows10Version22000OrGreater))] + public async Task SendAsync_ServerCertificateValidationCallback_CalledOnce() + { + int callbackCount = 0; + var handler = new WinHttpHandler() + { + ServerCertificateValidationCallback = (m, cert, chain, err) => + { + Interlocked.Increment(ref callbackCount); + return true; + } + }; + using (var client = new HttpClient(handler)) + { + for (int i = 0; i < 5; i++) + { + var response = await client.SendAsync(new HttpRequestMessage(HttpMethod.Get, Configuration.Http.SecureRemoteEchoServer) + { + Version = HttpVersion.Version11 + }); + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + _ = await response.Content.ReadAsStringAsync(); + } + Assert.Equal(1, callbackCount); + } + } + + [OuterLoop] + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsWindows10Version22000OrGreater))] + public async Task SendAsync_ServerCertificateValidationCallbackHttp2_CalledOnce() + { + using TestEventListener testEventListener = new TestEventListener(_output, TestEventListener.NetworkingEvents); + int callbackCount = 0; + var handler = new WinHttpHandler() + { + ServerCertificateValidationCallback = (m, cert, chain, err) => + { + Interlocked.Increment(ref callbackCount); + return true; + } + }; + using (var client = new HttpClient(handler)) + { + for (int i = 0; i < 5; i++) + { + var response = await client.SendAsync(new HttpRequestMessage(HttpMethod.Get, System.Net.Test.Common.Configuration.Http.Http2RemoteEchoServer) + { + Version = HttpVersion20.Value + }); + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + _ = await response.Content.ReadAsStringAsync(); + } + Assert.Equal(1, callbackCount); + } + } + + [OuterLoop] + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsWindows10Version20348OrLower))] + public async Task SendAsync_ServerCertificateValidationCallback_CalledPerRequest() + { + int callbackCount = 0; + const int RequestCount = 5; + var handler = new WinHttpHandler() + { + ServerCertificateValidationCallback = (m, cert, chain, err) => + { + Interlocked.Increment(ref callbackCount); + return true; + } + }; + using (var client = new HttpClient(handler)) + { + for (int i = 0; i < RequestCount; i++) + { + var response = client.GetAsync(System.Net.Test.Common.Configuration.Http.SecureRemoteEchoServer).Result; + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + _ = await response.Content.ReadAsStringAsync(); + } + Assert.Equal(RequestCount, callbackCount); + } + } + [OuterLoop] [Theory] [InlineData(CookieUsePolicy.UseInternalCookieStoreOnly, "cookieName1", "cookieValue1")]