Skip to content

Commit

Permalink
Fix race condition inproduct check (#8296) (#8297)
Browse files Browse the repository at this point in the history
Co-authored-by: Florian Bernd <[email protected]>
  • Loading branch information
github-actions[bot] and flobernd authored Aug 13, 2024
1 parent 4b6c8e4 commit 6e0cc6c
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 419 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,31 @@
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Elastic.Clients.Elasticsearch.Serverless.Requests;
using System.Threading;
using Elastic.Transport;
using Elastic.Transport.Diagnostics;
using Elastic.Transport.Products.Elasticsearch;

#if ELASTICSEARCH_SERVERLESS
using Elastic.Clients.Elasticsearch.Serverless.Requests;
#else
using Elastic.Clients.Elasticsearch.Requests;
#endif

#if ELASTICSEARCH_SERVERLESS
namespace Elastic.Clients.Elasticsearch.Serverless;
#else

namespace Elastic.Clients.Elasticsearch;
#endif

/// <summary>
/// A strongly-typed client for communicating with Elasticsearch server endpoints.
/// </summary>
public partial class ElasticsearchClient
{
private const string OpenTelemetrySpanAttributePrefix = "db.elasticsearch.";

// This should be updated if any of the code uses semantic conventions defined in newer schema versions.
private const string OpenTelemetrySchemaVersion = "https://opentelemetry.io/schemas/1.21.0";

Expand Down Expand Up @@ -82,13 +92,14 @@ internal ElasticsearchClient(ITransport<IElasticsearchClientSettings> transport)
public Serializer SourceSerializer => _transport.Configuration.SourceSerializer;
public ITransport<IElasticsearchClientSettings> Transport => _transport;

private ProductCheckStatus _productCheckStatus;
private int _productCheckStatus;

private enum ProductCheckStatus
{
NotChecked,
Succeeded,
Failed
NotChecked = 0,
InProgress = 1,
Succeeded = 2,
Failed = 3
}

private partial void SetupNamespaces();
Expand Down Expand Up @@ -133,48 +144,115 @@ private ValueTask<TResponse> DoRequestCoreAsync<TRequest, TResponse, TRequestPar
where TResponse : TransportResponse, new()
where TRequestParameters : RequestParameters, new()
{
if (_productCheckStatus == ProductCheckStatus.Failed)
throw new UnsupportedProductException(UnsupportedProductException.InvalidProductError);
// The product check modifies request parameters and therefore must not be executed concurrently.
// We use a lockless CAS approach to make sure that only a single product check request is executed at a time.
// We do not guarantee that the product check is always performed on the first request.

var (requestModified, hadRequestConfig, originalHeaders) = AttachProductCheckHeaderIfRequired<TRequest, TRequestParameters>(request);
var (resolvedUrl, urlTemplate, resolvedRouteValues, postData) = PrepareRequest<TRequest, TRequestParameters>(request, forceConfiguration);
var openTelemetryData = PrepareOpenTelemetryData<TRequest, TRequestParameters>(request, resolvedRouteValues);
var productCheckStatus = Interlocked.CompareExchange(
ref _productCheckStatus,
(int)ProductCheckStatus.InProgress,
(int)ProductCheckStatus.NotChecked
);

if (_productCheckStatus == ProductCheckStatus.Succeeded && !requestModified)
return productCheckStatus switch
{
if (isAsync)
return new ValueTask<TResponse>(_transport.RequestAsync<TResponse>(request.HttpMethod, resolvedUrl, postData, request.RequestParameters, in openTelemetryData, cancellationToken));
else
return new ValueTask<TResponse>(_transport.Request<TResponse>(request.HttpMethod, resolvedUrl, postData, request.RequestParameters, in openTelemetryData));
(int)ProductCheckStatus.NotChecked => SendRequestWithProductCheck(),
(int)ProductCheckStatus.InProgress or
(int)ProductCheckStatus.Succeeded => SendRequest(),
(int)ProductCheckStatus.Failed => throw new UnsupportedProductException(UnsupportedProductException.InvalidProductError),
_ => throw new InvalidOperationException("unreachable")
};

ValueTask<TResponse> SendRequest()
{
var (resolvedUrl, _, resolvedRouteValues, postData) = PrepareRequest<TRequest, TRequestParameters>(request, forceConfiguration);
var openTelemetryData = PrepareOpenTelemetryData<TRequest, TRequestParameters>(request, resolvedRouteValues);

return isAsync
? new ValueTask<TResponse>(_transport
.RequestAsync<TResponse>(request.HttpMethod, resolvedUrl, postData, request.RequestParameters, in openTelemetryData, cancellationToken))
: new ValueTask<TResponse>(_transport
.Request<TResponse>(request.HttpMethod, resolvedUrl, postData, request.RequestParameters, in openTelemetryData));
}

return SendRequest(isAsync);
async ValueTask<TResponse> SendRequestWithProductCheck()
{
try
{
return await SendRequestWithProductCheckCore().ConfigureAwait(false);
}
catch
{
// Re-try product check on next request.

// 32-bit read/write operations are atomic and due to the initial memory barrier, we can be sure that
// no other thread executes the product check at the same time. Locked access is not required here.
if (_productCheckStatus is (int)ProductCheckStatus.InProgress)
_productCheckStatus = (int)ProductCheckStatus.NotChecked;

throw;
}
}

async ValueTask<TResponse> SendRequest(bool isAsync)
async ValueTask<TResponse> SendRequestWithProductCheckCore()
{
// Attach product check header

var hadRequestConfig = false;
HeadersList? originalHeaders = null;

if (request.RequestParameters.RequestConfiguration is null)
request.RequestParameters.RequestConfiguration = new RequestConfiguration();
else
{
originalHeaders = request.RequestParameters.RequestConfiguration.ResponseHeadersToParse;
hadRequestConfig = true;
}

request.RequestParameters.RequestConfiguration.ResponseHeadersToParse = request.RequestParameters.RequestConfiguration.ResponseHeadersToParse.Count == 0
? new HeadersList("x-elastic-product")
: new HeadersList(request.RequestParameters.RequestConfiguration.ResponseHeadersToParse, "x-elastic-product");

// Send request

var (resolvedUrl, _, resolvedRouteValues, postData) = PrepareRequest<TRequest, TRequestParameters>(request, forceConfiguration);
var openTelemetryData = PrepareOpenTelemetryData<TRequest, TRequestParameters>(request, resolvedRouteValues);

TResponse response;

if (isAsync)
response = await _transport.RequestAsync<TResponse>(request.HttpMethod, resolvedUrl, postData, request.RequestParameters, in openTelemetryData, cancellationToken).ConfigureAwait(false);
{
response = await _transport
.RequestAsync<TResponse>(request.HttpMethod, resolvedUrl, postData, request.RequestParameters, in openTelemetryData, cancellationToken)
.ConfigureAwait(false);
}
else
response = _transport.Request<TResponse>(request.HttpMethod, resolvedUrl, postData, request.RequestParameters, in openTelemetryData);
{
response = _transport
.Request<TResponse>(request.HttpMethod, resolvedUrl, postData, request.RequestParameters, in openTelemetryData);
}

// Evaluate product check result

var productCheckSucceeded = response.ApiCallDetails.TryGetHeader("x-elastic-product", out var values) &&
values.FirstOrDefault(x => x.Equals("Elasticsearch", StringComparison.Ordinal)) is not null;

PostRequestProductCheck<TRequest, TResponse>(request, response);
_productCheckStatus = productCheckSucceeded
? (int)ProductCheckStatus.Succeeded
: (int)ProductCheckStatus.Failed;

if (_productCheckStatus == ProductCheckStatus.Failed)
if (_productCheckStatus == (int)ProductCheckStatus.Failed)
throw new UnsupportedProductException(UnsupportedProductException.InvalidProductError);

if (request.RequestParameters.RequestConfiguration is not null)
{
if (!hadRequestConfig)
{
request.RequestParameters.RequestConfiguration = null;
}
else if (originalHeaders.HasValue && originalHeaders.Value.Count > 0)
{
request.RequestParameters.RequestConfiguration.ResponseHeadersToParse = originalHeaders.Value;
}
}
if (request.RequestParameters.RequestConfiguration is null)
return response;

// Reset request configuration

if (!hadRequestConfig)
request.RequestParameters.RequestConfiguration = null;
else if (originalHeaders is { Count: > 0 })
request.RequestParameters.RequestConfiguration.ResponseHeadersToParse = originalHeaders.Value;

return response;
}
Expand Down Expand Up @@ -215,42 +293,6 @@ private static OpenTelemetryData PrepareOpenTelemetryData<TRequest, TRequestPara
return openTelemetryData;
}

private (bool requestModified, bool hadRequestConfig, HeadersList? originalHeaders) AttachProductCheckHeaderIfRequired<TRequest, TRequestParameters>(TRequest request)
where TRequest : Request<TRequestParameters>
where TRequestParameters : RequestParameters, new()
{
var requestModified = false;
var hadRequestConfig = false;
HeadersList? originalHeaders = null;

// If we have not yet checked the product name, add the product header to the list of headers to parse.
if (_productCheckStatus == ProductCheckStatus.NotChecked)
{
requestModified = true;

if (request.RequestParameters.RequestConfiguration is null)
{
request.RequestParameters.RequestConfiguration = new RequestConfiguration();
}
else
{
originalHeaders = request.RequestParameters.RequestConfiguration.ResponseHeadersToParse;
hadRequestConfig = true;
}

if (request.RequestParameters.RequestConfiguration.ResponseHeadersToParse.Count == 0)
{
request.RequestParameters.RequestConfiguration.ResponseHeadersToParse = new HeadersList("x-elastic-product");
}
else
{
request.RequestParameters.RequestConfiguration.ResponseHeadersToParse = new HeadersList(request.RequestParameters.RequestConfiguration.ResponseHeadersToParse, "x-elastic-product");
}
}

return (requestModified, hadRequestConfig, originalHeaders);
}

private (string resolvedUrl, string urlTemplate, Dictionary<string, string>? resolvedRouteValues, PostData data) PrepareRequest<TRequest, TRequestParameters>(TRequest request,
Action<IRequestConfiguration>? forceConfiguration)
where TRequest : Request<TRequestParameters>
Expand Down Expand Up @@ -278,21 +320,6 @@ private static OpenTelemetryData PrepareOpenTelemetryData<TRequest, TRequestPara
return (resolvedUrl, urlTemplate, routeValues, postData);
}

private void PostRequestProductCheck<TRequest, TResponse>(TRequest request, TResponse response)
where TRequest : Request
where TResponse : TransportResponse, new()
{
if (response.ApiCallDetails.HttpStatusCode.HasValue && response.ApiCallDetails.HttpStatusCode.Value >= 200 && response.ApiCallDetails.HttpStatusCode.Value <= 299 && _productCheckStatus == ProductCheckStatus.NotChecked)
{
if (!response.ApiCallDetails.TryGetHeader("x-elastic-product", out var values) || !values.Single().Equals("Elasticsearch", StringComparison.Ordinal))
{
_productCheckStatus = ProductCheckStatus.Failed;
}

_productCheckStatus = ProductCheckStatus.Succeeded;
}
}

private static void ForceConfiguration<TRequestParameters>(Request<TRequestParameters> request, Action<IRequestConfiguration> forceConfiguration)
where TRequestParameters : RequestParameters, new()
{
Expand Down
Loading

0 comments on commit 6e0cc6c

Please sign in to comment.