Skip to content

Ensure correct SPN when calling SspiContextProvider #3347

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,6 @@ internal sealed partial class TdsParser

private bool _is2022 = false;

private string[] _serverSpn = null;

// SqlStatistics
private SqlStatistics _statistics = null;

Expand Down Expand Up @@ -395,7 +393,6 @@ internal void Connect(ServerInfo serverInfo,
}
else
{
_serverSpn = null;
SqlClientEventSource.Log.TryTraceEvent("TdsParser.Connect | SEC | Connection Object Id {0}, Authentication Mode: {1}", _connHandler.ObjectID,
authType == SqlAuthenticationMethod.NotSpecified ? SqlAuthenticationMethod.SqlPassword.ToString() : authType.ToString());
}
Expand All @@ -407,7 +404,6 @@ internal void Connect(ServerInfo serverInfo,
SqlClientEventSource.Log.TryTraceEvent("<sc.TdsParser.Connect|SEC> Encryption will be disabled as target server is a SQL Local DB instance.");
}

_serverSpn = null;
_authenticationProvider = null;

// AD Integrated behaves like Windows integrated when connecting to a non-fedAuth server
Expand Down Expand Up @@ -441,12 +437,14 @@ internal void Connect(ServerInfo serverInfo,

_connHandler.pendingSQLDNSObject = null;

string[] serverSpn = null;

// AD Integrated behaves like Windows integrated when connecting to a non-fedAuth server
_physicalStateObj.CreatePhysicalSNIHandle(
serverInfo.ExtendedServerName,
timeout,
out instanceName,
ref _serverSpn,
ref serverSpn,
false,
true,
fParallel,
Expand All @@ -459,8 +457,6 @@ internal void Connect(ServerInfo serverInfo,
hostNameInCertificate,
serverCertificateFilename);

_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this);

if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status)
{
_physicalStateObj.AddError(ProcessSNIError(_physicalStateObj));
Expand Down Expand Up @@ -546,7 +542,7 @@ internal void Connect(ServerInfo serverInfo,
serverInfo.ExtendedServerName,
timeout,
out instanceName,
ref _serverSpn,
ref serverSpn,
true,
true,
fParallel,
Expand All @@ -559,8 +555,6 @@ internal void Connect(ServerInfo serverInfo,
hostNameInCertificate,
serverCertificateFilename);

_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this);

if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status)
{
_physicalStateObj.AddError(ProcessSNIError(_physicalStateObj));
Expand Down Expand Up @@ -599,6 +593,8 @@ internal void Connect(ServerInfo serverInfo,
}
SqlClientEventSource.Log.TryTraceEvent("<sc.TdsParser.Connect|SEC> Prelogin handshake successful");

_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this, serverSpn);

if (_fMARS && marsCapable)
{
// if user explicitly disables mars or mars not supported, don't create the session pool
Expand Down Expand Up @@ -744,7 +740,7 @@ private void SendPreLoginHandshake(

// UNDONE - need to do some length verification to ensure packet does not
// get too big!!! Not beyond it's max length!

for (int option = (int)PreLoginOptions.VERSION; option < (int)PreLoginOptions.NUMOPT; option++)
{
int optionDataSize = 0;
Expand Down Expand Up @@ -935,7 +931,7 @@ private PreLoginHandshakeStatus ConsumePreLoginHandshake(
string serverCertificateFilename)
{
// Assign default values
marsCapable = _fMARS;
marsCapable = _fMARS;
fedAuthRequired = false;
Debug.Assert(_physicalStateObj._syncOverAsync, "Should not attempt pends in a synchronous call");
TdsOperationStatus result = _physicalStateObj.TryReadNetworkPacket();
Expand Down Expand Up @@ -2181,7 +2177,7 @@ internal TdsOperationStatus TryRun(RunBehavior runBehavior, SqlCommand cmdHandle
dataStream.BrowseModeInfoConsumed = true;
}
else
{
{
// no dataStream
result = stateObj.TrySkipBytes(tokenLength);
if (result != TdsOperationStatus.Done)
Expand All @@ -2195,7 +2191,7 @@ internal TdsOperationStatus TryRun(RunBehavior runBehavior, SqlCommand cmdHandle
case TdsEnums.SQLDONE:
case TdsEnums.SQLDONEPROC:
case TdsEnums.SQLDONEINPROC:
{
{
// RunBehavior can be modified - see SQL BU DT 269516 & 290090
result = TryProcessDone(cmdHandler, dataStream, ref runBehavior, stateObj);
if (result != TdsOperationStatus.Done)
Expand Down Expand Up @@ -4122,7 +4118,7 @@ internal TdsOperationStatus TryProcessReturnValue(int length,
{
return result;
}

byte len;
result = stateObj.TryReadByte(out len);
if (result != TdsOperationStatus.Done)
Expand Down Expand Up @@ -4321,7 +4317,7 @@ internal TdsOperationStatus TryProcessReturnValue(int length,
{
return result;
}

if (rec.collation.IsUTF8)
{ // UTF8 collation
rec.encoding = Encoding.UTF8;
Expand Down Expand Up @@ -4776,13 +4772,13 @@ internal TdsOperationStatus TryProcessAltMetaData(int cColumns, TdsParserStateOb
{
// internal meta data class
_SqlMetaData col = altMetaDataSet[i];

result = stateObj.TryReadByte(out _);
if (result != TdsOperationStatus.Done)
{
return result;
}

result = stateObj.TryReadUInt16(out _);
if (result != TdsOperationStatus.Done)
{
Expand Down Expand Up @@ -5466,7 +5462,7 @@ private TdsOperationStatus TryProcessColInfo(_SqlMetaDataSet columns, SqlDataRea
for (int i = 0; i < columns.Length; i++)
{
_SqlMetaData col = columns[i];

TdsOperationStatus result = stateObj.TryReadByte(out _);
if (result != TdsOperationStatus.Done)
{
Expand Down Expand Up @@ -7386,7 +7382,7 @@ private byte[] SerializeSqlMoney(SqlMoney value, int length, TdsParserStateObjec

private void WriteSqlMoney(SqlMoney value, int length, TdsParserStateObject stateObj)
{
// UNDONE: can I use SqlMoney.ToInt64()?
// UNDONE: can I use SqlMoney.ToInt64()?
int[] bits = decimal.GetBits(value.Value);

// this decimal should be scaled by 10000 (regardless of what the incoming decimal was scaled by)
Expand Down Expand Up @@ -9906,7 +9902,7 @@ private Task TDSExecuteRPCAddParameter(TdsParserStateObject stateObj, SqlParamet

WriteUDTMetaData(value, names[0], names[1], names[2], stateObj);

// UNDONE - re-org to use code below to write value!
// UNDONE - re-org to use code below to write value!
if (!isNull)
{
WriteUnsignedLong((ulong)udtVal.Length, stateObj); // PLP length
Expand Down Expand Up @@ -12340,7 +12336,7 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int
case TdsEnums.SQLNVARCHAR:
case TdsEnums.SQLNTEXT:
case TdsEnums.SQLXMLTYPE:
case TdsEnums.SQLJSON:
case TdsEnums.SQLJSON:
{
Debug.Assert(!isDataFeed || (value is TextDataFeed || value is XmlDataFeed), "Value must be a TextReader or XmlReader");
Debug.Assert(isDataFeed || (value is string || value is byte[]), "Value is a byte array or string");
Expand Down Expand Up @@ -13556,15 +13552,14 @@ private TdsOperationStatus TryProcessUDTMetaData(SqlMetaDataPriv metaData, TdsPa
+ " _connHandler = {14}\n\t"
+ " _fMARS = {15}\n\t"
+ " _sessionPool = {16}\n\t"
+ " _sniSpnBuffer = {17}\n\t"
+ " _errors = {18}\n\t"
+ " _warnings = {19}\n\t"
+ " _attentionErrors = {20}\n\t"
+ " _attentionWarnings = {21}\n\t"
+ " _statistics = {22}\n\t"
+ " _statisticsIsInTransaction = {23}\n\t"
+ " _fPreserveTransaction = {24}"
+ " _fParallel = {25}"
+ " _errors = {17}\n\t"
+ " _warnings = {18}\n\t"
+ " _attentionErrors = {19}\n\t"
+ " _attentionWarnings = {20}\n\t"
+ " _statistics = {21}\n\t"
+ " _statisticsIsInTransaction = {22}\n\t"
+ " _fPreserveTransaction = {23}"
+ " _fParallel = {24}"
;
internal string TraceString()
{
Expand All @@ -13587,7 +13582,6 @@ internal string TraceString()
_connHandler == null ? "(null)" : _connHandler.ObjectID.ToString((IFormatProvider)null),
_fMARS ? bool.TrueString : bool.FalseString,
_sessionPool == null ? "(null)" : _sessionPool.TraceString(),
_serverSpn == null ? "(null)" : _serverSpn.Length.ToString((IFormatProvider)null),
_physicalStateObj != null ? "(null)" : _physicalStateObj.ErrorCount.ToString((IFormatProvider)null),
_physicalStateObj != null ? "(null)" : _physicalStateObj.WarningCount.ToString((IFormatProvider)null),
_physicalStateObj != null ? "(null)" : _physicalStateObj.PreAttentionErrorCount.ToString((IFormatProvider)null),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,6 @@ internal sealed partial class TdsParser

private bool _is2022 = false;

private string _serverSpn = null;

// SqlStatistics
private SqlStatistics _statistics = null;

Expand Down Expand Up @@ -396,6 +394,8 @@ internal void Connect(ServerInfo serverInfo,
Debug.Fail("SNI returned status != success, but no error thrown?");
}

string serverSpn = null;

//Create LocalDB instance if necessary
if (connHandler.ConnectionOptions.LocalDBInstance != null)
{
Expand All @@ -415,21 +415,20 @@ internal void Connect(ServerInfo serverInfo,

if (!string.IsNullOrEmpty(serverInfo.ServerSPN))
{
_serverSpn = serverInfo.ServerSPN;
serverSpn = serverInfo.ServerSPN;
SqlClientEventSource.Log.TryTraceEvent("<sc.TdsParser.Connect|SEC> Server SPN `{0}` from the connection string is used.", serverInfo.ServerSPN);
}
else
{
// Empty signifies to interop layer that SPN needs to be generated
_serverSpn = string.Empty;
serverSpn = string.Empty;
}

SqlClientEventSource.Log.TryTraceEvent("<sc.TdsParser.Connect|SEC> SSPI or Active Directory Authentication Library for SQL Server based integrated authentication");
}
else
{
_authenticationProvider = null;
_serverSpn = null;

switch (authType)
{
Expand Down Expand Up @@ -508,7 +507,7 @@ internal void Connect(ServerInfo serverInfo,
serverInfo.ExtendedServerName,
timeout,
out instanceName,
ref _serverSpn,
ref serverSpn,
false,
true,
fParallel,
Expand All @@ -518,8 +517,6 @@ internal void Connect(ServerInfo serverInfo,
FQDNforDNSCache,
hostNameInCertificate);

_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this);

if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status)
{
_physicalStateObj.AddError(ProcessSNIError(_physicalStateObj));
Expand Down Expand Up @@ -602,7 +599,7 @@ internal void Connect(ServerInfo serverInfo,
serverInfo.ExtendedServerName,
timeout,
out instanceName,
ref _serverSpn,
ref serverSpn,
true,
true,
fParallel,
Expand All @@ -612,8 +609,6 @@ internal void Connect(ServerInfo serverInfo,
serverInfo.ResolvedServerName,
hostNameInCertificate);

_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this);

if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status)
{
_physicalStateObj.AddError(ProcessSNIError(_physicalStateObj));
Expand Down Expand Up @@ -648,6 +643,8 @@ internal void Connect(ServerInfo serverInfo,
}
SqlClientEventSource.Log.TryTraceEvent("<sc.TdsParser.Connect|SEC> Prelogin handshake successful");

_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this, serverSpn);

if (_fMARS && marsCapable)
{
// if user explicitly disables mars or mars not supported, don't create the session pool
Expand Down Expand Up @@ -13669,15 +13666,14 @@ internal ulong PlpBytesTotalLength(TdsParserStateObject stateObj)
+ " _connHandler = {14}\n\t"
+ " _fMARS = {15}\n\t"
+ " _sessionPool = {16}\n\t"
+ " _sniSpnBuffer = {17}\n\t"
+ " _errors = {18}\n\t"
+ " _warnings = {19}\n\t"
+ " _attentionErrors = {20}\n\t"
+ " _attentionWarnings = {21}\n\t"
+ " _statistics = {22}\n\t"
+ " _statisticsIsInTransaction = {23}\n\t"
+ " _fPreserveTransaction = {24}"
+ " _fParallel = {25}"
+ " _errors = {17}\n\t"
+ " _warnings = {18}\n\t"
+ " _attentionErrors = {19}\n\t"
+ " _attentionWarnings = {20}\n\t"
+ " _statistics = {21}\n\t"
+ " _statisticsIsInTransaction = {22}\n\t"
+ " _fPreserveTransaction = {23}"
+ " _fParallel = {24}"
;
internal string TraceString()
{
Expand All @@ -13700,7 +13696,6 @@ internal string TraceString()
_connHandler == null ? "(null)" : _connHandler.ObjectID.ToString((IFormatProvider)null),
_fMARS ? bool.TrueString : bool.FalseString,
_sessionPool == null ? "(null)" : _sessionPool.TraceString(),
_serverSpn == null ? "(null)" : _serverSpn.Length.ToString((IFormatProvider)null),
_physicalStateObj != null ? "(null)" : _physicalStateObj.ErrorCount.ToString((IFormatProvider)null),
_physicalStateObj != null ? "(null)" : _physicalStateObj.WarningCount.ToString((IFormatProvider)null),
_physicalStateObj != null ? "(null)" : _physicalStateObj.PreAttentionErrorCount.ToString((IFormatProvider)null),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

using System;
using System.Buffers;
using System.Diagnostics;
using System.Net.Security;

#nullable enable
Expand All @@ -10,13 +11,16 @@ namespace Microsoft.Data.SqlClient
{
internal sealed class NegotiateSspiContextProvider : SspiContextProvider
{
private NegotiateAuthentication? _negotiateAuth = null;
private NegotiateAuthentication? _negotiateAuth;

protected override bool GenerateSspiClientContext(ReadOnlySpan<byte> incomingBlob, IBufferWriter<byte> outgoingBlobWriter, SspiAuthenticationParameters authParams)
{
NegotiateAuthenticationStatusCode statusCode = NegotiateAuthenticationStatusCode.UnknownCredentials;

_negotiateAuth ??= new(new NegotiateAuthenticationClientOptions { Package = "Negotiate", TargetName = authParams.Resource });

Debug.Assert(_negotiateAuth.TargetName == authParams.Resource, "SSPI resource does not match TargetName");

var sendBuff = _negotiateAuth.GetOutgoingBlob(incomingBlob, out statusCode)!;

// Log session id, status code and the actual SPN used in the negotiation
Expand All @@ -29,6 +33,8 @@ protected override bool GenerateSspiClientContext(ReadOnlySpan<byte> incomingBlo
return true;
}

// Reset _negotiateAuth to be generated again for next SPN.
_negotiateAuth = null;
return false;
}
}
Expand Down
Loading
Loading