Skip to content

Commit 1787def

Browse files
committed
sync with master
1 parent edf39fc commit 1787def

File tree

13 files changed

+571
-121
lines changed

13 files changed

+571
-121
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
This is a Client SDK for RelationalAI
44

5-
- API version: 1.2.2
5+
- API version: 1.2.3
66

77
## Frameworks supported
88

RelationalAI/AuthType.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
namespace Com.RelationalAI
2+
{
3+
public enum AuthType
4+
{
5+
ACCESS_KEY,
6+
CLIENT_CREDENTIALS
7+
8+
}
9+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
using System;
2+
namespace Com.RelationalAI
3+
{
4+
5+
/// <summary> Class to describe Access Token retrieval exception. </summary>
6+
public class ClientCredentialsException : Exception
7+
{
8+
public ClientCredentialsException() { }
9+
public ClientCredentialsException(string message) : base(message) { }
10+
public ClientCredentialsException(string message, System.Exception inner) : base(message, inner) { }
11+
protected ClientCredentialsException(
12+
System.Runtime.Serialization.SerializationInfo info,
13+
System.Runtime.Serialization.StreamingContext context) : base(info, context) { }
14+
}
15+
}
Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
using System;
2+
using System.Threading.Tasks;
3+
using System.Net.Http;
4+
using System.Net.Http.Headers;
5+
using System.Collections.Generic;
6+
using NSec.Cryptography;
7+
using System.Text;
8+
9+
namespace Com.RelationalAI
10+
{
11+
/// <summary>Class <c>ClientCredentialsService</c> is used to get Access Token from authentication API for SDK access on RAICloud services.</summary>
12+
/// <remarks> It implements the singleton pattern to provide a single object to all the classes in the SDK.
13+
/// It keeps a Dictionary based cache of Access Tokens. A dictionary has been used to enable the service to support multiple tenants/connections/clouds
14+
/// It keeps track of the token generation and expiration time and only grabs a new AccessToken when the cached Token is expired.
15+
/// Currently the cached and/or expired tokens are only evicted when the consumer will call the GetAccessToken method.
16+
/// </remarks>
17+
class ClientCredentialsService
18+
{
19+
// Private constructor for singleton
20+
private ClientCredentialsService(){}
21+
22+
// Singleton instance of ClientCredentialsService
23+
private static ClientCredentialsService instance;
24+
25+
// Constants
26+
private const string ACCESS_TOKEN_KEY = "access_token";
27+
private const string EXPIRES_IN_KEY = "expires_in";
28+
private const string CLIENT_ID_KEY = "client_id";
29+
private const string CLIENT_SECRET_KEY = "client_secret";
30+
private const string AUDIENCE_KEY = "audience";
31+
private const string GRANT_TYPE_KEY = "grant_type";
32+
private const string CLIENT_CREDENTIALS_KEY = "client_credentials";
33+
34+
// Locking object for GetInstance class
35+
private static readonly object syncLock = new object();
36+
37+
// Authentication API URL Prefix to build the URI
38+
private static readonly string API_URL_PREFIX = "https://login";
39+
40+
// Authentication API URL Postfix to build the URI
41+
private static readonly string API_URL_POSTFIX = ".relationalai.com/oauth/token";
42+
43+
44+
// Dictionary to hold Access Tokens. Using Dictionary to support multiple tenants/connections from the SDK.
45+
private Dictionary<string, AccessToken> accessTokenCache = new Dictionary<string, AccessToken>();
46+
47+
/// <summary> Gets the singleton instance of <c>ClientCredentialsService</c> </summary>
48+
/// <remarks>Thread Safety Singleton using Double-Check Locking </remarks>
49+
/// <return> <c> ClientCredentialsService</c>.<return>
50+
public static ClientCredentialsService Instance
51+
{
52+
get
53+
{
54+
if (instance == null)
55+
{
56+
lock (syncLock)
57+
{
58+
if (instance == null) {
59+
instance = new ClientCredentialsService();
60+
}
61+
}
62+
}
63+
return instance;
64+
}
65+
}
66+
67+
/// <summary> Gets Access Token from authentication API. </summary>
68+
/// <example> For example:
69+
/// <code>
70+
/// ClientCredentialsService.Instance.GetAccessToken(credentials, host);
71+
/// </code>
72+
/// results in <c>string</c> Access Token for SDK authentication.
73+
/// </example>
74+
/// <param name="credentials">RAICredentials Object. Contains ClientId and ClientSecret from ~/.rai/config</param>
75+
/// <param name="host">Host value from ~/.rai/config</param>
76+
/// <exception> Throws ClientCredentialsException if failed to get the access token from remote API. </exception>
77+
/// <remarks> This function will throw exception in the following scenarios
78+
/// 1. Client id and/or client secret is wrong.
79+
/// 2. Client id does not have permission on the API.
80+
/// 3. Access token generation quota has been exhausted.
81+
/// 4. Any network communication issue.
82+
/// 5. The remote API or the audience has been renamed or does not exist.
83+
/// 6. If the host-name/url is not in proper format.
84+
/// </remarks>
85+
public string GetAccessToken(RAICredentials credentials, string host)
86+
{
87+
// Create the cache retrieval key.
88+
// It is a concatenation of client ID and audience for supporting
89+
// a client with multiple domains.
90+
string cacheKey = GetCacheKey(credentials.ClientId, host);
91+
92+
// Check if there is already a valid access token is present in the cache.
93+
AccessToken accessToken = GetValidAccessTokenFromCache(cacheKey);
94+
// If there is valid/un-expired token, then don't get a new one, just return the stored token.
95+
if(accessToken != null)
96+
{
97+
return accessToken.Token;
98+
}
99+
string normalizedHostName = host.StartsWith("https://") ? host : ("https://" + host);
100+
// Get the new access token from the remote API.
101+
string apiResult = GetAccessTokenInternal(credentials.ClientId, credentials.ClientScrt, normalizedHostName, GetApiUriFromHost(host)).GetAwaiter().GetResult();
102+
// Convert the JSON result into a dictionary to grab the access token and expiration.
103+
Dictionary<string, string> result = (Dictionary<string, string>) Newtonsoft.Json.JsonConvert.DeserializeObject(apiResult, typeof(Dictionary<string, string>));
104+
if(result != null && result.Count > 0)
105+
{
106+
// Add the Access Token object in the cache.
107+
accessTokenCache.Add(cacheKey, new AccessToken(result[ACCESS_TOKEN_KEY], long.Parse(result[EXPIRES_IN_KEY])));
108+
// Return the Access Token
109+
return result[ACCESS_TOKEN_KEY];
110+
}
111+
// Throw ClientCredentialsException because we have failed to get one.
112+
throw new ClientCredentialsException("Failed to get Access-Token from the remote API");
113+
}
114+
115+
/// <summary> Removes a cached access token from the cache. </summary>
116+
/// <param name="credentials">RAICredentials Object. Contains ClientId and ClientSecret from ~/.rai/config</param>
117+
/// <param name="host">Host value from ~/.rai/config</param>
118+
public void InvalidateCache(RAICredentials credentials, string host)
119+
{
120+
if(credentials != null)
121+
{
122+
string cacheKey = GetCacheKey(credentials.ClientId, host);
123+
// Do not need to verify if the key is successfully removed or not?
124+
// In case if the key is not then Remove will return false
125+
// This won't throw exception unless the key is null.
126+
accessTokenCache.Remove(cacheKey);
127+
}
128+
}
129+
130+
/// <summary> Gets Access Token from authentication API.</summary>
131+
/// <param name="clientId">client_id as mentioned in the ~/.rai/config</param>
132+
/// <param name="clientSecret">client_secret value from ~/.rai/config</param>
133+
/// <param name="audience">The token token audience/target API (Machine to Machine Application API)</param>
134+
/// <param name="apiUrl">Auth token API endpoint.</param>
135+
/// <exception> Throws ClientCredentialsException if failed to get the access token from remote API. </exception>
136+
/// <remarks> This function will throw exception in the following scenarios,
137+
/// 1. Client id and/or client secret is wrong.
138+
// 2. Client id does not have permission on the API.
139+
/// 3. Access token generation quota has been exhausted.
140+
/// 4. Any network communication issue.
141+
/// 5. The remote API or the audience has been renamed or does not exist.
142+
/// </remarks>
143+
/// <return> Access token response as <c>string</c>.</return>
144+
private async Task<string> GetAccessTokenInternal(string clientId, string clientSecret, string audience, Uri apiUrl)
145+
{
146+
// Form the API request body.
147+
string body = "{\"" + CLIENT_ID_KEY + "\":\""+ clientId + "\",\"" + CLIENT_SECRET_KEY + "\":\"" + clientSecret
148+
+ "\",\"" + AUDIENCE_KEY + "\":\"" + audience + "\",\"" + GRANT_TYPE_KEY + "\":\"" + CLIENT_CREDENTIALS_KEY + "\"}";
149+
150+
//Define the content object
151+
var content = new System.Net.Http.StringContent(body);
152+
try
153+
{
154+
// Create HTTP client to send the POST request
155+
// Using block will destroy the HTTP client automatically
156+
using (var client = new HttpClient())
157+
{
158+
// Set the API url
159+
client.BaseAddress = apiUrl;
160+
// Create the POST request
161+
var request = new HttpRequestMessage(new HttpMethod("POST"), client.BaseAddress);
162+
// Set content in the request.
163+
request.Content = content;
164+
// Set the content type.
165+
request.Content.Headers.ContentType = MediaTypeHeaderValue.Parse("application/json");
166+
// Set the Accepted Media Type as the response.
167+
request.Headers.Accept.Add(System.Net.Http.Headers.MediaTypeWithQualityHeaderValue.Parse("application/json"));
168+
// Get the result back or throws an exception.
169+
var result = await client.SendAsync(request);
170+
return await result.Content.ReadAsStringAsync();
171+
}
172+
}
173+
catch(Exception e)
174+
{
175+
// Wrap exception as ClientCredentialsException and throw it.
176+
throw new ClientCredentialsException(e.Message, e);
177+
}
178+
}
179+
180+
/// <summary>Gets a key to store AccessToken in the cache.</summary>
181+
/// <param name="clientID">client_id as mentioned in the ~/.rai/config</param>
182+
/// <param name="audience">host value from ~/.rai/config</param>
183+
/// <remarks>Key is the concatenation of client ID and audience fields</remarks>
184+
/// <return> Cache key as <c>string</c>.</return>
185+
private static string GetCacheKey(string clientID, string audience)
186+
{
187+
return String.Format("{0}:{1}", clientID, audience);
188+
}
189+
190+
/// <summary> Gets a valid un-expired Access Token from the cache</summary>
191+
/// <param name="cacheKey">Cache Key</param>
192+
/// <return> <c>AccessToken</c> object if an un-expired token is present in the cache. Otherwise, will return Null. </return>
193+
private AccessToken GetValidAccessTokenFromCache(string cacheKey)
194+
{
195+
if(accessTokenCache.ContainsKey(cacheKey))
196+
{
197+
AccessToken accessToken = accessTokenCache[cacheKey];
198+
if(!accessToken.IsExpired())
199+
{
200+
return accessToken;
201+
}
202+
accessTokenCache.Remove(cacheKey);
203+
}
204+
return null;
205+
}
206+
207+
/// <summary> Formulates the authentication API endpoint from the host value in ~/.rai/config </summary>
208+
/// <param name="host">Value of host as mentioned in the ~/.rai/config</param>
209+
/// <example>host=azure-env.relationalai.com </example>
210+
/// <exception>Will throw exception if the host name/FQDN is not properly defined.</exception>
211+
/// <remarks>
212+
/// The Production API Url will be registered with authentication service as https://login.relationalai.com/auth/token
213+
/// Dev and/or staging API Urls will be registered as https://login-env.relationalai.com/oauth/token.
214+
/// This function will check for a -env in the host field. If the host is for some dev or stanging environment
215+
/// then it will return the API Url for the environment otherwise it will return the production API Url.
216+
/// </remarks>
217+
/// <return> API Url as <c>Uri</c> object.</return>
218+
private static Uri GetApiUriFromHost(string host)
219+
{
220+
string environment = "";
221+
// Search for hyphen, which means the host is some dev or staging environment.
222+
// If hyphen is present then extract the environment name using IndexOf and Substring function
223+
// of the string class.
224+
if(host.Contains("-"))
225+
{
226+
int hyphenStart = host.IndexOf('-');
227+
int indexOfDot = host.IndexOf('.', hyphenStart + 1);
228+
if(indexOfDot >= 0)
229+
{
230+
environment = host.Substring(hyphenStart + 1, indexOfDot - (hyphenStart + 1));
231+
}
232+
else
233+
{
234+
environment = host.Substring(hyphenStart + 1);
235+
}
236+
}
237+
238+
// Return API Url for either production or for an environment.
239+
if(environment != "")
240+
{
241+
return new Uri(String.Format("{0}-{1}{2}", API_URL_PREFIX, environment, API_URL_POSTFIX));
242+
}
243+
244+
return new Uri(String.Format("{0}{1}", API_URL_PREFIX, API_URL_POSTFIX));
245+
}
246+
}
247+
248+
/// <summary> This class is used to store the AccessToken Object in the cache. </summary>
249+
class AccessToken
250+
{
251+
public string Token { get; }
252+
public long ExpiresIn { get; }
253+
public DateTime TimeAcquired { get; }
254+
255+
public AccessToken(string accessToken, long expiresIn)
256+
{
257+
Token = accessToken;
258+
ExpiresIn = expiresIn;
259+
TimeAcquired = DateTime.Now;
260+
}
261+
262+
/// <summary> Checks if a Token has been expired or not? </summary>
263+
public bool IsExpired()
264+
{
265+
TimeSpan timeSpan = DateTime.Now - TimeAcquired;
266+
return (long)timeSpan.TotalSeconds >= ExpiresIn;
267+
}
268+
}
269+
}

RelationalAI/KGMSClient.cs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
using System.Net;
88
using System.Net.Http;
99
using System.Net.Http.Headers;
10-
using System.Text;
1110
using System.Threading;
1211
using System.Threading.Tasks;
1312
using System.Web;
@@ -22,7 +21,7 @@ public partial class GeneratedRelationalAIClient
2221

2322
public const string JSON_CONTENT_TYPE = "application/json";
2423
public const string CSV_CONTENT_TYPE = "text/csv";
25-
public const string USER_AGENT_HEADER = "KGMSClient/1.2.2/csharp";
24+
public const string USER_AGENT_HEADER = "KGMSClient/1.2.3/csharp";
2625

2726
public int DebugLevel = Connection.DEFAULT_DEBUG_LEVEL;
2827

@@ -72,9 +71,9 @@ partial void PrepareRequest(Transaction body, HttpClient client, HttpRequestMess
7271
//Set the content type header
7372
request.Content.Headers.ContentType = MediaTypeHeaderValue.Parse("application/json; charset=utf-8");
7473

75-
// sign request here
74+
// Set Auth here
7675
var raiRequest = new RAIRequest(request, conn);
77-
raiRequest.Sign(debugLevel: DebugLevel);
76+
raiRequest.SetAuth();
7877
KGMSClient.AddExtraHeaders(request);
7978

8079
// use HTTP 2.0 (to handle keep-alive)

RelationalAI/ManagementClient.cs

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ partial void PrepareRequest(System.Net.Http.HttpClient client, System.Net.Http.H
6767
request.Content.Headers.ContentType = MediaTypeHeaderValue.Parse("application/json");
6868

6969
RAIRequest raiReq = new RAIRequest(request, conn);
70-
raiReq.Sign();
70+
raiReq.SetAuth();
7171
KGMSClient.AddExtraHeaders(request);
7272
}
7373
}
@@ -86,6 +86,7 @@ public ManagementClient(Connection conn) : base(KGMSClient.GetHttpClient(conn.Ba
8686
this.conn = conn;
8787
this.conn.CloudClient = this;
8888
this.BaseUrl = conn.BaseUrl.ToString();
89+
System.AppDomain.CurrentDomain.UnhandledException += GlobalExceptionHandler;
8990
}
9091

9192
public ICollection<ComputeInfoProtocol> ListComputes(RAIComputeFilters filters = null)
@@ -182,5 +183,25 @@ public GetAccountCreditsResponse GetAccountCreditUsage(Period period=Period.Curr
182183
{
183184
return this.AccountCreditsGetAsync(period).Result;
184185
}
186+
187+
///<summary> This global exception handler will be invoked in case of any exception.
188+
/// It can be used for multiple purposes, like logging. But, currently it is being
189+
/// used to invalidate the Client Credentials Cache.
190+
/// </summary>
191+
private void GlobalExceptionHandler(object sender, UnhandledExceptionEventArgs e) {
192+
if (e.ExceptionObject is Exception)
193+
{
194+
Exception exception = (Exception)e.ExceptionObject;
195+
if(exception.InnerException is ApiException
196+
&& conn.Creds.AuthType == AuthType.CLIENT_CREDENTIALS)
197+
{
198+
ApiException apiException = (ApiException)exception.InnerException;
199+
if(apiException.StatusCode == 400 || apiException.StatusCode == 401)
200+
{
201+
ClientCredentialsService.Instance.InvalidateCache(conn.Creds, conn.Host);
202+
}
203+
}
204+
}
205+
}
185206
}
186207
}

0 commit comments

Comments
 (0)