Skip to content

Ensure correct SPN when calling SspiContextProvider #3347

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,7 @@
</Compile>
<Compile Include="Microsoft\Data\Common\DbConnectionOptions.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\ConcurrentQueueSemaphore.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\ResolvedServerSpn.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIError.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNICommon.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIHandle.cs" />
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

#nullable enable

namespace Microsoft.Data.SqlClient.SNI
{
/// <summary>
/// This is used to hold the ServerSpn for a given connection. Most connection types have a single format, although TCP connections may allow
/// with and without a port. Depending on how the SPN is registered on the server, either one may be the correct name.
/// </summary>
/// <see href="https://learn.microsoft.com/sql/database-engine/configure-windows/register-a-service-principal-name-for-kerberos-connections?view=sql-server-ver17#spn-formats"/>
/// <param name="primary"></param>
/// <param name="secondary"></param>
/// <remarks>
/// <para>SQL Server SPN format follows these patterns:</para>
/// <list type="bullet">
/// <item>
/// <term>Default instance, no port (primary):</term>
/// <description>MSSQLSvc/fully-qualified-domain-name</description>
/// </item>
/// <item>
/// <term>Default instance, default port (secondary):</term>
/// <description>MSSQLSvc/fully-qualified-domain-name:1433</description>
/// </item>
/// <item>
/// <term>Named instance or custom port:</term>
/// <description>MSSQLSvc/fully-qualified-domain-name:port_or_instance_name</description>
/// </item>
/// </list>
/// <para>For TCP connections to named instances, the port number is used in SPN.</para>
/// <para>For Named Pipe connections to named instances, the instance name is used in SPN.</para>
/// <para>When hostname resolution fails, the user-provided hostname is used instead of FQDN.</para>
/// <para>For default instances with TCP protocol, both forms (with and without port) may be returned.</para>
/// </remarks>
internal readonly struct ResolvedServerSpn(string primary, string? secondary = null)
{
public string Primary => primary;

public string? Secondary => secondary;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,9 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Net;
using System.Net.Security;
using System.Net.Sockets;
using System.Text;
using Microsoft.Data.ProviderBase;
Expand All @@ -34,7 +31,7 @@ internal class SNIProxy
/// <param name="fullServerName">Full server name from connection string</param>
/// <param name="timeout">Timer expiration</param>
/// <param name="instanceName">Instance name</param>
/// <param name="spns">SPNs</param>
/// <param name="resolvedSpn">SPN</param>
/// <param name="serverSPN">pre-defined SPN</param>
/// <param name="flushCache">Flush packet cache</param>
/// <param name="async">Asynchronous connection</param>
Expand All @@ -51,7 +48,7 @@ internal static SNIHandle CreateConnectionHandle(
string fullServerName,
TimeoutTimer timeout,
out byte[] instanceName,
ref string[] spns,
out ResolvedServerSpn resolvedSpn,
string serverSPN,
bool flushCache,
bool async,
Expand All @@ -65,6 +62,7 @@ internal static SNIHandle CreateConnectionHandle(
string serverCertificateFilename)
{
instanceName = new byte[1];
resolvedSpn = default;

bool errorWithLocalDBProcessing;
string localDBDataSource = GetLocalDBDataSource(fullServerName, out errorWithLocalDBProcessing);
Expand Down Expand Up @@ -103,7 +101,7 @@ internal static SNIHandle CreateConnectionHandle(
{
try
{
spns = GetSqlServerSPNs(details, serverSPN);
resolvedSpn = GetSqlServerSPNs(details, serverSPN);
}
catch (Exception e)
{
Expand All @@ -115,12 +113,12 @@ internal static SNIHandle CreateConnectionHandle(
return sniHandle;
}

private static string[] GetSqlServerSPNs(DataSource dataSource, string serverSPN)
private static ResolvedServerSpn GetSqlServerSPNs(DataSource dataSource, string serverSPN)
{
Debug.Assert(!string.IsNullOrWhiteSpace(dataSource.ServerName));
if (!string.IsNullOrWhiteSpace(serverSPN))
{
return new[] { serverSPN };
return new(serverSPN);
}

string hostName = dataSource.ServerName;
Expand All @@ -138,7 +136,7 @@ private static string[] GetSqlServerSPNs(DataSource dataSource, string serverSPN
return GetSqlServerSPNs(hostName, postfix, dataSource.ResolvedProtocol);
}

private static string[] GetSqlServerSPNs(string hostNameOrAddress, string portOrInstanceName, DataSource.Protocol protocol)
private static ResolvedServerSpn GetSqlServerSPNs(string hostNameOrAddress, string portOrInstanceName, DataSource.Protocol protocol)
{
Debug.Assert(!string.IsNullOrWhiteSpace(hostNameOrAddress));
IPHostEntry hostEntry = null;
Expand Down Expand Up @@ -169,12 +167,12 @@ private static string[] GetSqlServerSPNs(string hostNameOrAddress, string portOr
string serverSpnWithDefaultPort = serverSpn + $":{DefaultSqlServerPort}";
// Set both SPNs with and without Port as Port is optional for default instance
SqlClientEventSource.Log.TryAdvancedTraceEvent("SNIProxy.GetSqlServerSPN | Info | ServerSPNs {0} and {1}", serverSpn, serverSpnWithDefaultPort);
return new[] { serverSpn, serverSpnWithDefaultPort };
return new(serverSpn, serverSpnWithDefaultPort);
}
// else Named Pipes do not need to valid port

SqlClientEventSource.Log.TryAdvancedTraceEvent("SNIProxy.GetSqlServerSPN | Info | ServerSPN {0}", serverSpn);
return new[] { serverSpn };
return new(serverSpn);
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,6 @@ internal sealed partial class TdsParser

private bool _is2022 = false;

private string[] _serverSpn = null;

// SqlStatistics
private SqlStatistics _statistics = null;

Expand Down Expand Up @@ -395,7 +393,6 @@ internal void Connect(ServerInfo serverInfo,
}
else
{
_serverSpn = null;
SqlClientEventSource.Log.TryTraceEvent("TdsParser.Connect | SEC | Connection Object Id {0}, Authentication Mode: {1}", _connHandler.ObjectID,
authType == SqlAuthenticationMethod.NotSpecified ? SqlAuthenticationMethod.SqlPassword.ToString() : authType.ToString());
}
Expand All @@ -407,7 +404,6 @@ internal void Connect(ServerInfo serverInfo,
SqlClientEventSource.Log.TryTraceEvent("<sc.TdsParser.Connect|SEC> Encryption will be disabled as target server is a SQL Local DB instance.");
}

_serverSpn = null;
_authenticationProvider = null;

// AD Integrated behaves like Windows integrated when connecting to a non-fedAuth server
Expand Down Expand Up @@ -446,7 +442,7 @@ internal void Connect(ServerInfo serverInfo,
serverInfo.ExtendedServerName,
timeout,
out instanceName,
ref _serverSpn,
out var resolvedServerSpn,
false,
true,
fParallel,
Expand All @@ -459,8 +455,6 @@ internal void Connect(ServerInfo serverInfo,
hostNameInCertificate,
serverCertificateFilename);

_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this);

if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status)
{
_physicalStateObj.AddError(ProcessSNIError(_physicalStateObj));
Expand Down Expand Up @@ -546,7 +540,7 @@ internal void Connect(ServerInfo serverInfo,
serverInfo.ExtendedServerName,
timeout,
out instanceName,
ref _serverSpn,
out resolvedServerSpn,
true,
true,
fParallel,
Expand All @@ -559,8 +553,6 @@ internal void Connect(ServerInfo serverInfo,
hostNameInCertificate,
serverCertificateFilename);

_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this);

if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status)
{
_physicalStateObj.AddError(ProcessSNIError(_physicalStateObj));
Expand Down Expand Up @@ -599,6 +591,11 @@ internal void Connect(ServerInfo serverInfo,
}
SqlClientEventSource.Log.TryTraceEvent("<sc.TdsParser.Connect|SEC> Prelogin handshake successful");

if (_authenticationProvider is { })
{
_authenticationProvider.Initialize(serverInfo, _physicalStateObj, this, resolvedServerSpn.Primary, resolvedServerSpn.Secondary);
}

if (_fMARS && marsCapable)
{
// if user explicitly disables mars or mars not supported, don't create the session pool
Expand Down Expand Up @@ -744,7 +741,7 @@ private void SendPreLoginHandshake(

// UNDONE - need to do some length verification to ensure packet does not
// get too big!!! Not beyond it's max length!

for (int option = (int)PreLoginOptions.VERSION; option < (int)PreLoginOptions.NUMOPT; option++)
{
int optionDataSize = 0;
Expand Down Expand Up @@ -935,7 +932,7 @@ private PreLoginHandshakeStatus ConsumePreLoginHandshake(
string serverCertificateFilename)
{
// Assign default values
marsCapable = _fMARS;
marsCapable = _fMARS;
fedAuthRequired = false;
Debug.Assert(_physicalStateObj._syncOverAsync, "Should not attempt pends in a synchronous call");
TdsOperationStatus result = _physicalStateObj.TryReadNetworkPacket();
Expand Down Expand Up @@ -2181,7 +2178,7 @@ internal TdsOperationStatus TryRun(RunBehavior runBehavior, SqlCommand cmdHandle
dataStream.BrowseModeInfoConsumed = true;
}
else
{
{
// no dataStream
result = stateObj.TrySkipBytes(tokenLength);
if (result != TdsOperationStatus.Done)
Expand All @@ -2195,7 +2192,7 @@ internal TdsOperationStatus TryRun(RunBehavior runBehavior, SqlCommand cmdHandle
case TdsEnums.SQLDONE:
case TdsEnums.SQLDONEPROC:
case TdsEnums.SQLDONEINPROC:
{
{
// RunBehavior can be modified - see SQL BU DT 269516 & 290090
result = TryProcessDone(cmdHandler, dataStream, ref runBehavior, stateObj);
if (result != TdsOperationStatus.Done)
Expand Down Expand Up @@ -4122,7 +4119,7 @@ internal TdsOperationStatus TryProcessReturnValue(int length,
{
return result;
}

byte len;
result = stateObj.TryReadByte(out len);
if (result != TdsOperationStatus.Done)
Expand Down Expand Up @@ -4321,7 +4318,7 @@ internal TdsOperationStatus TryProcessReturnValue(int length,
{
return result;
}

if (rec.collation.IsUTF8)
{ // UTF8 collation
rec.encoding = Encoding.UTF8;
Expand Down Expand Up @@ -4776,13 +4773,13 @@ internal TdsOperationStatus TryProcessAltMetaData(int cColumns, TdsParserStateOb
{
// internal meta data class
_SqlMetaData col = altMetaDataSet[i];

result = stateObj.TryReadByte(out _);
if (result != TdsOperationStatus.Done)
{
return result;
}

result = stateObj.TryReadUInt16(out _);
if (result != TdsOperationStatus.Done)
{
Expand Down Expand Up @@ -5466,7 +5463,7 @@ private TdsOperationStatus TryProcessColInfo(_SqlMetaDataSet columns, SqlDataRea
for (int i = 0; i < columns.Length; i++)
{
_SqlMetaData col = columns[i];

TdsOperationStatus result = stateObj.TryReadByte(out _);
if (result != TdsOperationStatus.Done)
{
Expand Down Expand Up @@ -7386,7 +7383,7 @@ private byte[] SerializeSqlMoney(SqlMoney value, int length, TdsParserStateObjec

private void WriteSqlMoney(SqlMoney value, int length, TdsParserStateObject stateObj)
{
// UNDONE: can I use SqlMoney.ToInt64()?
// UNDONE: can I use SqlMoney.ToInt64()?
int[] bits = decimal.GetBits(value.Value);

// this decimal should be scaled by 10000 (regardless of what the incoming decimal was scaled by)
Expand Down Expand Up @@ -9906,7 +9903,7 @@ private Task TDSExecuteRPCAddParameter(TdsParserStateObject stateObj, SqlParamet

WriteUDTMetaData(value, names[0], names[1], names[2], stateObj);

// UNDONE - re-org to use code below to write value!
// UNDONE - re-org to use code below to write value!
if (!isNull)
{
WriteUnsignedLong((ulong)udtVal.Length, stateObj); // PLP length
Expand Down Expand Up @@ -12340,7 +12337,7 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int
case TdsEnums.SQLNVARCHAR:
case TdsEnums.SQLNTEXT:
case TdsEnums.SQLXMLTYPE:
case TdsEnums.SQLJSON:
case TdsEnums.SQLJSON:
{
Debug.Assert(!isDataFeed || (value is TextDataFeed || value is XmlDataFeed), "Value must be a TextReader or XmlReader");
Debug.Assert(isDataFeed || (value is string || value is byte[]), "Value is a byte array or string");
Expand Down Expand Up @@ -13556,15 +13553,14 @@ private TdsOperationStatus TryProcessUDTMetaData(SqlMetaDataPriv metaData, TdsPa
+ " _connHandler = {14}\n\t"
+ " _fMARS = {15}\n\t"
+ " _sessionPool = {16}\n\t"
+ " _sniSpnBuffer = {17}\n\t"
+ " _errors = {18}\n\t"
+ " _warnings = {19}\n\t"
+ " _attentionErrors = {20}\n\t"
+ " _attentionWarnings = {21}\n\t"
+ " _statistics = {22}\n\t"
+ " _statisticsIsInTransaction = {23}\n\t"
+ " _fPreserveTransaction = {24}"
+ " _fParallel = {25}"
+ " _errors = {17}\n\t"
+ " _warnings = {18}\n\t"
+ " _attentionErrors = {19}\n\t"
+ " _attentionWarnings = {20}\n\t"
+ " _statistics = {21}\n\t"
+ " _statisticsIsInTransaction = {22}\n\t"
+ " _fPreserveTransaction = {23}"
+ " _fParallel = {24}"
;
internal string TraceString()
{
Expand All @@ -13587,7 +13583,6 @@ internal string TraceString()
_connHandler == null ? "(null)" : _connHandler.ObjectID.ToString((IFormatProvider)null),
_fMARS ? bool.TrueString : bool.FalseString,
_sessionPool == null ? "(null)" : _sessionPool.TraceString(),
_serverSpn == null ? "(null)" : _serverSpn.Length.ToString((IFormatProvider)null),
_physicalStateObj != null ? "(null)" : _physicalStateObj.ErrorCount.ToString((IFormatProvider)null),
_physicalStateObj != null ? "(null)" : _physicalStateObj.WarningCount.ToString((IFormatProvider)null),
_physicalStateObj != null ? "(null)" : _physicalStateObj.PreAttentionErrorCount.ToString((IFormatProvider)null),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using System.Threading.Tasks;
using Microsoft.Data.Common;
using Microsoft.Data.ProviderBase;
using Microsoft.Data.SqlClient.SNI;

namespace Microsoft.Data.SqlClient
{
Expand Down Expand Up @@ -55,7 +56,7 @@ internal TdsParserStateObject(TdsParser parser, TdsParserStateObject physicalCon
AddError(parser.ProcessSNIError(this));
ThrowExceptionAndWarning();
}

// we post a callback that represents the call to dispose; once the
// object is disposed, the next callback will cause the GC Handle to
// be released.
Expand All @@ -71,7 +72,7 @@ internal abstract void CreatePhysicalSNIHandle(
string serverName,
TimeoutTimer timeout,
out byte[] instanceName,
ref string[] spns,
out ResolvedServerSpn resolvedSpn,
bool flushCache,
bool async,
bool fParallel,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ internal override void CreatePhysicalSNIHandle(
string serverName,
TimeoutTimer timeout,
out byte[] instanceName,
ref string[] spns,
out ResolvedServerSpn resolvedSpn,
bool flushCache,
bool async,
bool parallel,
Expand All @@ -94,7 +94,7 @@ internal override void CreatePhysicalSNIHandle(
string hostNameInCertificate,
string serverCertificateFilename)
{
SNIHandle? sessionHandle = SNIProxy.CreateConnectionHandle(serverName, timeout, out instanceName, ref spns, serverSPN,
SNIHandle? sessionHandle = SNIProxy.CreateConnectionHandle(serverName, timeout, out instanceName, out resolvedSpn, serverSPN,
flushCache, async, parallel, isIntegratedSecurity, iPAddressPreference, cachedFQDN, ref pendingDNSInfo, tlsFirst,
hostNameInCertificate, serverCertificateFilename);

Expand Down
Loading
Loading