diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Protocol.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Protocol.cs index 22a3f91f9630fa..5160628cf827be 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Protocol.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Protocol.cs @@ -194,40 +194,44 @@ internal void CloseContext() // Protecting from X509Certificate2 derived classes. X509Certificate2? certEx = MakeEx(certificate); - if (certEx != null) + if (certEx is null) { - if (certEx.HasPrivateKey) - { - if (NetEventSource.Log.IsEnabled()) - NetEventSource.Log.CertIsType2(instance); + return null; + } - return certEx; - } + if (certEx.HasPrivateKey) + { + if (NetEventSource.Log.IsEnabled()) + NetEventSource.Log.CertIsType2(instance); - if (!object.ReferenceEquals(certificate, certEx)) - { - certEx.Dispose(); - } + return certEx; } - string certHash = certEx!.Thumbprint; + Span certHash = stackalloc byte[SHA512.HashSizeInBytes]; + bool ret = certEx.TryGetCertHash(HashAlgorithmName.SHA512, certHash, out int written); + Debug.Assert(ret && written == certHash.Length); + + if (!object.ReferenceEquals(certificate, certEx)) + { + certEx.Dispose(); + } // ELSE Try the MY user and machine stores for private key check. // For server side mode MY machine store takes priority. X509Certificate2? found = - FindCertWithPrivateKey(isServer) ?? - FindCertWithPrivateKey(!isServer); + FindCertWithPrivateKey(isServer, certHash) ?? + FindCertWithPrivateKey(!isServer, certHash); if (found is not null) { return found; } - X509Certificate2? FindCertWithPrivateKey(bool isServer) + X509Certificate2? FindCertWithPrivateKey(bool isServer, ReadOnlySpan certHash) { if (CertificateValidationPal.EnsureStoreOpened(isServer) is X509Store store) { X509Certificate2Collection certs = store.Certificates; - X509Certificate2Collection found = certs.Find(X509FindType.FindByThumbprint, certHash, false); + X509Certificate2Collection found = certs.FindByThumbprint(HashAlgorithmName.SHA512, certHash); X509Certificate2? cert = null; try { @@ -247,19 +251,14 @@ internal void CloseContext() } finally { - for (int i = 0; i < found.Count; i++) + for (int i = 0; i < certs.Count; i++) { - X509Certificate2 toDispose = found[i]; + X509Certificate2 toDispose = certs[i]; if (!ReferenceEquals(toDispose, cert)) { toDispose.Dispose(); } } - - for (int i = 0; i < certs.Count; i++) - { - certs[i].Dispose(); - } } }