Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Identity.Client.Core;

namespace Microsoft.Identity.Client.ManagedIdentity.V2
{
internal interface IMtlsBindingCache
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add XML docs here to be consistent with the rest of this PR?

{
Task<Tuple<X509Certificate2, string /*endpoint*/, string /*clientId*/>> GetOrCreateAsync(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you create a named tuple?

internal class MtlsBindingInfo
{
    public X509Certificate2 Certificate { get; set; }
    public string Endpoint { get; set; }
    public string ClientId { get; set; }
}

string cacheKey,
Func<Task<Tuple<X509Certificate2, string, string>>> factory,
CancellationToken cancellationToken,
ILoggerAdapter logger);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System.Security.Cryptography.X509Certificates;
using Microsoft.Identity.Client.Core;

namespace Microsoft.Identity.Client.ManagedIdentity.V2
{
/// <summary>
/// Persistence interface for IMDSv2 mTLS binding certificates.
/// Implementations must be best-effort and non-throwing.
/// </summary>
internal interface IPersistentCertificateCache
{
/// <summary>
/// Reads the newest valid (≥24h remaining, has private key) entry for the alias.
/// </summary>
bool Read(string alias, out CertificateCacheValue value, ILoggerAdapter logger = null);

/// <summary>
/// Persists the certificate for the alias (best-effort). Implementations should
/// tag entries to allow alias scoping and prune expired duplicates conservatively.
/// </summary>
void Write(string alias, X509Certificate2 cert, string endpointBase, ILoggerAdapter logger = null);

/// <summary>
/// Prunes expired entries for the alias (best-effort).
/// </summary>
void Delete(string alias, ILoggerAdapter logger = null);
Comment on lines 27 to 34
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should these return bool, like read, to indicate success or failure?

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ internal class ImdsV2ManagedIdentitySource : AbstractManagedIdentity
// Central, process-local cache for mTLS binding (cert + endpoint + canonical client_id).
internal static readonly ICertificateCache s_mtlsCertificateCache = new InMemoryCertificateCache();

// Per-key async de-duplication so concurrent callers don’t double-mint.
internal static readonly ConcurrentDictionary<string, SemaphoreSlim> s_perKeyGates =
new ConcurrentDictionary<string, SemaphoreSlim>(StringComparer.Ordinal);
private readonly IMtlsBindingCache _mtlsCache;

// used in unit tests
public const string ImdsV2ApiVersion = "2.0";
Expand Down Expand Up @@ -195,7 +193,12 @@ public static AbstractManagedIdentity Create(RequestContext requestContext)

internal ImdsV2ManagedIdentitySource(RequestContext requestContext) :
base(requestContext, ManagedIdentitySource.ImdsV2)
{ }
{
IPersistentCertificateCache persisted =
PersistentCertificateCacheFactory.Create(requestContext.Logger);

_mtlsCache = new MtlsBindingCache(s_mtlsCertificateCache, persisted);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use dependency injection instead of putting this in the constructor? This should enhance testability.


private async Task<CertificateRequestResponse> ExecuteCertificateRequestAsync(
string clientId,
Expand Down Expand Up @@ -440,65 +443,13 @@ private async Task<string> GetAttestationJwtAsync(
return response.AttestationToken;
}

// ...unchanged usings and class header...

/// <summary>
/// Read-through cache: try cache; if missing, run async factory once (per key),
/// store the result, and return it. Thread-safe for the given cacheKey.
/// </summary>
private static async Task<Tuple<X509Certificate2, string, string>> GetOrCreateMtlsBindingAsync(
private Task<Tuple<X509Certificate2, string, string>> GetOrCreateMtlsBindingAsync(
string cacheKey,
Func<Task<Tuple<X509Certificate2, string, string>>> factory,
CancellationToken cancellationToken,
ILoggerAdapter logger)
{
if (string.IsNullOrWhiteSpace(cacheKey))
throw new ArgumentException("cacheKey must be non-empty.", nameof(cacheKey));
if (factory is null)
throw new ArgumentNullException(nameof(factory));

X509Certificate2 cachedCertificate;
string cachedEndpointBase;
string cachedClientId;

// 1) Only lookup by cacheKey
if (s_mtlsCertificateCache.TryGet(cacheKey, out var cached, logger))
{
cachedCertificate = cached.Certificate;
cachedEndpointBase = cached.Endpoint;
cachedClientId = cached.ClientId;

return Tuple.Create(cachedCertificate, cachedEndpointBase, cachedClientId);
}

// 2) Gate per cacheKey
var gate = s_perKeyGates.GetOrAdd(cacheKey, _ => new SemaphoreSlim(1, 1));
await gate.WaitAsync(cancellationToken).ConfigureAwait(false);

try
{
// Re-check after acquiring the gate
if (s_mtlsCertificateCache.TryGet(cacheKey, out cached, logger))
{
cachedCertificate = cached.Certificate;
cachedEndpointBase = cached.Endpoint;
cachedClientId = cached.ClientId;
return Tuple.Create(cachedCertificate, cachedEndpointBase, cachedClientId);
}

// 3) Mint + cache under the provided cacheKey
var created = await factory().ConfigureAwait(false);

s_mtlsCertificateCache.Set(cacheKey,
new CertificateCacheValue(created.Item1, created.Item2, created.Item3),
logger);

return created;
}
finally
{
gate.Release();
}
return _mtlsCache.GetOrCreateAsync(cacheKey, factory, cancellationToken, logger);
}

internal static void ResetCertCacheForTest()
Expand All @@ -508,14 +459,6 @@ internal static void ResetCertCacheForTest()
{
s_mtlsCertificateCache.Clear();
}

foreach (var gate in s_perKeyGates.Values)
{
try
{ gate.Dispose(); }
catch { }
}
s_perKeyGates.Clear();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Security.Cryptography;
using System.Text;
using System.Threading;
using Microsoft.Identity.Client.PlatformsCommon.Shared;

namespace Microsoft.Identity.Client.ManagedIdentity.V2
{
/// <summary>
/// Executes paramref name="action"/ under a cross-process, per-alias mutex.
/// We attempt 2 namespaces, in order:
/// 1) <c>Global\</c> — preferred so we dedupe across all sessions on the machine
/// (e.g., service + user session). This can be denied by OS policy or missing
/// SeCreateGlobalPrivilege in some contexts.
/// 2) <c>Local\</c> — fallback to still dedupe within the current session when
/// <c>Global\</c> is not permitted.
/// Using both ensures we never throw (persistence is best-effort) while getting
/// machine-wide dedupe when allowed and session-local dedupe otherwise.
/// Notes:
/// - The mutex name is derived from <c>alias</c> (= cacheKey) via SHA-256 hex (truncated)
/// to avoid invalid characters / length issues.
/// - On non-Windows runtimes the Global/Local prefixes are treated as part of the name;
/// behavior remains correct but dedupe scope is platform-defined.
/// - Abandoned mutexes are treated as acquired to avoid blocking after a crash.
/// </summary>

internal static class InterprocessLock
{
// Prefer Global\ for cross-session dedupe; fall back to Local\
// if ACLs block Global\ to remain non-throwing.
public static bool TryWithAliasLock(
string alias,
TimeSpan timeout,
Action action,
Action<string> logVerbose = null)
{
var nameGlobal = GetMutexNameForAlias(alias, preferGlobal: true);
var nameLocal = GetMutexNameForAlias(alias, preferGlobal: false);

foreach (var name in new[] { nameGlobal, nameLocal })
{
try
{
// Create or open existing
using var m = new Mutex(initiallyOwned: false, name);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if an exception occurs before the using block exposes?


// Wait to acquire
bool entered;

try
{
entered = m.WaitOne(timeout);
}
catch (AbandonedMutexException)
{
entered = true; // prior holder crashed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a log here?

}

if (!entered)
{
logVerbose?.Invoke($"[PersistentCert] Skip persist (lock busy '{name}').");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we log duration waited?

return false;
}

try
{
action();
}
finally
{
try
{
m.ReleaseMutex();
}
catch
{
/* best-effort */
}
}

return true;
}
catch (UnauthorizedAccessException)
{
logVerbose?.Invoke($"[PersistentCert] No access to mutex scope '{name}', trying next.");
continue; // try Local if Global blocked
}
catch (Exception ex)
{
logVerbose?.Invoke($"[PersistentCert] Lock failure '{name}': {ex.Message}");
return false;
}
}

return false;
}

public static string GetMutexNameForAlias(string alias, bool preferGlobal = true)
{
string suffix = HashAlias(Canonicalize(alias));
return (preferGlobal ? @"Global\" : @"Local\") + "MSAL_MI_P_" + suffix;
}

private static string Canonicalize(string alias) => (alias ?? string.Empty).Trim().ToUpperInvariant();

private static string HashAlias(string s)
{
try
{
var hex = new CommonCryptographyManager().CreateSha256HashHex(s);
// Truncate to 32 chars to fit mutex name length limits
return string.IsNullOrEmpty(hex) ? "0" : (hex.Length > 32 ? hex.Substring(0, 32) : hex);
}
catch
{
return "0";
}
}
}
}
Loading