|
| 1 | +using System; |
| 2 | +using System.Threading.Tasks; |
| 3 | +using System.Net.Http; |
| 4 | +using System.Net.Http.Headers; |
| 5 | +using System.Collections.Generic; |
| 6 | +using NSec.Cryptography; |
| 7 | +using System.Text; |
| 8 | + |
| 9 | +namespace Com.RelationalAI |
| 10 | +{ |
| 11 | + /// <summary>Class <c>ClientCredentialsService</c> is used to get Access Token from authentication API for SDK access on RAICloud services.</summary> |
| 12 | + /// <remarks> It implements the singleton pattern to provide a single object to all the classes in the SDK. |
| 13 | + /// It keeps a Dictionary based cache of Access Tokens. A dictionary has been used to enable the service to support multiple tenants/connections/clouds |
| 14 | + /// It keeps track of the token generation and expiration time and only grabs a new AccessToken when the cached Token is expired. |
| 15 | + /// Currently the cached and/or expired tokens are only evicted when the consumer will call the GetAccessToken method. |
| 16 | + /// </remarks> |
| 17 | + class ClientCredentialsService |
| 18 | + { |
| 19 | + // Private constructor for singleton |
| 20 | + private ClientCredentialsService(){} |
| 21 | + |
| 22 | + // Singleton instance of ClientCredentialsService |
| 23 | + private static ClientCredentialsService instance; |
| 24 | + |
| 25 | + // Constants |
| 26 | + private const string ACCESS_TOKEN_KEY = "access_token"; |
| 27 | + private const string EXPIRES_IN_KEY = "expires_in"; |
| 28 | + private const string CLIENT_ID_KEY = "client_id"; |
| 29 | + private const string CLIENT_SECRET_KEY = "client_secret"; |
| 30 | + private const string AUDIENCE_KEY = "audience"; |
| 31 | + private const string GRANT_TYPE_KEY = "grant_type"; |
| 32 | + private const string CLIENT_CREDENTIALS_KEY = "client_credentials"; |
| 33 | + |
| 34 | + // Locking object for GetInstance class |
| 35 | + private static readonly object syncLock = new object(); |
| 36 | + |
| 37 | + // Authentication API URL Prefix to build the URI |
| 38 | + private static readonly string API_URL_PREFIX = "https://login"; |
| 39 | + |
| 40 | + // Authentication API URL Postfix to build the URI |
| 41 | + private static readonly string API_URL_POSTFIX = ".relationalai.com/oauth/token"; |
| 42 | + |
| 43 | + |
| 44 | + // Dictionary to hold Access Tokens. Using Dictionary to support multiple tenants/connections from the SDK. |
| 45 | + private Dictionary<string, AccessToken> accessTokenCache = new Dictionary<string, AccessToken>(); |
| 46 | + |
| 47 | + /// <summary> Gets the singleton instance of <c>ClientCredentialsService</c> </summary> |
| 48 | + /// <remarks>Thread Safety Singleton using Double-Check Locking </remarks> |
| 49 | + /// <return> <c> ClientCredentialsService</c>.<return> |
| 50 | + public static ClientCredentialsService Instance |
| 51 | + { |
| 52 | + get |
| 53 | + { |
| 54 | + if (instance == null) |
| 55 | + { |
| 56 | + lock (syncLock) |
| 57 | + { |
| 58 | + if (instance == null) { |
| 59 | + instance = new ClientCredentialsService(); |
| 60 | + } |
| 61 | + } |
| 62 | + } |
| 63 | + return instance; |
| 64 | + } |
| 65 | + } |
| 66 | + |
| 67 | + /// <summary> Gets Access Token from authentication API. </summary> |
| 68 | + /// <example> For example: |
| 69 | + /// <code> |
| 70 | + /// ClientCredentialsService.Instance.GetAccessToken(credentials, host); |
| 71 | + /// </code> |
| 72 | + /// results in <c>string</c> Access Token for SDK authentication. |
| 73 | + /// </example> |
| 74 | + /// <param name="credentials">RAICredentials Object. Contains ClientId and ClientSecret from ~/.rai/config</param> |
| 75 | + /// <param name="host">Host value from ~/.rai/config</param> |
| 76 | + /// <exception> Throws ClientCredentialsException if failed to get the access token from remote API. </exception> |
| 77 | + /// <remarks> This function will throw exception in the following scenarios |
| 78 | + /// 1. Client id and/or client secret is wrong. |
| 79 | + /// 2. Client id does not have permission on the API. |
| 80 | + /// 3. Access token generation quota has been exhausted. |
| 81 | + /// 4. Any network communication issue. |
| 82 | + /// 5. The remote API or the audience has been renamed or does not exist. |
| 83 | + /// 6. If the host-name/url is not in proper format. |
| 84 | + /// </remarks> |
| 85 | + public string GetAccessToken(RAICredentials credentials, string host) |
| 86 | + { |
| 87 | + // Create the cache retrieval key. |
| 88 | + // It is a concatenation of client ID and audience for supporting |
| 89 | + // a client with multiple domains. |
| 90 | + string cacheKey = GetCacheKey(credentials.ClientId, host); |
| 91 | + |
| 92 | + // Check if there is already a valid access token is present in the cache. |
| 93 | + AccessToken accessToken = GetValidAccessTokenFromCache(cacheKey); |
| 94 | + // If there is valid/un-expired token, then don't get a new one, just return the stored token. |
| 95 | + if(accessToken != null) |
| 96 | + { |
| 97 | + return accessToken.Token; |
| 98 | + } |
| 99 | + string normalizedHostName = host.StartsWith("https://") ? host : ("https://" + host); |
| 100 | + // Get the new access token from the remote API. |
| 101 | + string apiResult = GetAccessTokenInternal(credentials.ClientId, credentials.ClientScrt, normalizedHostName, GetApiUriFromHost(host)).GetAwaiter().GetResult(); |
| 102 | + // Convert the JSON result into a dictionary to grab the access token and expiration. |
| 103 | + Dictionary<string, string> result = (Dictionary<string, string>) Newtonsoft.Json.JsonConvert.DeserializeObject(apiResult, typeof(Dictionary<string, string>)); |
| 104 | + if(result != null && result.Count > 0) |
| 105 | + { |
| 106 | + // Add the Access Token object in the cache. |
| 107 | + accessTokenCache.Add(cacheKey, new AccessToken(result[ACCESS_TOKEN_KEY], long.Parse(result[EXPIRES_IN_KEY]))); |
| 108 | + // Return the Access Token |
| 109 | + return result[ACCESS_TOKEN_KEY]; |
| 110 | + } |
| 111 | + // Throw ClientCredentialsException because we have failed to get one. |
| 112 | + throw new ClientCredentialsException("Failed to get Access-Token from the remote API"); |
| 113 | + } |
| 114 | + |
| 115 | + /// <summary> Removes a cached access token from the cache. </summary> |
| 116 | + /// <param name="credentials">RAICredentials Object. Contains ClientId and ClientSecret from ~/.rai/config</param> |
| 117 | + /// <param name="host">Host value from ~/.rai/config</param> |
| 118 | + public void InvalidateCache(RAICredentials credentials, string host) |
| 119 | + { |
| 120 | + if(credentials != null) |
| 121 | + { |
| 122 | + string cacheKey = GetCacheKey(credentials.ClientId, host); |
| 123 | + // Do not need to verify if the key is successfully removed or not? |
| 124 | + // In case if the key is not then Remove will return false |
| 125 | + // This won't throw exception unless the key is null. |
| 126 | + accessTokenCache.Remove(cacheKey); |
| 127 | + } |
| 128 | + } |
| 129 | + |
| 130 | + /// <summary> Gets Access Token from authentication API.</summary> |
| 131 | + /// <param name="clientId">client_id as mentioned in the ~/.rai/config</param> |
| 132 | + /// <param name="clientSecret">client_secret value from ~/.rai/config</param> |
| 133 | + /// <param name="audience">The token token audience/target API (Machine to Machine Application API)</param> |
| 134 | + /// <param name="apiUrl">Auth token API endpoint.</param> |
| 135 | + /// <exception> Throws ClientCredentialsException if failed to get the access token from remote API. </exception> |
| 136 | + /// <remarks> This function will throw exception in the following scenarios, |
| 137 | + /// 1. Client id and/or client secret is wrong. |
| 138 | + // 2. Client id does not have permission on the API. |
| 139 | + /// 3. Access token generation quota has been exhausted. |
| 140 | + /// 4. Any network communication issue. |
| 141 | + /// 5. The remote API or the audience has been renamed or does not exist. |
| 142 | + /// </remarks> |
| 143 | + /// <return> Access token response as <c>string</c>.</return> |
| 144 | + private async Task<string> GetAccessTokenInternal(string clientId, string clientSecret, string audience, Uri apiUrl) |
| 145 | + { |
| 146 | + // Form the API request body. |
| 147 | + string body = "{\"" + CLIENT_ID_KEY + "\":\""+ clientId + "\",\"" + CLIENT_SECRET_KEY + "\":\"" + clientSecret |
| 148 | + + "\",\"" + AUDIENCE_KEY + "\":\"" + audience + "\",\"" + GRANT_TYPE_KEY + "\":\"" + CLIENT_CREDENTIALS_KEY + "\"}"; |
| 149 | + |
| 150 | + //Define the content object |
| 151 | + var content = new System.Net.Http.StringContent(body); |
| 152 | + try |
| 153 | + { |
| 154 | + // Create HTTP client to send the POST request |
| 155 | + // Using block will destroy the HTTP client automatically |
| 156 | + using (var client = new HttpClient()) |
| 157 | + { |
| 158 | + // Set the API url |
| 159 | + client.BaseAddress = apiUrl; |
| 160 | + // Create the POST request |
| 161 | + var request = new HttpRequestMessage(new HttpMethod("POST"), client.BaseAddress); |
| 162 | + // Set content in the request. |
| 163 | + request.Content = content; |
| 164 | + // Set the content type. |
| 165 | + request.Content.Headers.ContentType = MediaTypeHeaderValue.Parse("application/json"); |
| 166 | + // Set the Accepted Media Type as the response. |
| 167 | + request.Headers.Accept.Add(System.Net.Http.Headers.MediaTypeWithQualityHeaderValue.Parse("application/json")); |
| 168 | + // Get the result back or throws an exception. |
| 169 | + var result = await client.SendAsync(request); |
| 170 | + return await result.Content.ReadAsStringAsync(); |
| 171 | + } |
| 172 | + } |
| 173 | + catch(Exception e) |
| 174 | + { |
| 175 | + // Wrap exception as ClientCredentialsException and throw it. |
| 176 | + throw new ClientCredentialsException(e.Message, e); |
| 177 | + } |
| 178 | + } |
| 179 | + |
| 180 | + /// <summary>Gets a key to store AccessToken in the cache.</summary> |
| 181 | + /// <param name="clientID">client_id as mentioned in the ~/.rai/config</param> |
| 182 | + /// <param name="audience">host value from ~/.rai/config</param> |
| 183 | + /// <remarks>Key is the concatenation of client ID and audience fields</remarks> |
| 184 | + /// <return> Cache key as <c>string</c>.</return> |
| 185 | + private static string GetCacheKey(string clientID, string audience) |
| 186 | + { |
| 187 | + return String.Format("{0}:{1}", clientID, audience); |
| 188 | + } |
| 189 | + |
| 190 | + /// <summary> Gets a valid un-expired Access Token from the cache</summary> |
| 191 | + /// <param name="cacheKey">Cache Key</param> |
| 192 | + /// <return> <c>AccessToken</c> object if an un-expired token is present in the cache. Otherwise, will return Null. </return> |
| 193 | + private AccessToken GetValidAccessTokenFromCache(string cacheKey) |
| 194 | + { |
| 195 | + if(accessTokenCache.ContainsKey(cacheKey)) |
| 196 | + { |
| 197 | + AccessToken accessToken = accessTokenCache[cacheKey]; |
| 198 | + if(!accessToken.IsExpired()) |
| 199 | + { |
| 200 | + return accessToken; |
| 201 | + } |
| 202 | + accessTokenCache.Remove(cacheKey); |
| 203 | + } |
| 204 | + return null; |
| 205 | + } |
| 206 | + |
| 207 | + /// <summary> Formulates the authentication API endpoint from the host value in ~/.rai/config </summary> |
| 208 | + /// <param name="host">Value of host as mentioned in the ~/.rai/config</param> |
| 209 | + /// <example>host=azure-env.relationalai.com </example> |
| 210 | + /// <exception>Will throw exception if the host name/FQDN is not properly defined.</exception> |
| 211 | + /// <remarks> |
| 212 | + /// The Production API Url will be registered with authentication service as https://login.relationalai.com/auth/token |
| 213 | + /// Dev and/or staging API Urls will be registered as https://login-env.relationalai.com/oauth/token. |
| 214 | + /// This function will check for a -env in the host field. If the host is for some dev or stanging environment |
| 215 | + /// then it will return the API Url for the environment otherwise it will return the production API Url. |
| 216 | + /// </remarks> |
| 217 | + /// <return> API Url as <c>Uri</c> object.</return> |
| 218 | + private static Uri GetApiUriFromHost(string host) |
| 219 | + { |
| 220 | + string environment = ""; |
| 221 | + // Search for hyphen, which means the host is some dev or staging environment. |
| 222 | + // If hyphen is present then extract the environment name using IndexOf and Substring function |
| 223 | + // of the string class. |
| 224 | + if(host.Contains("-")) |
| 225 | + { |
| 226 | + int hyphenStart = host.IndexOf('-'); |
| 227 | + int indexOfDot = host.IndexOf('.', hyphenStart + 1); |
| 228 | + if(indexOfDot >= 0) |
| 229 | + { |
| 230 | + environment = host.Substring(hyphenStart + 1, indexOfDot - (hyphenStart + 1)); |
| 231 | + } |
| 232 | + else |
| 233 | + { |
| 234 | + environment = host.Substring(hyphenStart + 1); |
| 235 | + } |
| 236 | + } |
| 237 | + |
| 238 | + // Return API Url for either production or for an environment. |
| 239 | + if(environment != "") |
| 240 | + { |
| 241 | + return new Uri(String.Format("{0}-{1}{2}", API_URL_PREFIX, environment, API_URL_POSTFIX)); |
| 242 | + } |
| 243 | + |
| 244 | + return new Uri(String.Format("{0}{1}", API_URL_PREFIX, API_URL_POSTFIX)); |
| 245 | + } |
| 246 | + } |
| 247 | + |
| 248 | + /// <summary> This class is used to store the AccessToken Object in the cache. </summary> |
| 249 | + class AccessToken |
| 250 | + { |
| 251 | + public string Token { get; } |
| 252 | + public long ExpiresIn { get; } |
| 253 | + public DateTime TimeAcquired { get; } |
| 254 | + |
| 255 | + public AccessToken(string accessToken, long expiresIn) |
| 256 | + { |
| 257 | + Token = accessToken; |
| 258 | + ExpiresIn = expiresIn; |
| 259 | + TimeAcquired = DateTime.Now; |
| 260 | + } |
| 261 | + |
| 262 | + /// <summary> Checks if a Token has been expired or not? </summary> |
| 263 | + public bool IsExpired() |
| 264 | + { |
| 265 | + TimeSpan timeSpan = DateTime.Now - TimeAcquired; |
| 266 | + return (long)timeSpan.TotalSeconds >= ExpiresIn; |
| 267 | + } |
| 268 | + } |
| 269 | +} |
0 commit comments