Skip to content

Commit cec2a37

Browse files
authored
feat: Regasm: should call static methods with Com(Un)RegisterFunctionAttribute (#370)
* Added Feature * Update RegistrationServices.cs * Moved Implementation
1 parent 4fe2f11 commit cec2a37

File tree

2 files changed

+187
-2
lines changed

2 files changed

+187
-2
lines changed

src/dscom.test/tests/RegistrationServicesTest.cs

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,53 @@ public void RegisterTypeForComClients_SuspendedTestClassAsLocalServerWithCoResum
168168
// Cleanup
169169
registrationServices.UnregisterTypeForComClients(cookie);
170170
}
171+
172+
[Fact]
173+
public void RegistrationServices_Should_No_ThrowOn_ComRegister_Without_Method()
174+
{
175+
RegistrationServices.CustomRegistrationFunction(typeof(NoRegisterFunctionClass));
176+
RegistrationServices.CustomUnregistrationFunction(typeof(NoRegisterFunctionClass));
177+
}
178+
179+
[Fact]
180+
public void RegistrationServices_Should_Invoke_ComRegisterMethod()
181+
{
182+
RegistrationServices.CustomRegistrationFunction(typeof(SingleRegisterFunctionClass));
183+
RegistrationServices.CustomUnregistrationFunction(typeof(SingleRegisterFunctionClass));
184+
185+
Assert.Equal(typeof(SingleRegisterFunctionClass), SingleRegisterFunctionClass.RegisteredType);
186+
Assert.Equal(typeof(SingleRegisterFunctionClass), SingleRegisterFunctionClass.UnregisteredType);
187+
}
188+
189+
[Fact]
190+
public void RegistrationServices_Should_Throw_On_Instance_ComRegisterMethod()
191+
{
192+
Assert.Throws<InvalidOperationException>(
193+
() => RegistrationServices.CustomRegistrationFunction(typeof(SingleInstanceRegisterFunctionClass)));
194+
195+
Assert.Throws<InvalidOperationException>(
196+
() => RegistrationServices.CustomUnregistrationFunction(typeof(SingleInstanceRegisterFunctionClass)));
197+
}
198+
199+
[Fact]
200+
public void RegistrationServices_Should_Throw_On_Multiple_ComRegisterMethod()
201+
{
202+
Assert.Throws<InvalidOperationException>(
203+
() => RegistrationServices.CustomRegistrationFunction(typeof(MultipleRegisterFunctionClass)));
204+
205+
Assert.Throws<InvalidOperationException>(
206+
() => RegistrationServices.CustomUnregistrationFunction(typeof(MultipleRegisterFunctionClass)));
207+
}
208+
209+
[Fact]
210+
public void RegistrationServices_Should_Throw_On_WrongSignature_ComRegisterMethod()
211+
{
212+
Assert.Throws<InvalidOperationException>(
213+
() => RegistrationServices.CustomRegistrationFunction(typeof(WrongSignatureRegisterFunctionClass)));
214+
215+
Assert.Throws<InvalidOperationException>(
216+
() => RegistrationServices.CustomUnregistrationFunction(typeof(WrongSignatureRegisterFunctionClass)));
217+
}
171218
}
172219

173220
[ComVisible(true)]
@@ -220,3 +267,88 @@ public string GetSuccessString()
220267
return "Success";
221268
}
222269
}
270+
271+
public class NoRegisterFunctionClass
272+
{
273+
274+
}
275+
276+
public class SingleRegisterFunctionClass
277+
{
278+
public static Type? RegisteredType { get; private set; }
279+
280+
public static Type? UnregisteredType { get; private set; }
281+
282+
[ComRegisterFunction]
283+
public static void ComRegister(Type type)
284+
{
285+
RegisteredType = type;
286+
}
287+
288+
[ComUnregisterFunction]
289+
public static void ComUnregister(Type type)
290+
{
291+
UnregisteredType = type;
292+
}
293+
}
294+
295+
public class SingleInstanceRegisterFunctionClass
296+
{
297+
public Type? RegisteredType { get; private set; }
298+
299+
public Type? UnregisteredType { get; private set; }
300+
301+
[ComRegisterFunction]
302+
public void ComRegister(Type type)
303+
{
304+
RegisteredType = type;
305+
}
306+
307+
[ComUnregisterFunction]
308+
public void ComUnregister(Type type)
309+
{
310+
UnregisteredType = type;
311+
}
312+
}
313+
314+
public class MultipleRegisterFunctionClass
315+
{
316+
public static Type? RegisteredType { get; private set; }
317+
318+
public static Type? UnregisteredType { get; private set; }
319+
320+
[ComRegisterFunction]
321+
public static void ComRegisterA(Type type)
322+
{
323+
RegisteredType = type;
324+
}
325+
326+
[ComRegisterFunction]
327+
public static void ComRegisterB(Type type)
328+
{
329+
RegisteredType = type;
330+
}
331+
332+
[ComUnregisterFunction]
333+
public static void ComUnregisterA(Type type)
334+
{
335+
UnregisteredType = type;
336+
}
337+
338+
[ComUnregisterFunction]
339+
public static void ComUnregisterB(Type type)
340+
{
341+
UnregisteredType = type;
342+
}
343+
}
344+
345+
public class WrongSignatureRegisterFunctionClass
346+
{
347+
[ComRegisterFunction]
348+
public static void ComRegister()
349+
{ }
350+
351+
[ComUnregisterFunction]
352+
public static void ComUnregister()
353+
{ }
354+
}

src/dscom/RegistrationServices.cs

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ public bool RegisterAssembly(Assembly assembly, bool registerCodeBase, ManagedCa
227227
RegisterManagedType(type, fullName, assemblyVersion, codeBase, runtimeVersion, preferredAction);
228228
}
229229

230-
// Skip: CustomRegistrationFunction
230+
CustomRegistrationFunction(type);
231231
}
232232

233233
// Skip: PIA Regitration
@@ -263,7 +263,8 @@ public bool UnregisterAssembly(Assembly assembly)
263263

264264
foreach (var type in typesToUnregister)
265265
{
266-
// Skip: Custom unregister function
266+
CustomUnregistrationFunction(type);
267+
267268
if (IsComRegistratableValueType(type) && !UnregisterValueType(type, assemblyVersion))
268269
{
269270
typesNotRemoved.Add(type);
@@ -848,4 +849,56 @@ private static bool CanWriteGlobalRegistry()
848849
}
849850
}
850851

852+
public static void CustomRegistrationFunction(Type type)
853+
{
854+
var provider = TryGetComRegisterFunction<ComRegisterFunctionAttribute>(type);
855+
856+
provider?.Invoke(null, new object[] { type });
857+
}
858+
859+
public static void CustomUnregistrationFunction(Type type)
860+
{
861+
var provider = TryGetComRegisterFunction<ComUnregisterFunctionAttribute>(type);
862+
863+
provider?.Invoke(null, new object[] { type });
864+
}
865+
866+
private static MethodInfo? TryGetComRegisterFunction<T>(Type type)
867+
where T : Attribute
868+
{
869+
var registerMethod = type
870+
.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Static)
871+
.Where(m => m.GetCustomAttribute<T>() is not null)
872+
.ToList();
873+
874+
// No method
875+
if (registerMethod.Count == 0)
876+
{
877+
return null;
878+
}
879+
880+
// multipe methods -> error
881+
if (registerMethod.Count > 1)
882+
{
883+
throw new InvalidOperationException($"The type '{type.Name}' contains more than on COM-Function for registration.");
884+
}
885+
886+
var provider = registerMethod[0];
887+
888+
// method should be static
889+
if (!provider.IsStatic)
890+
{
891+
throw new InvalidOperationException($"The COM-Function for registration of type '{type.Name}' should be static.");
892+
}
893+
894+
// method should have exactly one parameter of type
895+
var paramters = provider.GetParameters();
896+
if (paramters.Length != 1 || paramters[0].ParameterType != typeof(Type))
897+
{
898+
throw new InvalidOperationException($"The COM-Function for registration of type '{type.Name}' should have exactly one parameter of '{nameof(Type)}.");
899+
}
900+
901+
return provider;
902+
}
903+
851904
}

0 commit comments

Comments
 (0)