Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions docs/design/libraries/ComInterfaceGenerator/VTableStubs.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,26 +80,29 @@ public readonly ref struct VirtualMethodTableInfo
}
}

public interface IUnmanagedVirtualMethodTableProvider<T> where T : IEquatable<T>
public interface IUnmanagedVirtualMethodTableProvider
{
protected VirtualMethodTableInfo GetVirtualMethodTableInfoForKey(T typeKey);
protected VirtualMethodTableInfo GetVirtualMethodTableInfoForKey(Type type);

public sealed VirtualMethodTableInfo GetVirtualMethodTableInfoForKey<TUnmanagedInterfaceType>()
where TUnmanagedInterfaceType : IUnmanagedInterfaceType<T>
public sealed VirtualMethodTableInfo GetVirtualMethodTableInfoForKey()
where TUnmanagedInterfaceType : IUnmanagedInterfaceType
{
return GetVirtualMethodTableInfoForKey(TUnmanagedInterfaceType.TypeKey);
// Dispatch from a non-virtual generic to a virtual non-generic with System.Type
// to avoid generic virtual method dispatch, which is very slow.
return GetVirtualMethodTableInfoForKey(typeof(TUnmanagedInterfaceType));
}
}

public interface IUnmanagedInterfaceType<T> where T : IEquatable<T>
public interface IUnmanagedInterfaceType
{
public abstract static T TypeKey { get; }
}
```

## Required API Shapes

The user will be required to implement `IUnmanagedVirtualMethodTableProvider<T>` on the type that provides the method tables, and `IUnmanagedInterfaceType<T>` on the type that defines the unmanaged interface. The `T` types must match between the two interfaces. This mechanism is designed to enable each native API platform to provide their own casting key, for example `IID`s in COM, without interfering with each other or requiring using reflection-based types like `System.Type`.
The user will be required to implement `IUnmanagedVirtualMethodTableProvider` on the type that provides the method tables, and `IUnmanagedInterfaceType` on the type that defines the unmanaged interface.

Previously, each of these interface types was generic on a type `T`. The `T` types were required to match between the two interfaces. This mechanism was designed to enable each native API platform to provide their own casting key, for example `IID`s in COM, without interfering with each other or requiring using reflection-based types like `System.Type`. However, practical implementation showed that providing just a "type key" was not enough information to cover any non-trivial scenarios (like COM) efficiently without effectively forcing a two-level lookup model or hard-coding type support in the `IUnmanagedVirtualMethodTableProvider<T>` implementation. Additionally, we determined that using reflection to get to attributes is considered "okay" and using generic attributes would enable APIs that build on this model like COM to effectively retrieve information from the `System.Type` instance without causing additional problems.

## Example Usage

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,11 @@ public ManagedToNativeVTableMethodGenerator(
/// <remarks>
/// The generated code assumes it will be in an unsafe context.
/// </remarks>
public BlockSyntax GenerateStubBody(int index, ImmutableArray<FunctionPointerUnmanagedCallingConventionSyntax> callConv, TypeSyntax containingTypeName, ManagedTypeInfo typeKeyType)
public BlockSyntax GenerateStubBody(int index, ImmutableArray<FunctionPointerUnmanagedCallingConventionSyntax> callConv, TypeSyntax containingTypeName)
{
var setupStatements = new List<StatementSyntax>
{
// var (<thisParameter>, <virtualMethodTable>) = ((IUnmanagedVirtualMethodTableProvider<<typeKeyType>>)this).GetVirtualMethodTableInfoForKey<<containingTypeName>>();
// var (<thisParameter>, <virtualMethodTable>) = ((IUnmanagedVirtualMethodTableProvider)this).GetVirtualMethodTableInfoForKey<<containingTypeName>>();
ExpressionStatement(
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
Expand All @@ -119,11 +119,7 @@ public BlockSyntax GenerateStubBody(int index, ImmutableArray<FunctionPointerUnm
SyntaxKind.SimpleMemberAccessExpression,
ParenthesizedExpression(
CastExpression(
GenericName(
Identifier(TypeNames.IUnmanagedVirtualMethodTableProvider))
.WithTypeArgumentList(
TypeArgumentList(
SingletonSeparatedList(typeKeyType.Syntax))),
ParseTypeName(TypeNames.IUnmanagedVirtualMethodTableProvider),
ThisExpression())),
GenericName(
Identifier("GetVirtualMethodTableInfoForKey"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@

namespace Microsoft.Interop
{
internal sealed record NativeThisInfo(ManagedTypeInfo TypeKeyType) : MarshallingInfo;
internal sealed record NativeThisInfo : MarshallingInfo
{
public static readonly NativeThisInfo Instance = new();
}

internal sealed class NativeToManagedThisMarshallerFactory : IMarshallingGeneratorFactory
{
Expand All @@ -20,14 +23,10 @@ public NativeToManagedThisMarshallerFactory(IMarshallingGeneratorFactory inner)
}

public IMarshallingGenerator Create(TypePositionInfo info, StubCodeContext context)
=> info.MarshallingAttributeInfo is NativeThisInfo(ManagedTypeInfo typeKeyType) ? new Marshaller(typeKeyType) : _inner.Create(info, context);
=> info.MarshallingAttributeInfo is NativeThisInfo ? new Marshaller() : _inner.Create(info, context);

private sealed class Marshaller : IMarshallingGenerator
{
private readonly ManagedTypeInfo _typeKeyType;

public Marshaller(ManagedTypeInfo typeKeyType) => _typeKeyType = typeKeyType;

public ManagedTypeInfo AsNativeType(TypePositionInfo info) => new PointerTypeInfo("void*", "void*", false);
public IEnumerable<StatementSyntax> Generate(TypePositionInfo info, StubCodeContext context)
{
Expand All @@ -44,10 +43,7 @@ public IEnumerable<StatementSyntax> Generate(TypePositionInfo info, StubCodeCont
IdentifierName(managedIdentifier),
InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
GenericName(Identifier(TypeNames.IUnmanagedVirtualMethodTableProvider),
TypeArgumentList(
SingletonSeparatedList(
_typeKeyType.Syntax))),
ParseTypeName(TypeNames.IUnmanagedVirtualMethodTableProvider),
GenericName(Identifier("GetObjectForUnmanagedWrapper"),
TypeArgumentList(
SingletonSeparatedList(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ internal sealed record IncrementalStubGenerationContext(
MarshallingInfo ExceptionMarshallingInfo,
MarshallingGeneratorFactoryKey<(TargetFramework TargetFramework, Version TargetFrameworkVersion)> ManagedToUnmanagedGeneratorFactory,
MarshallingGeneratorFactoryKey<(TargetFramework TargetFramework, Version TargetFrameworkVersion)> UnmanagedToManagedGeneratorFactory,
ManagedTypeInfo TypeKeyType,
ManagedTypeInfo TypeKeyOwner,
SequenceEqualImmutableArray<Diagnostic> Diagnostics);

Expand Down Expand Up @@ -348,20 +347,15 @@ private static IncrementalStubGenerationContext CalculateStubInformation(MethodD

ImmutableArray<FunctionPointerUnmanagedCallingConventionSyntax> callConv = GenerateCallConvSyntaxFromAttributes(suppressGCTransitionAttribute, unmanagedCallConvAttribute);

var typeKeyOwner = ManagedTypeInfo.CreateTypeInfoForTypeSymbol(symbol.ContainingType);
ManagedTypeInfo typeKeyType = SpecialTypeInfo.Byte;
var interfaceType = ManagedTypeInfo.CreateTypeInfoForTypeSymbol(symbol.ContainingType);

INamedTypeSymbol? iUnmanagedInterfaceTypeInstantiation = symbol.ContainingType.AllInterfaces.FirstOrDefault(iface => SymbolEqualityComparer.Default.Equals(iface.OriginalDefinition, iUnmanagedInterfaceTypeType));
if (iUnmanagedInterfaceTypeInstantiation is null)
INamedTypeSymbol expectedUnmanagedInterfaceType = iUnmanagedInterfaceTypeType.Construct(symbol.ContainingType);

bool implementsIUnmanagedInterfaceOfSelf = symbol.ContainingType.AllInterfaces.Any(iface => SymbolEqualityComparer.Default.Equals(iface, expectedUnmanagedInterfaceType));
if (!implementsIUnmanagedInterfaceOfSelf)
{
// TODO: Report invalid configuration
}
else
{
// The type key is the second generic type parameter, so we need to get the info for the
// second argument.
typeKeyType = ManagedTypeInfo.CreateTypeInfoForTypeSymbol(iUnmanagedInterfaceTypeInstantiation.TypeArguments[1]);
}

MarshallingInfo exceptionMarshallingInfo = CreateExceptionMarshallingInfo(virtualMethodIndexAttr, symbol, environment.Compilation, generatorDiagnostics, virtualMethodIndexData);

Expand All @@ -375,8 +369,7 @@ private static IncrementalStubGenerationContext CalculateStubInformation(MethodD
exceptionMarshallingInfo,
ComInterfaceGeneratorHelpers.CreateGeneratorFactory(environment, MarshalDirection.ManagedToUnmanaged),
ComInterfaceGeneratorHelpers.CreateGeneratorFactory(environment, MarshalDirection.UnmanagedToManaged),
typeKeyType,
typeKeyOwner,
interfaceType,
new SequenceEqualImmutableArray<Diagnostic>(generatorDiagnostics.Diagnostics.ToImmutableArray()));
}

Expand Down Expand Up @@ -442,8 +435,7 @@ private static (MemberDeclarationSyntax, ImmutableArray<Diagnostic>) GenerateMan
BlockSyntax code = stubGenerator.GenerateStubBody(
methodStub.VtableIndexData.Index,
methodStub.CallingConvention.Array,
methodStub.TypeKeyOwner.Syntax,
methodStub.TypeKeyType);
methodStub.TypeKeyOwner.Syntax);

return (
methodStub.ContainingSyntaxContext.AddContainingSyntax(
Expand Down Expand Up @@ -518,7 +510,7 @@ private static ImmutableArray<TypePositionInfo> AddImplicitElementInfos(Incremen

var elements = ImmutableArray.CreateBuilder<TypePositionInfo>(originalElements.Length + 2);

elements.Add(new TypePositionInfo(methodStub.TypeKeyOwner, new NativeThisInfo(methodStub.TypeKeyType))
elements.Add(new TypePositionInfo(methodStub.TypeKeyOwner, NativeThisInfo.Instance)
{
InstanceIdentifier = ThisParameterIdentifier,
NativeIndex = 0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public static class TypeNames

public const string IUnmanagedVirtualMethodTableProvider = "System.Runtime.InteropServices.IUnmanagedVirtualMethodTableProvider";

public const string IUnmanagedInterfaceType_Metadata = "System.Runtime.InteropServices.IUnmanagedInterfaceType`2";
public const string IUnmanagedInterfaceType_Metadata = "System.Runtime.InteropServices.IUnmanagedInterfaceType`1";

public const string System_Span_Metadata = "System.Span`1";
public const string System_Span = "System.Span";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,27 +48,27 @@ public void Deconstruct(out IntPtr thisPointer, out ReadOnlySpan<IntPtr> virtual
}

/// <summary>
/// This interface allows an object to provide information about a virtual method table for a managed interface that implements <see cref="IUnmanagedInterfaceType{TInterface, T}"/> to enable invoking methods in the virtual method table.
/// This interface allows an object to provide information about a virtual method table for a managed interface that implements <see cref="IUnmanagedInterfaceType{TInterface}"/> to enable invoking methods in the virtual method table.
/// </summary>
/// <typeparam name="T">The type to use to represent the the identity of the unmanaged type.</typeparam>
public unsafe interface IUnmanagedVirtualMethodTableProvider<T> where T : IEquatable<T>
public unsafe interface IUnmanagedVirtualMethodTableProvider
{
/// <summary>
/// Get the information about the virtual method table for a given unmanaged interface type represented by <paramref name="typeKey"/>.
/// Get the information about the virtual method table for a given unmanaged interface type represented by <paramref name="type"/>.
/// </summary>
/// <param name="typeKey">The object that represents the identity of the unmanaged interface.</param>
/// <param name="type">The managed type for the unmanaged interface.</param>
/// <returns>The virtual method table information for the unmanaged interface.</returns>
protected VirtualMethodTableInfo GetVirtualMethodTableInfoForKey(T typeKey);
protected VirtualMethodTableInfo GetVirtualMethodTableInfoForKey(Type type);

/// <summary>
/// Get the information about the virtual method table for the given unmanaged interface type.
/// </summary>
/// <typeparam name="TUnmanagedInterfaceType">The managed interface type that represents the unmanaged interface.</typeparam>
/// <returns>The virtual method table information for the unmanaged interface.</returns>
public sealed VirtualMethodTableInfo GetVirtualMethodTableInfoForKey<TUnmanagedInterfaceType>()
where TUnmanagedInterfaceType : IUnmanagedInterfaceType<TUnmanagedInterfaceType, T>
where TUnmanagedInterfaceType : IUnmanagedInterfaceType<TUnmanagedInterfaceType>
{
return GetVirtualMethodTableInfoForKey(TUnmanagedInterfaceType.TypeKey);
return GetVirtualMethodTableInfoForKey(typeof(TUnmanagedInterfaceType));
}

/// <summary>
Expand All @@ -77,7 +77,7 @@ public sealed VirtualMethodTableInfo GetVirtualMethodTableInfoForKey<TUnmanagedI
/// <typeparam name="TUnmanagedInterfaceType">The managed interface type that represents the unmanaged interface.</typeparam>
/// <returns>The length of the virtual method table for the unmanaged interface.</returns>
public static int GetVirtualMethodTableLength<TUnmanagedInterfaceType>()
where TUnmanagedInterfaceType : IUnmanagedInterfaceType<TUnmanagedInterfaceType, T>
where TUnmanagedInterfaceType : IUnmanagedInterfaceType<TUnmanagedInterfaceType>
{
return TUnmanagedInterfaceType.VirtualMethodTableLength;
}
Expand All @@ -88,7 +88,7 @@ public static int GetVirtualMethodTableLength<TUnmanagedInterfaceType>()
/// <typeparam name="TUnmanagedInterfaceType">The managed interface type that represents the unmanaged interface.</typeparam>
/// <returns>A pointer to the virtual method table of managed implementations of the unmanaged interface type</returns>
public static void* GetVirtualMethodTableManagedImplementation<TUnmanagedInterfaceType>()
where TUnmanagedInterfaceType : IUnmanagedInterfaceType<TUnmanagedInterfaceType, T>
where TUnmanagedInterfaceType : IUnmanagedInterfaceType<TUnmanagedInterfaceType>
{
return TUnmanagedInterfaceType.VirtualMethodTableManagedImplementation;
}
Expand All @@ -100,7 +100,7 @@ public static int GetVirtualMethodTableLength<TUnmanagedInterfaceType>()
/// <param name="obj">The managed object that implements the unmanaged interface.</param>
/// <returns>A pointer-sized value that can be passed to unmanaged code that represents <paramref name="obj"/></returns>
public static void* GetUnmanagedWrapperForObject<TUnmanagedInterfaceType>(TUnmanagedInterfaceType obj)
where TUnmanagedInterfaceType : IUnmanagedInterfaceType<TUnmanagedInterfaceType, T>
where TUnmanagedInterfaceType : IUnmanagedInterfaceType<TUnmanagedInterfaceType>
{
return TUnmanagedInterfaceType.GetUnmanagedWrapperForObject(obj);
}
Expand All @@ -112,7 +112,7 @@ public static int GetVirtualMethodTableLength<TUnmanagedInterfaceType>()
/// <param name="ptr">A pointer-sized value returned by <see cref="GetUnmanagedWrapperForObject{TUnmanagedInterfaceType}(TUnmanagedInterfaceType)"/> or <see cref="IUnmanagedInterfaceType{TInterface, TKey}.GetUnmanagedWrapperForObject(TInterface)"/>.</param>
/// <returns>The object wrapped by <paramref name="ptr"/>.</returns>
public static TUnmanagedInterfaceType GetObjectForUnmanagedWrapper<TUnmanagedInterfaceType>(void* ptr)
where TUnmanagedInterfaceType : IUnmanagedInterfaceType<TUnmanagedInterfaceType, T>
where TUnmanagedInterfaceType : IUnmanagedInterfaceType<TUnmanagedInterfaceType>
{
return TUnmanagedInterfaceType.GetObjectForUnmanagedWrapper(ptr);
}
Expand All @@ -123,9 +123,8 @@ public static TUnmanagedInterfaceType GetObjectForUnmanagedWrapper<TUnmanagedInt
/// </summary>
/// <typeparam name="TInterface">The managed interface.</typeparam>
/// <typeparam name="TKey">The type of a value that can represent types from the corresponding unmanaged type system.</typeparam>
public unsafe interface IUnmanagedInterfaceType<TInterface, TKey>
where TInterface : IUnmanagedInterfaceType<TInterface, TKey>
where TKey : IEquatable<TKey>
public unsafe interface IUnmanagedInterfaceType<TInterface>
where TInterface : IUnmanagedInterfaceType<TInterface>
{
/// <summary>
/// Get the length of the virtual method table for the given unmanaged interface type.
Expand All @@ -152,10 +151,5 @@ public unsafe interface IUnmanagedInterfaceType<TInterface, TKey>
/// <param name="ptr">A pointer-sized value returned by <see cref="IUnmanagedVirtualMethodTableProvider{TKey}.GetUnmanagedWrapperForObject{IUnmanagedInterfaceType{TInterface, TKey}}(IUnmanagedInterfaceType{TInterface, TKey})"/> or <see cref="GetUnmanagedWrapperForObject(TInterface)"/>.</param>
/// <returns>The object wrapped by <paramref name="ptr"/>.</returns>
public static abstract TInterface GetObjectForUnmanagedWrapper(void* ptr);

/// <summary>
/// The value that represents the unmanaged type's identity in the corresponding unmanaged type system.
/// </summary>
public static abstract TKey TypeKey { get; }
}
}
Loading