Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup | SNI Native Wrapper #3056

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ private static IntPtr UserInstanceDLLHandle
{
if (s_userInstanceDLLHandle == IntPtr.Zero)
{
SniNativeWrapper.SNIQueryInfo(QueryType.SNI_QUERY_LOCALDB_HMODULE, ref s_userInstanceDLLHandle);
SniNativeWrapper.SniQueryInfo(QueryType.SNI_QUERY_LOCALDB_HMODULE, ref s_userInstanceDLLHandle);
if (s_userInstanceDLLHandle != IntPtr.Zero)
{
SqlClientEventSource.Log.TryTraceEvent("LocalDBAPI.UserInstanceDLLHandle | LocalDB - handle obtained");
}
else
{
SniNativeWrapper.SNIGetLastError(out SniError sniError);
SniNativeWrapper.SniGetLastError(out SniError sniError);
throw CreateLocalDBException(errorMessage: StringsHelper.GetString("LocalDB_FailedGetDLLHandle"), sniError: sniError.sniError);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ private SNIErrorDetails GetSniErrorDetails()
}
else
{
SniNativeWrapper.SNIGetLastError(out SniError sniError);
SniNativeWrapper.SniGetLastError(out SniError sniError);
details.sniErrorNumber = sniError.sniError;
details.errorMessage = sniError.errorMessage;
details.nativeError = sniError.nativeError;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ internal override void AssignPendingDNSInfo(string userProtocol, string DNSCache
result = SniNativeWrapper.SniGetConnectionPort(Handle, ref portFromSNI);
Debug.Assert(result == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetConnectionPort");

result = SniNativeWrapper.SniGetConnectionIPString(Handle, ref IPStringFromSNI);
result = SniNativeWrapper.SniGetConnectionIpString(Handle, ref IPStringFromSNI);
Debug.Assert(result == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetConnectionIPString");

pendingDNSInfo = new SQLDNSInfo(DNSCacheKey, null, null, portFromSNI.ToString());
Expand Down Expand Up @@ -187,7 +187,7 @@ internal override void CreatePhysicalSNIHandle(
protected override uint SNIPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize)
{
Debug.Assert(packet.Type == PacketHandle.NativePointerType, "unexpected packet type when requiring NativePointer");
return SniNativeWrapper.SNIPacketGetData(packet.NativePointer, _inBuff, ref dataSize);
return SniNativeWrapper.SniPacketGetData(packet.NativePointer, _inBuff, ref dataSize);
}

protected override bool CheckPacket(PacketHandle packet, TaskCompletionSource<object> source)
Expand Down Expand Up @@ -267,7 +267,7 @@ internal override PacketHandle ReadSyncOverAsync(int timeoutRemaining, out uint
throw ADP.ClosedConnectionError();
}
IntPtr readPacketPtr = IntPtr.Zero;
error = SniNativeWrapper.SNIReadSyncOverAsync(handle, ref readPacketPtr, GetTimeoutRemaining());
error = SniNativeWrapper.SniReadSyncOverAsync(handle, ref readPacketPtr, GetTimeoutRemaining());
return PacketHandle.FromNativePointer(readPacketPtr);
}

Expand All @@ -284,20 +284,20 @@ internal override bool IsPacketEmpty(PacketHandle readPacket)
internal override void ReleasePacket(PacketHandle syncReadPacket)
{
Debug.Assert(syncReadPacket.Type == PacketHandle.NativePointerType, "unexpected packet type when requiring NativePointer");
SniNativeWrapper.SNIPacketRelease(syncReadPacket.NativePointer);
SniNativeWrapper.SniPacketRelease(syncReadPacket.NativePointer);
}

internal override uint CheckConnection()
{
SNIHandle handle = Handle;
return handle == null ? TdsEnums.SNI_SUCCESS : SniNativeWrapper.SNICheckConnection(handle);
return handle == null ? TdsEnums.SNI_SUCCESS : SniNativeWrapper.SniCheckConnection(handle);
}

internal override PacketHandle ReadAsync(SessionHandle handle, out uint error)
{
Debug.Assert(handle.Type == SessionHandle.NativeHandleType, "unexpected handle type when requiring NativePointer");
IntPtr readPacketPtr = IntPtr.Zero;
error = SniNativeWrapper.SNIReadAsync(handle.NativeHandle, ref readPacketPtr);
error = SniNativeWrapper.SniReadAsync(handle.NativeHandle, ref readPacketPtr);
return PacketHandle.FromNativePointer(readPacketPtr);
}

Expand All @@ -313,7 +313,7 @@ internal override PacketHandle CreateAndSetAttentionPacket()
internal override uint WritePacket(PacketHandle packet, bool sync)
{
Debug.Assert(packet.Type == PacketHandle.NativePacketType, "unexpected packet type when requiring NativePacket");
return SniNativeWrapper.SNIWritePacket(Handle, packet.NativePacket, sync);
return SniNativeWrapper.SniWritePacket(Handle, packet.NativePacket, sync);
}

internal override PacketHandle AddPacketToPendingList(PacketHandle packetToAdd)
Expand Down Expand Up @@ -346,7 +346,7 @@ internal override PacketHandle GetResetWritePacket(int dataSize)
{
if (_sniPacket != null)
{
SniNativeWrapper.SNIPacketReset(Handle, IoType.WRITE, _sniPacket, ConsumerNumber.SNI_Consumer_SNI);
SniNativeWrapper.SniPacketReset(Handle, IoType.WRITE, _sniPacket, ConsumerNumber.SNI_Consumer_SNI);
}
else
{
Expand Down Expand Up @@ -375,17 +375,17 @@ internal override void ClearAllWritePackets()
internal override void SetPacketData(PacketHandle packet, byte[] buffer, int bytesUsed)
{
Debug.Assert(packet.Type == PacketHandle.NativePacketType, "unexpected packet type when requiring NativePacket");
SniNativeWrapper.SNIPacketSetData(packet.NativePacket, buffer, bytesUsed);
SniNativeWrapper.SniPacketSetData(packet.NativePacket, buffer, bytesUsed);
}

internal override uint SniGetConnectionId(ref Guid clientConnectionId)
=> SniNativeWrapper.SniGetConnectionId(Handle, ref clientConnectionId);

internal override uint DisableSsl()
=> SniNativeWrapper.SNIRemoveProvider(Handle, Provider.SSL_PROV);
=> SniNativeWrapper.SniRemoveProvider(Handle, Provider.SSL_PROV);

internal override uint EnableMars(ref uint info)
=> SniNativeWrapper.SNIAddProvider(Handle, Provider.SMUX_PROV, ref info);
=> SniNativeWrapper.SniAddProvider(Handle, Provider.SMUX_PROV, ref info);

internal override uint EnableSsl(ref uint info, bool tlsFirst, string serverCertificateFilename)
{
Expand All @@ -395,15 +395,15 @@ internal override uint EnableSsl(ref uint info, bool tlsFirst, string serverCert
authInfo.serverCertFileName = serverCertificateFilename;

// Add SSL (Encryption) SNI provider.
return SniNativeWrapper.SNIAddProvider(Handle, Provider.SSL_PROV, ref authInfo);
return SniNativeWrapper.SniAddProvider(Handle, Provider.SSL_PROV, ref authInfo);
}

internal override uint SetConnectionBufferSize(ref uint unsignedPacketSize)
=> SniNativeWrapper.SNISetInfo(Handle, QueryType.SNI_QUERY_CONN_BUFSIZE, ref unsignedPacketSize);
=> SniNativeWrapper.SniSetInfo(Handle, QueryType.SNI_QUERY_CONN_BUFSIZE, ref unsignedPacketSize);

internal override uint WaitForSSLHandShakeToComplete(out int protocolVersion)
{
uint returnValue = SniNativeWrapper.SNIWaitForSSLHandshakeToComplete(Handle, GetTimeoutRemaining(), out uint nativeProtocolVersion);
uint returnValue = SniNativeWrapper.SniWaitForSslHandshakeToComplete(Handle, GetTimeoutRemaining(), out uint nativeProtocolVersion);
var nativeProtocol = (NativeProtocols)nativeProtocolVersion;

#pragma warning disable CA5398 // Avoid hardcoded SslProtocols values
Expand Down Expand Up @@ -472,7 +472,7 @@ public SNIPacket Take(SNIHandle sniHandle)
{
// Success - reset the packet
packet = _packets.Pop();
SniNativeWrapper.SNIPacketReset(sniHandle, IoType.WRITE, packet, ConsumerNumber.SNI_Consumer_SNI);
SniNativeWrapper.SniPacketReset(sniHandle, IoType.WRITE, packet, ConsumerNumber.SNI_Consumer_SNI);
}
else
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,15 @@ static IntPtr UserInstanceDLLHandle
Monitor.Enter(s_dllLock, ref lockTaken);
if (s_userInstanceDLLHandle == IntPtr.Zero)
{
SniNativeWrapper.SNIQueryInfo(QueryType.SNI_QUERY_LOCALDB_HMODULE, ref s_userInstanceDLLHandle);
SniNativeWrapper.SniQueryInfo(QueryType.SNI_QUERY_LOCALDB_HMODULE, ref s_userInstanceDLLHandle);
if (s_userInstanceDLLHandle != IntPtr.Zero)
{
SqlClientEventSource.Log.TryTraceEvent("<sc.LocalDBAPI.UserInstanceDLLHandle> LocalDB - handle obtained");
}
else
{
SniError sniError = new SniError();
SniNativeWrapper.SNIGetLastError(out sniError);
SniNativeWrapper.SniGetLastError(out sniError);
throw CreateLocalDBException(errorMessage: StringsHelper.GetString("LocalDB_FailedGetDLLHandle"), sniError: sniError.sniError);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@ internal void RemoveEncryption()
uint error = 0;

// Remove SSL (Encryption) SNI provider since we only wanted to encrypt login.
error = SniNativeWrapper.SNIRemoveProvider(_physicalStateObj.Handle, Provider.SSL_PROV);
error = SniNativeWrapper.SniRemoveProvider(_physicalStateObj.Handle, Provider.SSL_PROV);
if (error != TdsEnums.SNI_SUCCESS)
{
_physicalStateObj.AddError(ProcessSNIError(_physicalStateObj));
Expand All @@ -727,7 +727,7 @@ internal void EnableMars()
uint info = 0;

// Add SMUX (MARS) SNI provider.
error = SniNativeWrapper.SNIAddProvider(_pMarsPhysicalConObj.Handle, Provider.SMUX_PROV, ref info);
error = SniNativeWrapper.SniAddProvider(_pMarsPhysicalConObj.Handle, Provider.SMUX_PROV, ref info);

if (error != TdsEnums.SNI_SUCCESS)
{
Expand All @@ -748,12 +748,12 @@ internal void EnableMars()
{
_pMarsPhysicalConObj.IncrementPendingCallbacks();

error = SniNativeWrapper.SNIReadAsync(_pMarsPhysicalConObj.Handle, ref temp);
error = SniNativeWrapper.SniReadAsync(_pMarsPhysicalConObj.Handle, ref temp);

if (temp != IntPtr.Zero)
{
// Be sure to release packet, otherwise it will be leaked by native.
SniNativeWrapper.SNIPacketRelease(temp);
SniNativeWrapper.SniPacketRelease(temp);
}
}
Debug.Assert(IntPtr.Zero == temp, "unexpected syncReadPacket without corresponding SNIPacketRelease");
Expand Down Expand Up @@ -1026,7 +1026,7 @@ private void EnableSsl(uint info, SqlConnectionEncryptOption encrypt, bool integ

Debug.Assert((_encryptionOption & EncryptionOptions.CLIENT_CERT) == 0, "Client certificate authentication support has been removed");

error = SniNativeWrapper.SNIAddProvider(_physicalStateObj.Handle, Provider.SSL_PROV, authInfo);
error = SniNativeWrapper.SniAddProvider(_physicalStateObj.Handle, Provider.SSL_PROV, authInfo);

if (error != TdsEnums.SNI_SUCCESS)
{
Expand All @@ -1038,7 +1038,7 @@ private void EnableSsl(uint info, SqlConnectionEncryptOption encrypt, bool integ
// wait for SSL handshake to complete, so that the SSL context is fully negotiated before we try to use its
// Channel Bindings as part of the Windows Authentication context build (SSL handshake must complete
// before calling SNISecGenClientContext).
error = SniNativeWrapper.SNIWaitForSSLHandshakeToComplete(_physicalStateObj.Handle, _physicalStateObj.GetTimeoutRemaining(), out uint protocolVersion);
error = SniNativeWrapper.SniWaitForSslHandshakeToComplete(_physicalStateObj.Handle, _physicalStateObj.GetTimeoutRemaining(), out uint protocolVersion);

if (error != TdsEnums.SNI_SUCCESS)
{
Expand Down Expand Up @@ -1592,7 +1592,7 @@ internal SqlError ProcessSNIError(TdsParserStateObject stateObj)
Debug.Assert(SniContext.Undefined != stateObj.DebugOnlyCopyOfSniContext || ((_fMARS) && ((_state == TdsParserState.Closed) || (_state == TdsParserState.Broken))), "SniContext must not be None");
#endif
SniError sniError = new SniError();
SniNativeWrapper.SNIGetLastError(out sniError);
SniNativeWrapper.SniGetLastError(out sniError);

if (sniError.sniError != 0)
{
Expand Down Expand Up @@ -2915,7 +2915,7 @@ private TdsOperationStatus TryProcessEnvChange(int tokenLength, TdsParserStateOb

// Update SNI ConsumerInfo value to be resulting packet size
uint unsignedPacketSize = (uint)packetSize;
uint bufferSizeResult = SniNativeWrapper.SNISetInfo(_physicalStateObj.Handle, QueryType.SNI_QUERY_CONN_BUFSIZE, ref unsignedPacketSize);
uint bufferSizeResult = SniNativeWrapper.SniSetInfo(_physicalStateObj.Handle, QueryType.SNI_QUERY_CONN_BUFSIZE, ref unsignedPacketSize);

Debug.Assert(bufferSizeResult == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SNISetInfo");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ internal void AssignPendingDNSInfo(string userProtocol, string DNSCacheKey)
Debug.Assert(result == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetConnectionPort");


result = SniNativeWrapper.SniGetConnectionIPString(_physicalStateObj.Handle, ref IPStringFromSNI);
result = SniNativeWrapper.SniGetConnectionIpString(_physicalStateObj.Handle, ref IPStringFromSNI);
Debug.Assert(result == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetConnectionIPString");

_connHandler.pendingSQLDNSObject = new SQLDNSInfo(DNSCacheKey, null, null, portFromSNI.ToString());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,20 +285,20 @@ internal PacketHandle ReadSyncOverAsync(int timeoutRemaining, out uint error)
{
SNIHandle handle = Handle ?? throw ADP.ClosedConnectionError();
PacketHandle readPacket = default;
error = SniNativeWrapper.SNIReadSyncOverAsync(handle, ref readPacket, timeoutRemaining);
error = SniNativeWrapper.SniReadSyncOverAsync(handle, ref readPacket, timeoutRemaining);
return readPacket;
}

internal PacketHandle ReadAsync(SessionHandle handle, out uint error)
{
PacketHandle readPacket = default;
error = SniNativeWrapper.SNIReadAsync(handle.NativeHandle, ref readPacket);
error = SniNativeWrapper.SniReadAsync(handle.NativeHandle, ref readPacket);
return readPacket;
}

internal uint CheckConnection() => SniNativeWrapper.SNICheckConnection(Handle);
internal uint CheckConnection() => SniNativeWrapper.SniCheckConnection(Handle);

internal void ReleasePacket(PacketHandle syncReadPacket) => SniNativeWrapper.SNIPacketRelease(syncReadPacket);
internal void ReleasePacket(PacketHandle syncReadPacket) => SniNativeWrapper.SniPacketRelease(syncReadPacket);

[ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)]
internal int DecrementPendingCallbacks(bool release)
Expand Down Expand Up @@ -416,7 +416,7 @@ internal bool ValidateSNIConnection()
SNIHandle handle = Handle;
if (handle != null)
{
error = SniNativeWrapper.SNICheckConnection(handle);
error = SniNativeWrapper.SniCheckConnection(handle);
}
}
finally
Expand Down Expand Up @@ -543,7 +543,7 @@ public void ProcessSniPacket(PacketHandle packet, uint error)
{
uint dataSize = 0;

uint getDataError = SniNativeWrapper.SNIPacketGetData(packet, _inBuff, ref dataSize);
uint getDataError = SniNativeWrapper.SniPacketGetData(packet, _inBuff, ref dataSize);

if (getDataError == TdsEnums.SNI_SUCCESS)
{
Expand Down Expand Up @@ -1169,7 +1169,7 @@ private Task SNIWritePacket(SNIHandle handle, SNIPacket packet, out uint sniErro
}
finally
{
sniError = SniNativeWrapper.SNIWritePacket(handle, packet, sync);
sniError = SniNativeWrapper.SniWritePacket(handle, packet, sync);
}

if (sniError == TdsEnums.SNI_SUCCESS_IO_PENDING)
Expand Down Expand Up @@ -1281,7 +1281,7 @@ internal void SendAttention(bool mustTakeWriteLock = false, bool asyncClose = fa
SNIPacket attnPacket = new SNIPacket(Handle);
_sniAsyncAttnPacket = attnPacket;

SniNativeWrapper.SNIPacketSetData(attnPacket, SQL.AttentionHeader, TdsEnums.HEADER_LEN, null, null);
SniNativeWrapper.SniPacketSetData(attnPacket, SQL.AttentionHeader, TdsEnums.HEADER_LEN, null, null);

RuntimeHelpers.PrepareConstrainedRegions();
try
Expand Down Expand Up @@ -1345,7 +1345,7 @@ private Task WriteSni(bool canAccumulate)
{
// Prepare packet, and write to packet.
SNIPacket packet = GetResetWritePacket();
SniNativeWrapper.SNIPacketSetData(packet, _outBuff, _outBytesUsed, _securePasswords, _securePasswordOffsetsInBuffer);
SniNativeWrapper.SniPacketSetData(packet, _outBuff, _outBytesUsed, _securePasswords, _securePasswordOffsetsInBuffer);

Debug.Assert(Parser.Connection._parserLock.ThreadMayHaveLock(), "Thread is writing without taking the connection lock");
Task task = SNIWritePacket(Handle, packet, out _, canAccumulate, callerHasConnectionLock: true);
Expand Down Expand Up @@ -1400,7 +1400,7 @@ internal SNIPacket GetResetWritePacket()
{
if (_sniPacket != null)
{
SniNativeWrapper.SNIPacketReset(Handle, IoType.WRITE, _sniPacket, ConsumerNumber.SNI_Consumer_SNI);
SniNativeWrapper.SniPacketReset(Handle, IoType.WRITE, _sniPacket, ConsumerNumber.SNI_Consumer_SNI);
}
else
{
Expand Down
Loading
Loading