From 93567c7771f30152098ad31a7aeceda448b66ae7 Mon Sep 17 00:00:00 2001 From: Andrei Dreyer Date: Tue, 2 Apr 2024 15:43:44 +0200 Subject: [PATCH] [C#] Pattern Matching for `IfExpr` (#4415) * [c#] Pattern Matching Statements * Initial WIP * bump DotNetAstGen version * [c#] Added lowering for pattern expressions in if statements. Handles declaration and constant patterns * [c#] pr comments * [c#] bump dotnetastgen version * [c#] rollback version on dotnetastgen --------- Co-authored-by: David Baker Effendi --- .../src/main/resources/application.conf | 2 +- .../astcreation/AstCreatorHelper.scala | 2 +- .../AstForExpressionsCreator.scala | 68 ++++++++- .../astcreation/AstForStatementsCreator.scala | 34 ++++- .../csharpsrc2cpg/parser/DotNetJsonAst.scala | 132 ++++++++++-------- .../querying/ast/PatternMatchingTests.scala | 93 ++++++++++++ 6 files changed, 262 insertions(+), 69 deletions(-) create mode 100644 joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/PatternMatchingTests.scala diff --git a/joern-cli/frontends/csharpsrc2cpg/src/main/resources/application.conf b/joern-cli/frontends/csharpsrc2cpg/src/main/resources/application.conf index 7e3c432d0787..c0611d3e9557 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/main/resources/application.conf +++ b/joern-cli/frontends/csharpsrc2cpg/src/main/resources/application.conf @@ -1,3 +1,3 @@ csharpsrc2cpg { - dotnetastgen_version: "0.28.0" + dotnetastgen_version: "0.29.0" } diff --git a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstCreatorHelper.scala b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstCreatorHelper.scala index 1c6d17114b9a..0048e2297a9c 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstCreatorHelper.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstCreatorHelper.scala @@ -190,7 +190,7 @@ object AstCreatorHelper { def nameFromNode(node: DotNetNodeInfo): String = { node.node match case NamespaceDeclaration | UsingDirective | FileScopedNamespaceDeclaration => nameFromNamespaceDeclaration(node) - case IdentifierName | Parameter | _: DeclarationExpr | GenericName => + case IdentifierName | Parameter | _: DeclarationExpr | GenericName | SingleVariableDesignation => nameFromIdentifier(node) case QualifiedName => nameFromQualifiedName(node) case SimpleMemberAccessExpression | MemberBindingExpression | SuppressNullableWarningExpression | Attribute => diff --git a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstForExpressionsCreator.scala b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstForExpressionsCreator.scala index 8380e8c2b32e..2f0f468ae7e3 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstForExpressionsCreator.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstForExpressionsCreator.scala @@ -1,15 +1,14 @@ package io.joern.csharpsrc2cpg.astcreation -import io.joern.csharpsrc2cpg.CSharpOperators import io.joern.csharpsrc2cpg.datastructures.CSharpMethod import io.joern.csharpsrc2cpg.parser.DotNetJsonAst.* import io.joern.csharpsrc2cpg.parser.{DotNetNodeInfo, ParserKeys} -import io.joern.x2cpg.utils.NodeBuilders.{newIdentifierNode, newOperatorCallNode} +import io.joern.csharpsrc2cpg.{CSharpOperators, Constants} +import io.joern.x2cpg.utils.NodeBuilders.{newCallNode, newIdentifierNode, newOperatorCallNode} import io.joern.x2cpg.{Ast, Defines, ValidationMode} import io.shiftleft.codepropertygraph.generated.nodes.{NewFieldIdentifier, NewLiteral, NewTypeRef} import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} import ujson.Value -import io.joern.csharpsrc2cpg.Constants import scala.collection.mutable.ArrayBuffer import scala.util.{Failure, Success, Try} @@ -564,4 +563,67 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { val _annotationNode = annotationNode(attribute, attribute.code, attributeName, fullName) annotationAst(_annotationNode, argumentAsts) } + + /** Lowers a pattern expression into a condition and then a declaration if one occurs. + * @param isPatternExpression + * a pattern expression which may include a declaration. + * @return + * a condition and then (potentially) declaration. + */ + protected def astsForIsPatternExpression(isPatternExpression: DotNetNodeInfo): List[Ast] = { + val pattern = createDotNetNodeInfo(isPatternExpression.json(ParserKeys.Pattern)) + + val expressionNode = createDotNetNodeInfo(isPatternExpression.json(ParserKeys.Expression)) + val expression = astForExpression(expressionNode) + + pattern.node match { + case DeclarationPattern => + val designation = createDotNetNodeInfo(pattern.json(ParserKeys.Designation)) + val typeInfo = createDotNetNodeInfo(pattern.json(ParserKeys.Type)) + + val instanceOfCallNode = newOperatorCallNode( + Operators.instanceOf, + code(pattern), + Option(BuiltinTypes.Bool), + line(expressionNode), + column(expressionNode) + ) + + val assignmentAst = newOperatorCallNode( + Operators.assignment, + s"${typeInfo.code} ${designation.code} = ${expressionNode.code}", + Option(nodeTypeFullName(typeInfo)), + line(expressionNode), + column(expressionNode) + ) + + val designationAst = astForIdentifier(designation, nodeTypeFullName(typeInfo)) + + val typeNode = NewTypeRef() + .code(nodeTypeFullName(typeInfo)) + .lineNumber(line(expressionNode)) + .columnNumber(column(expressionNode)) + .typeFullName(nodeTypeFullName(typeInfo)) + + val conditionAst = callAst(instanceOfCallNode, expression :+ Ast(typeNode)) + val assignmentCallAst = callAst(assignmentAst, designationAst +: expression) + + List(conditionAst, assignmentCallAst) + case ConstantPattern => + val expr = createDotNetNodeInfo(pattern.json(ParserKeys.Expression)) + val exprAst = astForExpression(expr) + + val typeFullName = nodeTypeFullName(expr) + + val equalCallNode = + newOperatorCallNode(Operators.equals, code(pattern), Option(BuiltinTypes.Bool), line(expr), column(expr)) + val equalCallAst = callAst(equalCallNode, expression ++ exprAst) + + List(equalCallAst) + case x => + logger.warn(s"Unsupported pattern in pattern expression, $x") + astForExpression(pattern).toList + } + } + } diff --git a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstForStatementsCreator.scala b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstForStatementsCreator.scala index 66754487b375..9f295ad44fef 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstForStatementsCreator.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstForStatementsCreator.scala @@ -4,8 +4,8 @@ import io.joern.csharpsrc2cpg.CSharpOperators import io.joern.csharpsrc2cpg.parser.DotNetJsonAst.* import io.joern.csharpsrc2cpg.parser.{DotNetNodeInfo, ParserKeys} import io.joern.x2cpg.{Ast, ValidationMode} -import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, DispatchTypes} import io.shiftleft.codepropertygraph.generated.nodes.{NewBlock, NewControlStructure, NewIdentifier} +import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, DispatchTypes} import scala.:: import scala.util.{Success, Try} @@ -16,12 +16,38 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t astForStatement(createDotNetNodeInfo(statement)) } + /** Separates the `AST` result of a conditional expression into the condition as well as any declared variables to + * prepend. + * @param conditionAst + * the condition. + * @param prependIfBody + * statements to prepend to the `if`/`then` body. + */ + final case class ConditionAstResult(conditionAst: Ast, prependIfBody: List[Ast]) + + // TODO: Use this method elsewhere on other control structures + protected def astForConditionNode(condNode: DotNetNodeInfo): ConditionAstResult = { + lazy val default = ConditionAstResult(astForNode(condNode).headOption.getOrElse(Ast()), List.empty) + condNode.node match { + case x: PatternExpr => + astsForIsPatternExpression(condNode) match { + case head :: tail => ConditionAstResult(head, tail) + case Nil => + logger.warn( + s"Unable to handle pattern expression $x in condition expression, resorting to default behaviour" + ) + default + } + case _ => default + } + } + private def astForIfStatement(ifStmt: DotNetNodeInfo): Seq[Ast] = { - val conditionNode = createDotNetNodeInfo(ifStmt.json(ParserKeys.Condition)) - val conditionAst = astForNode(conditionNode).headOption.getOrElse(Ast()) + val conditionNode = createDotNetNodeInfo(ifStmt.json(ParserKeys.Condition)) + val ConditionAstResult(conditionAst, prependIfBody) = astForConditionNode(conditionNode) val thenNode = createDotNetNodeInfo(ifStmt.json(ParserKeys.Statement)) - val thenAst: Ast = astForBlock(thenNode) + val thenAst: Ast = astForBlock(thenNode, prefixAsts = prependIfBody) val ifNode = controlStructureNode(ifStmt, ControlStructureTypes.IF, s"if (${conditionNode.code})") val elseAst = ifStmt.json(ParserKeys.Else) match diff --git a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/parser/DotNetJsonAst.scala b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/parser/DotNetJsonAst.scala index 8e76a3c86782..853465531275 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/parser/DotNetJsonAst.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/parser/DotNetJsonAst.scala @@ -81,6 +81,16 @@ object DotNetJsonAst { object ParenthesizedLambdaExpression extends BaseLambdaExpression + sealed trait PatternExpr extends BaseExpr + + object IsPatternExpression extends PatternExpr + + object DeclarationPattern extends PatternExpr + + object SingleVariableDesignation extends PatternExpr + + object Designation extends PatternExpr + sealed trait ClauseExpr extends BaseExpr object EqualsValueClause extends ClauseExpr @@ -264,64 +274,66 @@ object DotNetJsonAst { */ object ParserKeys { - val AstRoot = "AstRoot" - val Arguments = "Arguments" - val ArgumentList = "ArgumentList" - val AttributeLists = "AttributeLists" - val Attributes = "Attributes" - val BaseList = "BaseList" - val Body = "Body" - val Block = "Block" - val Catches = "Catches" - val Code = "Code" - val ColumnStart = "ColumnStart" - val ColumnEnd = "ColumnEnd" - val Condition = "Condition" - val Contents = "Contents" - val Declaration = "Declaration" - val Elements = "Elements" - val ElementType = "ElementType" - val Else = "Else" - val Expression = "Expression" - val ExpressionElement = "ExpressionElement" - val Expressions = "Expressions" - val ExpressionBody = "ExpressionBody" - val Finally = "Finally" - val FileName = "FileName" - val Identifier = "Identifier" - val Incrementors = "Incrementors" - val Initializer = "Initializer" - val Initializers = "Initializers" - val Keyword = "Keyword" - val Kind = "Kind" - val Labels = "Labels" - val Left = "Left" - val LineStart = "LineStart" - val LineEnd = "LineEnd" - val MetaData = "MetaData" - val Members = "Members" - val Modifiers = "Modifiers" - val Name = "Name" - val NameEquals = "NameEquals" - val Operand = "Operand" - val OperatorToken = "OperatorToken" - val Parameter = "Parameter" - val Parameters = "Parameters" - val ParameterList = "ParameterList" - val Pattern = "Pattern" - val Sections = "Sections" - val Statement = "Statement" - val Statements = "Statements" - val ReturnType = "ReturnType" - val Right = "Right" - val TextToken = "TextToken" - val Type = "Type" - val TypeArgumentList = "TypeArgumentList" - val Types = "Types" - val Usings = "Usings" - val Value = "Value" - val Variables = "Variables" - val WhenFalse = "WhenFalse" - val WhenNotNull = "WhenNotNull" - val WhenTrue = "WhenTrue" + val AstRoot = "AstRoot" + val Arguments = "Arguments" + val ArgumentList = "ArgumentList" + val AttributeLists = "AttributeLists" + val Attributes = "Attributes" + val BaseList = "BaseList" + val Body = "Body" + val Block = "Block" + val Catches = "Catches" + val Code = "Code" + val ColumnStart = "ColumnStart" + val ColumnEnd = "ColumnEnd" + val Condition = "Condition" + val Contents = "Contents" + val Declaration = "Declaration" + val Designation = "Designation" + val Elements = "Elements" + val ElementType = "ElementType" + val Else = "Else" + val Expression = "Expression" + val ExpressionElement = "ExpressionElement" + val Expressions = "Expressions" + val ExpressionBody = "ExpressionBody" + val Finally = "Finally" + val FileName = "FileName" + val Identifier = "Identifier" + val Incrementors = "Incrementors" + val Initializer = "Initializer" + val Initializers = "Initializers" + val Keyword = "Keyword" + val Kind = "Kind" + val Labels = "Labels" + val Left = "Left" + val LineStart = "LineStart" + val LineEnd = "LineEnd" + val MetaData = "MetaData" + val Members = "Members" + val Modifiers = "Modifiers" + val Name = "Name" + val NameEquals = "NameEquals" + val Operand = "Operand" + val OperatorToken = "OperatorToken" + val Parameter = "Parameter" + val Parameters = "Parameters" + val ParameterList = "ParameterList" + val Pattern = "Pattern" + val Sections = "Sections" + val SingleVariableDesignation = "SingleVariableDesignation" + val Statement = "Statement" + val Statements = "Statements" + val ReturnType = "ReturnType" + val Right = "Right" + val TextToken = "TextToken" + val Type = "Type" + val TypeArgumentList = "TypeArgumentList" + val Types = "Types" + val Usings = "Usings" + val Value = "Value" + val Variables = "Variables" + val WhenFalse = "WhenFalse" + val WhenNotNull = "WhenNotNull" + val WhenTrue = "WhenTrue" } diff --git a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/PatternMatchingTests.scala b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/PatternMatchingTests.scala new file mode 100644 index 000000000000..d0321d771061 --- /dev/null +++ b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/PatternMatchingTests.scala @@ -0,0 +1,93 @@ +package io.joern.csharpsrc2cpg.querying.ast + +import io.joern.csharpsrc2cpg.testfixtures.CSharpCode2CpgFixture +import io.shiftleft.codepropertygraph.generated.nodes.{Call, Identifier, Literal, TypeRef} +import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, NodeTypes, Operators} +import io.shiftleft.semanticcpg.language.* + +class PatternMatchingTests extends CSharpCode2CpgFixture { + + "Pattern matching to extract the non-null value in an if-statement" should { + val cpg = code(basicBoilerplate(""" + |int? maybe = 12; + | + |if (maybe is int number) + |{ + | Console.WriteLine($"The nullable int 'maybe' has the value {number}"); + |} + |else + |{ + | Console.WriteLine("The nullable int 'maybe' doesn't hold a value"); + |} + |""".stripMargin)) + + "lower an assignment from `maybe` to `number` as the first statement of the if-body" in { + inside(cpg.assignment.where(_.target.isIdentifier.name("number")).headOption) { + case Some(assignment) => + assignment.order shouldBe 1 + assignment.inAst.exists(_.label == NodeTypes.CONTROL_STRUCTURE) shouldBe true + + inside(assignment.argument.l) { + case (number: Identifier) :: (maybe: Identifier) :: Nil => + number.name shouldBe "number" + number.typeFullName shouldBe "System.Int32" + + maybe.name shouldBe "maybe" + maybe.typeFullName shouldBe "System.Int32" + case xs => fail(s"Expected two identifier arguments, instead got [${xs.code.mkString(",")}]") + } + + case None => fail("Expected an assignment `number = maybe`") + } + } + + "have an instanceOf-style check as the if-condition" in { + inside(cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).condition.headOption) { + case Some(condition: Call) => + condition.name shouldBe Operators.instanceOf + inside(condition.argument.l) { + case (maybe: Identifier) :: (intType: TypeRef) :: Nil => + maybe.name shouldBe "maybe" + maybe.typeFullName shouldBe "System.Int32" + + intType.typeFullName shouldBe "System.Int32" + case xs => + fail( + s"Expected an identifier and type ref argument to `instanceOf`, instead got [${xs.code.mkString(",")}]" + ) + } + case _ => fail("Expected an if-statement with a condition call") + + } + } + } + + "Pattern matching with null type check" should { + val cpg = code(basicBoilerplate(""" + |int? maybe = 12; + | + |if (maybe is null) + |{ + | Console.WriteLine($"The nullable int 'maybe' has the value {number}"); + |} + |""".stripMargin)) + + "have equals check in if statement" in { + inside(cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).condition.headOption) { + case Some(condition: Call) => + condition.name shouldBe Operators.equals + + inside(condition.argument.l) { + case (maybe: Identifier) :: (nullType: Literal) :: Nil => + maybe.name shouldBe "maybe" + maybe.typeFullName shouldBe "System.Int32" + + nullType.typeFullName shouldBe "null" + case xs => fail(s"Expect identifier and literal, instead got [${xs.code.mkString(", ")}]") + } + case _ => fail("Expected an if-statement with condition call") + } + } + } + +}