diff --git a/src/Altinn.App.Internal.Analyzers/AnalyzerReleases.Unshipped.md b/src/Altinn.App.Internal.Analyzers/AnalyzerReleases.Unshipped.md index ac01bf2bb..a089007b8 100644 --- a/src/Altinn.App.Internal.Analyzers/AnalyzerReleases.Unshipped.md +++ b/src/Altinn.App.Internal.Analyzers/AnalyzerReleases.Unshipped.md @@ -7,4 +7,6 @@ Rule ID | Category | Severity | Notes --------|----------|----------|------- ALTINNINT0001 | General | Error | Dangerous constructor injection ALTINNINT0002 | General | Error | Dangerous 'IServiceProvider' service resolution +ALTINNINT0003 | General | Error | Dangerous 'AppImplementationFactory' service resolution +ALTINNINT0004 | General | Error | Invalid 'AppImplementationFactory' service resolution ALTINNINT9999 | General | Error | Unknown error diff --git a/src/Altinn.App.Internal.Analyzers/AppImplementationInjectionAnalyzer.cs b/src/Altinn.App.Internal.Analyzers/AppImplementationInjectionAnalyzer.cs index 12cad19d1..5cbe47c01 100644 --- a/src/Altinn.App.Internal.Analyzers/AppImplementationInjectionAnalyzer.cs +++ b/src/Altinn.App.Internal.Analyzers/AppImplementationInjectionAnalyzer.cs @@ -40,6 +40,22 @@ static Diagnostics() + " App implementable interfaces are only meant to be resolved through 'AppImplementationFactory'." ); + public static readonly DiagnosticDescriptor DangerousAppImplementationFactoryUse = Error( + "ALTINNINT0003", + Category.General, + "Dangerous 'AppImplementationFactory' service resolution", + "App implementable service interface '{0}' is resolved through 'AppImplementationFactory' in a constructor." + + " App implementable interfaces are only meant to be resolved lazily (as late as possible, right before use) to ensure expected lifetime." + ); + + public static readonly DiagnosticDescriptor NonAppImplementableServiceThroughAppImplementationFactory = Error( + "ALTINNINT0004", + Category.General, + "Invalid 'AppImplementationFactory' service resolution", + "Non-appimplementable service interface '{0}' is resolved through 'AppImplementationFactory'." + + " You can use 'IServiceProvider' for these instead." + ); + private static DiagnosticDescriptor Error(string id, string category, string title, string messageFormat) => Create(id, title, messageFormat, category, DiagnosticSeverity.Error); @@ -61,7 +77,7 @@ private static class Category public sealed class AppImplementationInjectionAnalyzer : DiagnosticAnalyzer { private const string MarkerAttributeName = "ImplementableByAppsAttribute"; - private static readonly SymbolEqualityComparer _symbolComparer = SymbolEqualityComparer.Default; + private static readonly SymbolEqualityComparer _comparer = SymbolEqualityComparer.Default; public override ImmutableArray SupportedDiagnostics => Diagnostics.All; @@ -72,6 +88,102 @@ public override void Initialize(AnalysisContext context) context.RegisterSyntaxNodeAction(AnalyzeConstructors, SyntaxKind.ParameterList); context.RegisterSyntaxNodeAction(AnalyzeDIServiceCalls, SyntaxKind.InvocationExpression); + context.RegisterSyntaxNodeAction( + AnalyzeAppImplementationFactoryConstructorUse, + SyntaxKind.InvocationExpression + ); + context.RegisterSyntaxNodeAction( + AnalyzeNonAppimplementableServicesThroughAppImplementationFactory, + SyntaxKind.InvocationExpression + ); + } + + private static void AnalyzeNonAppimplementableServicesThroughAppImplementationFactory( + SyntaxNodeAnalysisContext context + ) + { + var semanticModel = context.SemanticModel; + var invocation = (InvocationExpressionSyntax)context.Node; + var methodSymbol = semanticModel.GetSymbolInfo(invocation, context.CancellationToken).Symbol as IMethodSymbol; + if (methodSymbol is null) + return; + + var appImplementationFactoryType = GetAppImplementationFactorySymbol(context); + if (appImplementationFactoryType is null) + return; + + if (!_comparer.Equals(methodSymbol.ContainingType, appImplementationFactoryType)) + return; + + var typeInfo = GetAppImplementationFactoryInvocationTypeArgument( + context, + appImplementationFactoryType, + invocation + ); + if (typeInfo is null) + return; + + if (IsMarkedWithAttribute(typeInfo)) + return; + + var diagnostic = Diagnostic.Create( + Diagnostics.NonAppImplementableServiceThroughAppImplementationFactory, + invocation.GetLocation(), + typeInfo.Name + ); + context.ReportDiagnostic(diagnostic); + } + + private static void AnalyzeAppImplementationFactoryConstructorUse(SyntaxNodeAnalysisContext context) + { + var invocation = (InvocationExpressionSyntax)context.Node; + + var methodSymbol = + context.SemanticModel.GetSymbolInfo(invocation, context.CancellationToken).Symbol as IMethodSymbol; + if (methodSymbol is null) + return; + + var appImplementationFactoryType = GetAppImplementationFactorySymbol(context); + if (appImplementationFactoryType is null) + return; + + if (!_comparer.Equals(methodSymbol.ContainingType, appImplementationFactoryType)) + return; + + var parent = context.Node.Parent; + while (parent != null) + { + // Checks if the invocation is in a constructor body or field initializer (primary constructor) + if (parent is ConstructorDeclarationSyntax or FieldDeclarationSyntax or PropertyDeclarationSyntax) + { + // If this is a property with a null initializer, it is probably an arrow property, which is lazy + if (parent is PropertyDeclarationSyntax { Initializer: null }) + break; + + var typeInfo = GetAppImplementationFactoryInvocationTypeArgument( + context, + appImplementationFactoryType, + invocation + ); + if (typeInfo is null || !IsMarkedWithAttribute(typeInfo)) + return; + var diagnostic = Diagnostic.Create( + Diagnostics.DangerousAppImplementationFactoryUse, + invocation.GetLocation(), + typeInfo.Name + ); + context.ReportDiagnostic(diagnostic); + break; + } + + if (parent is MethodDeclarationSyntax or ClassDeclarationSyntax) + { + // We can stop checking + break; + } + + parent = parent.Parent; + } } private static void AnalyzeDIServiceCalls(SyntaxNodeAnalysisContext context) @@ -91,12 +203,10 @@ private static void AnalyzeDIServiceCalls(SyntaxNodeAnalysisContext context) if (methodSymbol is null) return; - var serviceProviderType = context.SemanticModel.Compilation.GetTypeByMetadataName("System.IServiceProvider"); + var serviceProviderType = GetIServiceProviderSymbol(context); if (serviceProviderType is null) return; - var enumerableType = context - .SemanticModel.Compilation.GetTypeByMetadataName("System.Collections.Generic.IEnumerable`1") - ?.ConstructUnboundGenericType(); + var enumerableType = GetIEnumerableSymbol(context); if (enumerableType is null) return; @@ -109,19 +219,17 @@ private static void AnalyzeDIServiceCalls(SyntaxNodeAnalysisContext context) var firstArgType = context .SemanticModel.GetTypeInfo(arguments[0].Expression, context.CancellationToken) .Type; - isLongFormExtMethodCall = _symbolComparer.Equals(firstArgType, serviceProviderType); + isLongFormExtMethodCall = _comparer.Equals(firstArgType, serviceProviderType); } - if (!_symbolComparer.Equals(methodSymbol.ReceiverType, serviceProviderType) && !isLongFormExtMethodCall) + if (!_comparer.Equals(methodSymbol.ReceiverType, serviceProviderType) && !isLongFormExtMethodCall) return; } else { - if (!_symbolComparer.Equals(methodSymbol.ContainingType, serviceProviderType)) + if (!_comparer.Equals(methodSymbol.ContainingType, serviceProviderType)) return; } - // System.Diagnostics.Debugger.Launch(); - // check the generic form, e.g. GetService() TypeSyntax? typeSyntax = null; var typeArgumentList = invocation.DescendantNodes().OfType().FirstOrDefault(); @@ -141,21 +249,21 @@ private static void AnalyzeDIServiceCalls(SyntaxNodeAnalysisContext context) if (typeSyntax is null or PredefinedTypeSyntax) return; - var typeInfoSymbol = context.SemanticModel.GetTypeInfo(typeSyntax, context.CancellationToken).Type; + var typeInfoSymbol = ResolveTypeSyntaxToPotentialAppImplementableType(context, typeSyntax); if (typeInfoSymbol is not INamedTypeSymbol typeInfo) return; - if (typeInfo.IsGenericType && _symbolComparer.Equals(typeInfo.ConstructUnboundGenericType(), enumerableType)) + if (typeInfo.IsGenericType && _comparer.Equals(typeInfo.ConstructUnboundGenericType(), enumerableType)) { - if (typeInfo.TypeArguments.FirstOrDefault() is not INamedTypeSymbol innerType) + var enumerableTypeArgument = typeInfo.TypeArguments.FirstOrDefault(); + if (enumerableTypeArgument is ITypeParameterSymbol typeParameterSymbol) + enumerableTypeArgument = GetMostRelevantConstraintType(typeParameterSymbol); + if (enumerableTypeArgument is not INamedTypeSymbol innerType) return; typeInfo = innerType; } - if ( - typeInfo.TypeKind == TypeKind.Interface - && typeInfo.GetAttributes().Any(attr => attr.AttributeClass?.Name == MarkerAttributeName) - ) + if (IsMarkedWithAttribute(typeInfo)) { var diagnostic = Diagnostic.Create( Diagnostics.DangerousServiceProviderServiceResolution, @@ -181,9 +289,7 @@ private static void AnalyzeConstructors(SyntaxNodeAnalysisContext context) if (typeIdentifier is null) return; - var enumerableType = context - .SemanticModel.Compilation.GetTypeByMetadataName("System.Collections.Generic.IEnumerable`1") - ?.ConstructUnboundGenericType(); + var enumerableType = GetIEnumerableSymbol(context); if (enumerableType is null) return; @@ -248,11 +354,12 @@ INamedTypeSymbol enumerableType context.SemanticModel.GetSymbolInfo(syntax, context.CancellationToken).Symbol as INamedTypeSymbol; if (typeInfo is null) return; - if ( - typeInfo.IsGenericType && _symbolComparer.Equals(typeInfo.ConstructUnboundGenericType(), enumerableType) - ) + if (typeInfo.IsGenericType && _comparer.Equals(typeInfo.ConstructUnboundGenericType(), enumerableType)) { - if (typeInfo.TypeArguments.FirstOrDefault() is not INamedTypeSymbol innerType) + var enumerableTypeArgument = typeInfo.TypeArguments.FirstOrDefault(); + if (enumerableTypeArgument is ITypeParameterSymbol typeParameterSymbol) + enumerableTypeArgument = GetMostRelevantConstraintType(typeParameterSymbol); + if (enumerableTypeArgument is not INamedTypeSymbol innerType) return; typeInfo = innerType; } @@ -261,10 +368,125 @@ INamedTypeSymbol enumerableType var key = syntax.GetLocation(); if (typesReferenced.ContainsKey(key)) return; - if (!typeInfo.GetAttributes().Any(attr => attr.AttributeClass?.Name == MarkerAttributeName)) + if (!IsMarkedWithAttribute(typeInfo)) return; typesReferenced.Add(key, (typeInfo, syntax)); } } + + private static bool IsMarkedWithAttribute(ITypeSymbol symbol) + { + var attributes = symbol.GetAttributes(); + foreach (var attribute in attributes) + { + if (attribute.AttributeClass?.Name == MarkerAttributeName) + return true; + } + + return false; + } + + private static INamedTypeSymbol? GetAppImplementationFactoryInvocationTypeArgument( + SyntaxNodeAnalysisContext context, + INamedTypeSymbol appImplementationFactoryType, + // Some method invocation syntax on the `AppImplementationFactory` type + // e.g. `GetRequired()` + InvocationExpressionSyntax invocation + ) + { + TypeArgumentListSyntax? typeArgumentList = null; + var typeArgumentLists = invocation.DescendantNodes().OfType().ToArray(); + if (typeArgumentLists.Length == 0) + return null; + if (typeArgumentLists.Length == 1) + { + typeArgumentList = typeArgumentLists[0]; + } + else + { + var factoryMethods = appImplementationFactoryType + .GetMembers() + .Where(m => + m is IMethodSymbol method && !method.IsStatic && !method.IsAbstract && method.IsGenericMethod + ) + .Cast() + .ToArray(); + foreach (var typeArgumentListSyntax in typeArgumentLists) + { + if (typeArgumentListSyntax.Parent is not NameSyntax methodName) + continue; + if (context.SemanticModel.GetSymbolInfo(methodName).Symbol is not IMethodSymbol method) + continue; + if (!factoryMethods.Contains(method.OriginalDefinition, _comparer)) + continue; + + typeArgumentList = typeArgumentListSyntax; + break; + } + } + + if (typeArgumentList is null) + return null; + var typeSyntax = typeArgumentList.Arguments.FirstOrDefault(); + if (typeSyntax is null) + return null; + return ResolveTypeSyntaxToPotentialAppImplementableType(context, typeSyntax); + } + + private static INamedTypeSymbol? ResolveTypeSyntaxToPotentialAppImplementableType( + SyntaxNodeAnalysisContext context, + TypeSyntax typeSyntax + ) + { + var typeInfoSymbol = context.SemanticModel.GetTypeInfo(typeSyntax, context.CancellationToken).Type; + if (typeInfoSymbol is ITypeParameterSymbol typeParameterSymbol) + typeInfoSymbol = GetMostRelevantConstraintType(typeParameterSymbol); + + if (typeInfoSymbol is not INamedTypeSymbol symbol) + return null; + + return symbol; + } + + private static ITypeSymbol? GetMostRelevantConstraintType(ITypeParameterSymbol typeParameterSymbol) + { + // If the type is a type parameter (T), we just take the first constraint type + // that is marked by the attribute or fallback to the first constraint type + // if none is marked with the attribute. + if (typeParameterSymbol.ConstraintTypes.Length == 0) + return null; + + for (int i = 0; i < typeParameterSymbol.ConstraintTypes.Length; i++) + { + var constraintType = typeParameterSymbol.ConstraintTypes[i]; + if (IsMarkedWithAttribute(constraintType)) + return constraintType; + } + + return typeParameterSymbol.ConstraintTypes[0]; + } + + private static INamedTypeSymbol? GetIEnumerableSymbol(SyntaxNodeAnalysisContext context) + { + var enumerableType = context + .SemanticModel.Compilation.GetTypeByMetadataName("System.Collections.Generic.IEnumerable`1") + ?.ConstructUnboundGenericType(); + + return enumerableType; + } + + private static INamedTypeSymbol? GetIServiceProviderSymbol(SyntaxNodeAnalysisContext context) + { + var serviceProviderType = context.SemanticModel.Compilation.GetTypeByMetadataName("System.IServiceProvider"); + return serviceProviderType; + } + + private static INamedTypeSymbol? GetAppImplementationFactorySymbol(SyntaxNodeAnalysisContext context) + { + var appImplementationFactoryType = context.SemanticModel.Compilation.GetTypeByMetadataName( + "Altinn.App.Core.Features.AppImplementationFactory" + ); + return appImplementationFactoryType; + } }