@@ -40,6 +40,22 @@ static Diagnostics()
4040 + " App implementable interfaces are only meant to be resolved through 'AppImplementationFactory'."
4141 ) ;
4242
43+ public static readonly DiagnosticDescriptor DangerousAppImplementationFactoryUse = Error (
44+ "ALTINNINT0003" ,
45+ Category . General ,
46+ "Dangerous 'AppImplementationFactory' service resolution" ,
47+ "App implementable service interface '{0}' is resolved through 'AppImplementationFactory' in a constructor."
48+ + " App implementable interfaces are only meant to be resolved lazily (as late as possible, right before use) to ensure expected lifetime."
49+ ) ;
50+
51+ public static readonly DiagnosticDescriptor NonAppImplementableServiceThroughAppImplementationFactory = Error (
52+ "ALTINNINT0004" ,
53+ Category . General ,
54+ "Invalid 'AppImplementationFactory' service resolution" ,
55+ "Non-appimplementable service interface '{0}' is resolved through 'AppImplementationFactory'."
56+ + " You can use 'IServiceProvider' for these instead."
57+ ) ;
58+
4359 private static DiagnosticDescriptor Error ( string id , string category , string title , string messageFormat ) =>
4460 Create ( id , title , messageFormat , category , DiagnosticSeverity . Error ) ;
4561
@@ -61,7 +77,7 @@ private static class Category
6177public sealed class AppImplementationInjectionAnalyzer : DiagnosticAnalyzer
6278{
6379 private const string MarkerAttributeName = "ImplementableByAppsAttribute" ;
64- private static readonly SymbolEqualityComparer _symbolComparer = SymbolEqualityComparer . Default ;
80+ private static readonly SymbolEqualityComparer _comparer = SymbolEqualityComparer . Default ;
6581
6682 public override ImmutableArray < DiagnosticDescriptor > SupportedDiagnostics => Diagnostics . All ;
6783
@@ -72,6 +88,102 @@ public override void Initialize(AnalysisContext context)
7288
7389 context . RegisterSyntaxNodeAction ( AnalyzeConstructors , SyntaxKind . ParameterList ) ;
7490 context . RegisterSyntaxNodeAction ( AnalyzeDIServiceCalls , SyntaxKind . InvocationExpression ) ;
91+ context . RegisterSyntaxNodeAction (
92+ AnalyzeAppImplementationFactoryConstructorUse ,
93+ SyntaxKind . InvocationExpression
94+ ) ;
95+ context . RegisterSyntaxNodeAction (
96+ AnalyzeNonAppimplementableServicesThroughAppImplementationFactory ,
97+ SyntaxKind . InvocationExpression
98+ ) ;
99+ }
100+
101+ private static void AnalyzeNonAppimplementableServicesThroughAppImplementationFactory (
102+ SyntaxNodeAnalysisContext context
103+ )
104+ {
105+ var semanticModel = context . SemanticModel ;
106+ var invocation = ( InvocationExpressionSyntax ) context . Node ;
107+ var methodSymbol = semanticModel . GetSymbolInfo ( invocation , context . CancellationToken ) . Symbol as IMethodSymbol ;
108+ if ( methodSymbol is null )
109+ return ;
110+
111+ var appImplementationFactoryType = GetAppImplementationFactorySymbol ( context ) ;
112+ if ( appImplementationFactoryType is null )
113+ return ;
114+
115+ if ( ! _comparer . Equals ( methodSymbol . ContainingType , appImplementationFactoryType ) )
116+ return ;
117+
118+ var typeInfo = GetAppImplementationFactoryInvocationTypeArgument (
119+ context ,
120+ appImplementationFactoryType ,
121+ invocation
122+ ) ;
123+ if ( typeInfo is null )
124+ return ;
125+
126+ if ( IsMarkedWithAttribute ( typeInfo ) )
127+ return ;
128+
129+ var diagnostic = Diagnostic . Create (
130+ Diagnostics . NonAppImplementableServiceThroughAppImplementationFactory ,
131+ invocation . GetLocation ( ) ,
132+ typeInfo . Name
133+ ) ;
134+ context . ReportDiagnostic ( diagnostic ) ;
135+ }
136+
137+ private static void AnalyzeAppImplementationFactoryConstructorUse ( SyntaxNodeAnalysisContext context )
138+ {
139+ var invocation = ( InvocationExpressionSyntax ) context . Node ;
140+
141+ var methodSymbol =
142+ context . SemanticModel . GetSymbolInfo ( invocation , context . CancellationToken ) . Symbol as IMethodSymbol ;
143+ if ( methodSymbol is null )
144+ return ;
145+
146+ var appImplementationFactoryType = GetAppImplementationFactorySymbol ( context ) ;
147+ if ( appImplementationFactoryType is null )
148+ return ;
149+
150+ if ( ! _comparer . Equals ( methodSymbol . ContainingType , appImplementationFactoryType ) )
151+ return ;
152+
153+ var parent = context . Node . Parent ;
154+ while ( parent != null )
155+ {
156+ // Checks if the invocation is in a constructor body or field initializer (primary constructor)
157+ if ( parent is ConstructorDeclarationSyntax or FieldDeclarationSyntax or PropertyDeclarationSyntax )
158+ {
159+ // If this is a property with a null initializer, it is probably an arrow property, which is lazy
160+ if ( parent is PropertyDeclarationSyntax { Initializer : null } )
161+ break ;
162+
163+ var typeInfo = GetAppImplementationFactoryInvocationTypeArgument (
164+ context ,
165+ appImplementationFactoryType ,
166+ invocation
167+ ) ;
168+ if ( typeInfo is null || ! IsMarkedWithAttribute ( typeInfo ) )
169+ return ;
170+ var diagnostic = Diagnostic . Create (
171+ Diagnostics . DangerousAppImplementationFactoryUse ,
172+ invocation . GetLocation ( ) ,
173+ typeInfo . Name
174+ ) ;
175+ context . ReportDiagnostic ( diagnostic ) ;
176+ break ;
177+ }
178+
179+ if ( parent is MethodDeclarationSyntax or ClassDeclarationSyntax )
180+ {
181+ // We can stop checking
182+ break ;
183+ }
184+
185+ parent = parent . Parent ;
186+ }
75187 }
76188
77189 private static void AnalyzeDIServiceCalls ( SyntaxNodeAnalysisContext context )
@@ -91,12 +203,10 @@ private static void AnalyzeDIServiceCalls(SyntaxNodeAnalysisContext context)
91203 if ( methodSymbol is null )
92204 return ;
93205
94- var serviceProviderType = context . SemanticModel . Compilation . GetTypeByMetadataName ( "System.IServiceProvider" ) ;
206+ var serviceProviderType = GetIServiceProviderSymbol ( context ) ;
95207 if ( serviceProviderType is null )
96208 return ;
97- var enumerableType = context
98- . SemanticModel . Compilation . GetTypeByMetadataName ( "System.Collections.Generic.IEnumerable`1" )
99- ? . ConstructUnboundGenericType ( ) ;
209+ var enumerableType = GetIEnumerableSymbol ( context ) ;
100210 if ( enumerableType is null )
101211 return ;
102212
@@ -109,19 +219,17 @@ private static void AnalyzeDIServiceCalls(SyntaxNodeAnalysisContext context)
109219 var firstArgType = context
110220 . SemanticModel . GetTypeInfo ( arguments [ 0 ] . Expression , context . CancellationToken )
111221 . Type ;
112- isLongFormExtMethodCall = _symbolComparer . Equals ( firstArgType , serviceProviderType ) ;
222+ isLongFormExtMethodCall = _comparer . Equals ( firstArgType , serviceProviderType ) ;
113223 }
114- if ( ! _symbolComparer . Equals ( methodSymbol . ReceiverType , serviceProviderType ) && ! isLongFormExtMethodCall )
224+ if ( ! _comparer . Equals ( methodSymbol . ReceiverType , serviceProviderType ) && ! isLongFormExtMethodCall )
115225 return ;
116226 }
117227 else
118228 {
119- if ( ! _symbolComparer . Equals ( methodSymbol . ContainingType , serviceProviderType ) )
229+ if ( ! _comparer . Equals ( methodSymbol . ContainingType , serviceProviderType ) )
120230 return ;
121231 }
122232
123- // System.Diagnostics.Debugger.Launch();
124-
125233 // check the generic form, e.g. GetService<T>()
126234 TypeSyntax ? typeSyntax = null ;
127235 var typeArgumentList = invocation . DescendantNodes ( ) . OfType < TypeArgumentListSyntax > ( ) . FirstOrDefault ( ) ;
@@ -141,21 +249,21 @@ private static void AnalyzeDIServiceCalls(SyntaxNodeAnalysisContext context)
141249 if ( typeSyntax is null or PredefinedTypeSyntax )
142250 return ;
143251
144- var typeInfoSymbol = context . SemanticModel . GetTypeInfo ( typeSyntax , context . CancellationToken ) . Type ;
252+ var typeInfoSymbol = ResolveTypeSyntaxToPotentialAppImplementableType ( context , typeSyntax ) ;
145253 if ( typeInfoSymbol is not INamedTypeSymbol typeInfo )
146254 return ;
147255
148- if ( typeInfo . IsGenericType && _symbolComparer . Equals ( typeInfo . ConstructUnboundGenericType ( ) , enumerableType ) )
256+ if ( typeInfo . IsGenericType && _comparer . Equals ( typeInfo . ConstructUnboundGenericType ( ) , enumerableType ) )
149257 {
150- if ( typeInfo . TypeArguments . FirstOrDefault ( ) is not INamedTypeSymbol innerType )
258+ var enumerableTypeArgument = typeInfo . TypeArguments . FirstOrDefault ( ) ;
259+ if ( enumerableTypeArgument is ITypeParameterSymbol typeParameterSymbol )
260+ enumerableTypeArgument = GetMostRelevantConstraintType ( typeParameterSymbol ) ;
261+ if ( enumerableTypeArgument is not INamedTypeSymbol innerType )
151262 return ;
152263 typeInfo = innerType ;
153264 }
154265
155- if (
156- typeInfo . TypeKind == TypeKind . Interface
157- && typeInfo . GetAttributes ( ) . Any ( attr => attr . AttributeClass ? . Name == MarkerAttributeName )
158- )
266+ if ( IsMarkedWithAttribute ( typeInfo ) )
159267 {
160268 var diagnostic = Diagnostic . Create (
161269 Diagnostics . DangerousServiceProviderServiceResolution ,
@@ -181,9 +289,7 @@ private static void AnalyzeConstructors(SyntaxNodeAnalysisContext context)
181289 if ( typeIdentifier is null )
182290 return ;
183291
184- var enumerableType = context
185- . SemanticModel . Compilation . GetTypeByMetadataName ( "System.Collections.Generic.IEnumerable`1" )
186- ? . ConstructUnboundGenericType ( ) ;
292+ var enumerableType = GetIEnumerableSymbol ( context ) ;
187293 if ( enumerableType is null )
188294 return ;
189295
@@ -248,11 +354,12 @@ INamedTypeSymbol enumerableType
248354 context . SemanticModel . GetSymbolInfo ( syntax , context . CancellationToken ) . Symbol as INamedTypeSymbol ;
249355 if ( typeInfo is null )
250356 return ;
251- if (
252- typeInfo . IsGenericType && _symbolComparer . Equals ( typeInfo . ConstructUnboundGenericType ( ) , enumerableType )
253- )
357+ if ( typeInfo . IsGenericType && _comparer . Equals ( typeInfo . ConstructUnboundGenericType ( ) , enumerableType ) )
254358 {
255- if ( typeInfo . TypeArguments . FirstOrDefault ( ) is not INamedTypeSymbol innerType )
359+ var enumerableTypeArgument = typeInfo . TypeArguments . FirstOrDefault ( ) ;
360+ if ( enumerableTypeArgument is ITypeParameterSymbol typeParameterSymbol )
361+ enumerableTypeArgument = GetMostRelevantConstraintType ( typeParameterSymbol ) ;
362+ if ( enumerableTypeArgument is not INamedTypeSymbol innerType )
256363 return ;
257364 typeInfo = innerType ;
258365 }
@@ -261,10 +368,125 @@ INamedTypeSymbol enumerableType
261368 var key = syntax . GetLocation ( ) ;
262369 if ( typesReferenced . ContainsKey ( key ) )
263370 return ;
264- if ( ! typeInfo . GetAttributes ( ) . Any ( attr => attr . AttributeClass ? . Name == MarkerAttributeName ) )
371+ if ( ! IsMarkedWithAttribute ( typeInfo ) )
265372 return ;
266373
267374 typesReferenced . Add ( key , ( typeInfo , syntax ) ) ;
268375 }
269376 }
377+
378+ private static bool IsMarkedWithAttribute ( ITypeSymbol symbol )
379+ {
380+ var attributes = symbol . GetAttributes ( ) ;
381+ foreach ( var attribute in attributes )
382+ {
383+ if ( attribute . AttributeClass ? . Name == MarkerAttributeName )
384+ return true ;
385+ }
386+
387+ return false ;
388+ }
389+
390+ private static INamedTypeSymbol ? GetAppImplementationFactoryInvocationTypeArgument (
391+ SyntaxNodeAnalysisContext context ,
392+ INamedTypeSymbol appImplementationFactoryType ,
393+ // Some method invocation syntax on the `AppImplementationFactory` type
394+ // e.g. `GetRequired<T>()`
395+ InvocationExpressionSyntax invocation
396+ )
397+ {
398+ TypeArgumentListSyntax ? typeArgumentList = null ;
399+ var typeArgumentLists = invocation . DescendantNodes ( ) . OfType < TypeArgumentListSyntax > ( ) . ToArray ( ) ;
400+ if ( typeArgumentLists . Length == 0 )
401+ return null ;
402+ if ( typeArgumentLists . Length == 1 )
403+ {
404+ typeArgumentList = typeArgumentLists [ 0 ] ;
405+ }
406+ else
407+ {
408+ var factoryMethods = appImplementationFactoryType
409+ . GetMembers ( )
410+ . Where ( m =>
411+ m is IMethodSymbol method && ! method . IsStatic && ! method . IsAbstract && method . IsGenericMethod
412+ )
413+ . Cast < IMethodSymbol > ( )
414+ . ToArray ( ) ;
415+ foreach ( var typeArgumentListSyntax in typeArgumentLists )
416+ {
417+ if ( typeArgumentListSyntax . Parent is not NameSyntax methodName )
418+ continue ;
419+ if ( context . SemanticModel . GetSymbolInfo ( methodName ) . Symbol is not IMethodSymbol method )
420+ continue ;
421+ if ( ! factoryMethods . Contains ( method . OriginalDefinition , _comparer ) )
422+ continue ;
423+
424+ typeArgumentList = typeArgumentListSyntax ;
425+ break ;
426+ }
427+ }
428+
429+ if ( typeArgumentList is null )
430+ return null ;
431+ var typeSyntax = typeArgumentList . Arguments . FirstOrDefault ( ) ;
432+ if ( typeSyntax is null )
433+ return null ;
434+ return ResolveTypeSyntaxToPotentialAppImplementableType ( context , typeSyntax ) ;
435+ }
436+
437+ private static INamedTypeSymbol ? ResolveTypeSyntaxToPotentialAppImplementableType (
438+ SyntaxNodeAnalysisContext context ,
439+ TypeSyntax typeSyntax
440+ )
441+ {
442+ var typeInfoSymbol = context . SemanticModel . GetTypeInfo ( typeSyntax , context . CancellationToken ) . Type ;
443+ if ( typeInfoSymbol is ITypeParameterSymbol typeParameterSymbol )
444+ typeInfoSymbol = GetMostRelevantConstraintType ( typeParameterSymbol ) ;
445+
446+ if ( typeInfoSymbol is not INamedTypeSymbol symbol )
447+ return null ;
448+
449+ return symbol ;
450+ }
451+
452+ private static ITypeSymbol ? GetMostRelevantConstraintType ( ITypeParameterSymbol typeParameterSymbol )
453+ {
454+ // If the type is a type parameter (T), we just take the first constraint type
455+ // that is marked by the attribute or fallback to the first constraint type
456+ // if none is marked with the attribute.
457+ if ( typeParameterSymbol . ConstraintTypes . Length == 0 )
458+ return null ;
459+
460+ for ( int i = 0 ; i < typeParameterSymbol . ConstraintTypes . Length ; i ++ )
461+ {
462+ var constraintType = typeParameterSymbol . ConstraintTypes [ i ] ;
463+ if ( IsMarkedWithAttribute ( constraintType ) )
464+ return constraintType ;
465+ }
466+
467+ return typeParameterSymbol . ConstraintTypes [ 0 ] ;
468+ }
469+
470+ private static INamedTypeSymbol ? GetIEnumerableSymbol ( SyntaxNodeAnalysisContext context )
471+ {
472+ var enumerableType = context
473+ . SemanticModel . Compilation . GetTypeByMetadataName ( "System.Collections.Generic.IEnumerable`1" )
474+ ? . ConstructUnboundGenericType ( ) ;
475+
476+ return enumerableType ;
477+ }
478+
479+ private static INamedTypeSymbol ? GetIServiceProviderSymbol ( SyntaxNodeAnalysisContext context )
480+ {
481+ var serviceProviderType = context . SemanticModel . Compilation . GetTypeByMetadataName ( "System.IServiceProvider" ) ;
482+ return serviceProviderType ;
483+ }
484+
485+ private static INamedTypeSymbol ? GetAppImplementationFactorySymbol ( SyntaxNodeAnalysisContext context )
486+ {
487+ var appImplementationFactoryType = context . SemanticModel . Compilation . GetTypeByMetadataName (
488+ "Altinn.App.Core.Features.AppImplementationFactory"
489+ ) ;
490+ return appImplementationFactoryType ;
491+ }
270492}
0 commit comments