Skip to content

Commit 92ad3c6

Browse files
Semantic Rerank: Adds Semantic Rerank API (#5445)
# Pull Request Template ## Description This pull request introduces a new semantic reranking feature to the Azure Cosmos DB .NET SDK, enabling users to rerank documents using an inference service that leverages Azure Active Directory (AAD) authentication. The main changes include the addition of the `InferenceService` class, new API surface for semantic reranking, and appropriate integration into the SDK's authorization and client context infrastructure. Notably, this functionality is only available when using AAD authentication. **Semantic Reranking Feature Integration:** * Added the `InferenceService` class, which handles communication with the Cosmos DB Inference Service for semantic reranking, including HTTP client configuration, payload construction, and response handling. This service enforces AAD authentication and manages its own authorization and disposal. * Introduced a new public (under `PREVIEW`) or internal API `SemanticRerankAsync` to the `Container` class, allowing users to rerank a list of documents based on a context/query string. This is implemented in `ContainerInlineCore` and routed through the client context. [[1]](diffhunk://#diff-e3b7704253edcfb63d18e851d9b0a8de3ea60889d704c363c2da1899a3ac39b7R1682-R1702) [[2]](diffhunk://#diff-d7119df75f749b2d2ebcadc708f02d419663febceeaab4147a898c4be777e33dR700-R708) * To use this feature, the environment variable "AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT", must be set with the inference endpoint from the service. * Additionally, the environment variable "AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_SERVICE_MAX_CONNECTION_LIMIT", can be set to change the inference client's max connection limit. **Authorization and Token Handling Updates:** * Extended the `AuthorizationTokenProvider` abstraction and its implementations to support a new method, `AddInferenceAuthorizationHeaderAsync`, which is only valid for AAD-based token providers. Non-AAD providers throw a `NotImplementedException` for this method. [[1]](diffhunk://#diff-839842554f21dd8c2180a3c4f5a25abffdc392902a0f10f1130397f064d84db5R55-R60) [[2]](diffhunk://#diff-93d0d6522a71c823da524b5cdceb024a8aae4497b14b383dcac41a98edf7092aR18-R29) [[3]](diffhunk://#diff-93d0d6522a71c823da524b5cdceb024a8aae4497b14b383dcac41a98edf7092aR78-R92) [[4]](diffhunk://#diff-1e0a2b64595baccd2b71f2328bd8f6ca9a81fb8175c5d489e5cdadba1b5c3a05R217-R221) [[5]](diffhunk://#diff-cca89b75cc8e6136d1f595f20169546258617f88cf0b5e0e345554c073e06e78R95-R99) [[6]](diffhunk://#diff-3ad037c5217ed55cda63aaa1a0e1897c9b1c6603b6e85b9649cdd8cf457ac6fbR128-R132) **Client Context and Resource Management:** * Updated `ClientContextCore` and `CosmosClientContext` to manage the lifecycle of the `InferenceService`, including creation, caching, and disposal. Added methods for invoking semantic reranking and for retrieving or creating the inference service instance. [[1]](diffhunk://#diff-a41317d6b53931e411d3091f58700758c53e80ca27dccedeafd76226852e58d2R38) [[2]](diffhunk://#diff-a41317d6b53931e411d3091f58700758c53e80ca27dccedeafd76226852e58d2R472-R497) [[3]](diffhunk://#diff-a41317d6b53931e411d3091f58700758c53e80ca27dccedeafd76226852e58d2R515) [[4]](diffhunk://#diff-b0bd965dfed52d866ec53af3bbb3b7afb4f420cdda75ac9071515ac471c0f7baR8) [[5]](diffhunk://#diff-b0bd965dfed52d866ec53af3bbb3b7afb4f420cdda75ac9071515ac471c0f7baR136-R148) [[6]](diffhunk://#diff-a41317d6b53931e411d3091f58700758c53e80ca27dccedeafd76226852e58d2R8) **Dependency Updates:** * Added a dependency on the `Azure.Identity` package in the test project to support AAD authentication scenarios. Please delete options that are not relevant. **Example** ```csharp //Sample code to demonstrate Semantic Reranking // Assume 'container' is an instance of Cosmos.Container // This example queries items from a fitness store with full-text search and then reranks them semantically. string search_text = "integrated pull-up bar"; string queryString = $@" SELECT TOP 15 c.id, c.Name, c.Brand, c.Description FROM c WHERE FullTextContains(c.Description, ""{search_text}"") ORDER BY RANK FullTextScore(c.Description, ""{search_text}"") "; string reranking_context = "most economical with multiple pulley adjustmnets and ideal for home gyms"; List<string> documents = new List<string>(); FeedIterator<dynamic> resultSetIterator = container.GetItemQueryIterator<dynamic>( new QueryDefinition(queryString), requestOptions: new QueryRequestOptions() { MaxItemCount = 15, }); while (resultSetIterator.HasMoreResults) { FeedResponse<dynamic> response = await resultSetIterator.ReadNextAsync(); foreach (JsonElement item in response) { documents.Add(item.ToString()); } } Dictionary<string, dynamic> options = new Dictionary<string, dynamic> { { "return_documents", true }, { "top_k", 10 }, { "batch_size", 32 }, { "sort", true } }; SemanticRerankResult results = await container.SemanticRerankAsync( reranking_context, documents, options); // get the best resulting document from the query results.RerankScores.First().Document; // or the index of the document in the original list results.RerankScores.First().Index; // or the reranking score results.RerankScores.First().Score; // get the latency information from the reranking operation Dictonary<string, object. latencyInfo = results.Latency; // get the token usage information from the reranking operation Dictonary<string, object> tokenUseageInfo = results.TokenUseage; ``` - [] New feature (non-breaking change which adds functionality) ## Closing issues To automatically close an issue: closes #IssueNumber
1 parent 4402de8 commit 92ad3c6

File tree

17 files changed

+834
-0
lines changed

17 files changed

+834
-0
lines changed

Microsoft.Azure.Cosmos.Encryption.Custom/src/EncryptionContainer.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,6 +1008,21 @@ public override Task<bool> IsFeedRangePartOfAsync(
10081008
}
10091009
#endif
10101010

1011+
#if PREVIEW && SDKPROJECTREF
1012+
public override Task<SemanticRerankResult> SemanticRerankAsync(
1013+
string rerankContext,
1014+
IEnumerable<string> documents,
1015+
IDictionary<string, object> options = null,
1016+
CancellationToken cancellationToken = default)
1017+
{
1018+
return this.container.SemanticRerankAsync(
1019+
rerankContext,
1020+
documents,
1021+
options,
1022+
cancellationToken);
1023+
}
1024+
#endif
1025+
10111026
private async Task<ResponseMessage> ReadManyItemsHelperAsync(
10121027
IReadOnlyList<(string id, PartitionKey partitionKey)> items,
10131028
ReadManyRequestOptions readManyRequestOptions = null,

Microsoft.Azure.Cosmos.Encryption/src/EncryptionContainer.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -732,6 +732,21 @@ public override FeedIterator<T> GetItemQueryIterator<T>(
732732
}
733733

734734
#if ENCRYPTIONPREVIEW
735+
#if SDKPROJECTREF
736+
public override Task<SemanticRerankResult> SemanticRerankAsync(
737+
string rerankContext,
738+
IEnumerable<string> documents,
739+
IDictionary<string, object> options = null,
740+
CancellationToken cancellationToken = default)
741+
{
742+
return this.Container.SemanticRerankAsync(
743+
rerankContext,
744+
documents,
745+
options,
746+
cancellationToken);
747+
}
748+
749+
#endif
735750
public override async Task<ResponseMessage> DeleteAllItemsByPartitionKeyStreamAsync(
736751
Cosmos.PartitionKey partitionKey,
737752
RequestOptions requestOptions = null,

Microsoft.Azure.Cosmos/src/Authorization/AuthorizationTokenProvider.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@ public abstract ValueTask<string> GetUserAuthorizationTokenAsync(
5252
AuthorizationTokenType tokenType,
5353
ITrace trace);
5454

55+
public abstract ValueTask AddInferenceAuthorizationHeaderAsync(
56+
INameValueCollection headersCollection,
57+
Uri requestAddress,
58+
string verb,
59+
AuthorizationTokenType tokenType);
60+
5561
public abstract void TraceUnauthorized(
5662
DocumentClientException dce,
5763
string authorizationToken,

Microsoft.Azure.Cosmos/src/Authorization/AuthorizationTokenProviderMasterKey.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,11 @@ private void Dispose(bool disposing)
214214
this.authKeyHashFunction = null;
215215
}
216216

217+
public override ValueTask AddInferenceAuthorizationHeaderAsync(INameValueCollection headersCollection, Uri requestAddress, string verb, AuthorizationTokenType tokenType)
218+
{
219+
throw new NotImplementedException("AddInferenceAuthorizationHeaderAsync is only valid for AAD");
220+
}
221+
217222
// Use C# finalizer syntax for finalization code.
218223
// This finalizer will run only if the Dispose method does not get called.
219224
// It gives your base class the opportunity to finalize.

Microsoft.Azure.Cosmos/src/Authorization/AuthorizationTokenProviderResourceToken.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,11 @@ private void Dispose(bool disposing)
9292
// Do nothing
9393
}
9494

95+
public override ValueTask AddInferenceAuthorizationHeaderAsync(INameValueCollection headersCollection, Uri requestAddress, string verb, AuthorizationTokenType tokenType)
96+
{
97+
throw new NotImplementedException("AddInferenceAuthorizationHeaderAsync is only valid for AAD");
98+
}
99+
95100
// Use C# finalizer syntax for finalization code.
96101
// This finalizer will run only if the Dispose method does not get called.
97102
// It gives your base class the opportunity to finalize.

Microsoft.Azure.Cosmos/src/Authorization/AuthorizationTokenProviderTokenCredential.cs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,18 @@ namespace Microsoft.Azure.Cosmos
1515

1616
internal sealed class AuthorizationTokenProviderTokenCredential : AuthorizationTokenProvider
1717
{
18+
private const string InferenceTokenPrefix = "Bearer ";
1819
internal readonly TokenCredentialCache tokenCredentialCache;
1920
private bool isDisposed = false;
2021

22+
internal readonly TokenCredential tokenCredential;
23+
2124
public AuthorizationTokenProviderTokenCredential(
2225
TokenCredential tokenCredential,
2326
Uri accountEndpoint,
2427
TimeSpan? backgroundTokenCredentialRefreshInterval)
2528
{
29+
this.tokenCredential = tokenCredential ?? throw new ArgumentNullException(nameof(tokenCredential));
2630
this.tokenCredentialCache = new TokenCredentialCache(
2731
tokenCredential: tokenCredential,
2832
accountEndpoint: accountEndpoint,
@@ -71,6 +75,21 @@ public override async ValueTask AddAuthorizationHeaderAsync(
7175
}
7276
}
7377

78+
public override async ValueTask AddInferenceAuthorizationHeaderAsync(
79+
INameValueCollection headersCollection,
80+
Uri requestAddress,
81+
string verb,
82+
AuthorizationTokenType tokenType)
83+
{
84+
using (Trace trace = Trace.GetRootTrace(nameof(GetUserAuthorizationTokenAsync), TraceComponent.Authorization, TraceLevel.Info))
85+
{
86+
string token = await this.tokenCredentialCache.GetTokenAsync(trace);
87+
88+
string inferenceToken = $"{InferenceTokenPrefix}{token}";
89+
headersCollection.Add(HttpConstants.HttpHeaders.Authorization, inferenceToken);
90+
}
91+
}
92+
7493
public override void TraceUnauthorized(
7594
DocumentClientException dce,
7695
string authorizationToken,

Microsoft.Azure.Cosmos/src/Authorization/AzureKeyCredentialAuthorizationTokenProvider.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,5 +125,10 @@ private void CheckAndRefreshTokenProvider()
125125
}
126126
}
127127
}
128+
129+
public override ValueTask AddInferenceAuthorizationHeaderAsync(INameValueCollection headersCollection, Uri requestAddress, string verb, AuthorizationTokenType tokenType)
130+
{
131+
throw new NotImplementedException("AddInferenceAuthorizationHeaderAsync is only valid for AAD");
132+
}
128133
}
129134
}
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
//------------------------------------------------------------
2+
// Copyright (c) Microsoft Corporation. All rights reserved.
3+
//------------------------------------------------------------
4+
5+
namespace Microsoft.Azure.Cosmos
6+
{
7+
using System;
8+
using System.Collections.Generic;
9+
using System.Diagnostics;
10+
using System.Linq;
11+
using System.Net.Http;
12+
using System.Net.Http.Headers;
13+
using System.Text;
14+
using System.Threading;
15+
using System.Threading.Tasks;
16+
using global::Azure.Core;
17+
using Microsoft.Azure.Documents;
18+
using Microsoft.Azure.Documents.Collections;
19+
20+
/// <summary>
21+
/// Provides functionality to interact with the Cosmos DB Inference Service for semantic reranking.
22+
/// </summary>
23+
internal class InferenceService : IDisposable
24+
{
25+
// Base path for the inference service endpoint.
26+
private const string basePath = "/inference/semanticReranking";
27+
// User agent string for inference requests.
28+
private const string inferenceUserAgent = "cosmos-inference-dotnet";
29+
// Default scope for AAD authentication.
30+
private const string inferenceServiceDefaultScope = "https://dbinference.azure.com/.default";
31+
private const int inferenceServiceDefaultMaxConnectionLimit = 50;
32+
33+
private readonly int inferenceServiceMaxConnectionLimit;
34+
private readonly string inferenceServiceBaseUrl;
35+
private readonly Uri inferenceEndpoint;
36+
37+
private HttpClient httpClient;
38+
private AuthorizationTokenProvider cosmosAuthorization;
39+
40+
private bool disposedValue;
41+
42+
/// <summary>
43+
/// Initializes a new instance of the <see cref="InferenceService"/> class.
44+
/// </summary>
45+
/// <param name="client">The CosmosClient instance.</param>
46+
/// <exception cref="InvalidOperationException">Thrown if AAD authentication is not used.</exception>
47+
public InferenceService(CosmosClient client)
48+
{
49+
this.inferenceServiceBaseUrl = ConfigurationManager.GetEnvironmentVariable<string>("AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT", null);
50+
51+
if (string.IsNullOrEmpty(this.inferenceServiceBaseUrl))
52+
{
53+
throw new ArgumentNullException("Set environment variable AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT to use inference service");
54+
}
55+
56+
this.inferenceServiceMaxConnectionLimit = ConfigurationManager.GetEnvironmentVariable<int?>(
57+
"AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_SERVICE_MAX_CONNECTION_LIMIT",
58+
inferenceServiceDefaultMaxConnectionLimit) ?? inferenceServiceDefaultMaxConnectionLimit;
59+
60+
// Create and configure HttpClient for inference requests.
61+
HttpMessageHandler httpMessageHandler = CosmosHttpClientCore.CreateHttpClientHandler(
62+
gatewayModeMaxConnectionLimit: this.inferenceServiceMaxConnectionLimit,
63+
webProxy: null,
64+
serverCertificateCustomValidationCallback: client.DocumentClient.ConnectionPolicy.ServerCertificateCustomValidationCallback);
65+
66+
this.httpClient = new HttpClient(httpMessageHandler);
67+
68+
this.CreateClientHelper(this.httpClient);
69+
70+
// Construct the inference service endpoint URI.
71+
this.inferenceEndpoint = new Uri($"{this.inferenceServiceBaseUrl}/{basePath}");
72+
73+
// Ensure AAD authentication is used.
74+
if (client.DocumentClient.cosmosAuthorization.GetType() != typeof(AuthorizationTokenProviderTokenCredential))
75+
{
76+
throw new InvalidOperationException("InferenceService only supports AAD authentication.");
77+
}
78+
79+
// Set up token credential for authorization.
80+
// This is done to ensure the correct scope, which is different than the scope of the client, is used for the inference service.
81+
AuthorizationTokenProviderTokenCredential defaultOperationTokenProvider = client.DocumentClient.cosmosAuthorization as AuthorizationTokenProviderTokenCredential;
82+
TokenCredential tokenCredential = defaultOperationTokenProvider.tokenCredential;
83+
84+
this.cosmosAuthorization = new AuthorizationTokenProviderTokenCredential(
85+
tokenCredential: tokenCredential,
86+
accountEndpoint: new Uri(inferenceServiceDefaultScope),
87+
backgroundTokenCredentialRefreshInterval: client.ClientOptions?.TokenCredentialBackgroundRefreshInterval);
88+
}
89+
90+
/// <summary>
91+
/// Sends a semantic rerank request to the inference service.
92+
/// </summary>
93+
/// <param name="rerankContext">The context/query for reranking.</param>
94+
/// <param name="documents">The documents to be reranked.</param>
95+
/// <param name="options">Optional additional options for the request.</param>
96+
/// <param name="cancellationToken">Cancellation token.</param>
97+
/// <returns>A dictionary containing the reranked results.</returns>
98+
public async Task<SemanticRerankResult> SemanticRerankAsync(
99+
string rerankContext,
100+
IEnumerable<string> documents,
101+
IDictionary<string, object> options = null,
102+
CancellationToken cancellationToken = default)
103+
{
104+
// Prepare HTTP request for semantic reranking.
105+
HttpRequestMessage message = new HttpRequestMessage(HttpMethod.Post, this.inferenceEndpoint);
106+
INameValueCollection additionalHeaders = new RequestNameValueCollection();
107+
await this.cosmosAuthorization.AddInferenceAuthorizationHeaderAsync(
108+
headersCollection: additionalHeaders,
109+
this.inferenceEndpoint,
110+
HttpConstants.HttpMethods.Post,
111+
AuthorizationTokenType.AadToken);
112+
additionalHeaders.Add(HttpConstants.HttpHeaders.UserAgent, inferenceUserAgent);
113+
114+
// Add all headers to the HTTP request.
115+
foreach (string key in additionalHeaders.AllKeys())
116+
{
117+
message.Headers.Add(key, additionalHeaders[key]);
118+
}
119+
120+
// Build the request payload.
121+
Dictionary<string, object> body = this.AddSemanticRerankPayload(rerankContext, documents, options);
122+
123+
message.Content = new StringContent(
124+
Newtonsoft.Json.JsonConvert.SerializeObject(body),
125+
Encoding.UTF8,
126+
RuntimeConstants.MediaTypes.Json);
127+
128+
// Send the request and ensure success.
129+
HttpResponseMessage responseMessage = await this.httpClient.SendAsync(message, cancellationToken);
130+
responseMessage.EnsureSuccessStatusCode();
131+
132+
// Deserialize and return the response content as a dictionary.
133+
return await SemanticRerankResult.DeserializeSemanticRerankResultAsync(responseMessage);
134+
}
135+
136+
/// <summary>
137+
/// Configures the provided HttpClient with default headers and settings for inference requests.
138+
/// </summary>
139+
/// <param name="httpClient">The HttpClient to configure.</param>
140+
private void CreateClientHelper(HttpClient httpClient)
141+
{
142+
httpClient.Timeout = TimeSpan.FromSeconds(120);
143+
httpClient.DefaultRequestHeaders.CacheControl = new CacheControlHeaderValue { NoCache = true };
144+
145+
// Set requested API version header for version enforcement.
146+
httpClient.DefaultRequestHeaders.Add(HttpConstants.HttpHeaders.Version,
147+
HttpConstants.Versions.CurrentVersion);
148+
149+
httpClient.DefaultRequestHeaders.Add(HttpConstants.HttpHeaders.Accept, RuntimeConstants.MediaTypes.Json);
150+
}
151+
152+
/// <summary>
153+
/// Constructs the payload for the semantic rerank request.
154+
/// </summary>
155+
/// <param name="rerankContext">The context/query for reranking.</param>
156+
/// <param name="documents">The documents to be reranked.</param>
157+
/// <param name="options">Optional additional options.</param>
158+
/// <returns>A dictionary representing the request payload.</returns>
159+
private Dictionary<string, object> AddSemanticRerankPayload(string rerankContext, IEnumerable<string> documents, IDictionary<string, object> options)
160+
{
161+
Dictionary<string, object> payload = new Dictionary<string, object>
162+
{
163+
{ "query", rerankContext },
164+
{ "documents", documents.ToArray() }
165+
};
166+
167+
if (options == null)
168+
{
169+
return payload;
170+
}
171+
172+
// Add any additional options to the payload.
173+
foreach (string option in options.Keys)
174+
{
175+
payload.Add(option, options[option]);
176+
}
177+
178+
return payload;
179+
}
180+
181+
/// <summary>
182+
/// Disposes managed resources used by the service.
183+
/// </summary>
184+
/// <param name="disposing">Indicates if called from Dispose.</param>
185+
protected void Dispose(bool disposing)
186+
{
187+
if (!this.disposedValue)
188+
{
189+
if (disposing)
190+
{
191+
this.httpClient.Dispose();
192+
this.cosmosAuthorization.Dispose();
193+
this.httpClient = null;
194+
this.cosmosAuthorization = null;
195+
}
196+
197+
this.disposedValue = true;
198+
}
199+
}
200+
201+
/// <summary>
202+
/// Disposes the service and its resources.
203+
/// </summary>
204+
public void Dispose()
205+
{
206+
this.Dispose(true);
207+
}
208+
}
209+
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
//------------------------------------------------------------
2+
// Copyright (c) Microsoft Corporation. All rights reserved.
3+
//------------------------------------------------------------
4+
5+
namespace Microsoft.Azure.Cosmos
6+
{
7+
/// <summary>
8+
/// Represents the score assigned to a document after a reranking operation.
9+
/// </summary>
10+
#if PREVIEW
11+
public
12+
#else
13+
internal
14+
#endif
15+
16+
class RerankScore
17+
{
18+
/// <summary>
19+
/// Gets the document content or identifier that was reranked.
20+
/// </summary>
21+
public object Document { get; }
22+
23+
/// <summary>
24+
/// Gets the score assigned to the document after reranking.
25+
/// </summary>
26+
public double Score { get; }
27+
28+
/// <summary>
29+
/// Gets the original index or position of the document before reranking.
30+
/// </summary>
31+
public int Index { get; }
32+
33+
/// <summary>
34+
/// Initializes a new instance of the <see cref="RerankScore"/> class.
35+
/// </summary>
36+
/// <param name="document">The document content or identifier.</param>
37+
/// <param name="score">The reranked score for the document.</param>
38+
/// <param name="index">The original index of the document.</param>
39+
public RerankScore(object document, double score, int index)
40+
{
41+
this.Document = document;
42+
this.Score = score;
43+
this.Index = index;
44+
}
45+
}
46+
}

0 commit comments

Comments
 (0)