Skip to content

Commit

Permalink
[C#] Pattern Matching for IfExpr (#4415)
Browse files Browse the repository at this point in the history
* [c#] Pattern Matching Statements

* Initial WIP

* bump DotNetAstGen version

* [c#] Added lowering for pattern expressions in if statements. Handles declaration and constant patterns

* [c#] pr comments

* [c#] bump dotnetastgen version

* [c#] rollback version on dotnetastgen

---------

Co-authored-by: David Baker Effendi <[email protected]>
  • Loading branch information
AndreiDreyer and DavidBakerEffendi authored Apr 2, 2024
1 parent cdcf180 commit 93567c7
Show file tree
Hide file tree
Showing 6 changed files with 262 additions and 69 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
csharpsrc2cpg {
dotnetastgen_version: "0.28.0"
dotnetastgen_version: "0.29.0"
}
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ object AstCreatorHelper {
def nameFromNode(node: DotNetNodeInfo): String = {
node.node match
case NamespaceDeclaration | UsingDirective | FileScopedNamespaceDeclaration => nameFromNamespaceDeclaration(node)
case IdentifierName | Parameter | _: DeclarationExpr | GenericName =>
case IdentifierName | Parameter | _: DeclarationExpr | GenericName | SingleVariableDesignation =>
nameFromIdentifier(node)
case QualifiedName => nameFromQualifiedName(node)
case SimpleMemberAccessExpression | MemberBindingExpression | SuppressNullableWarningExpression | Attribute =>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
package io.joern.csharpsrc2cpg.astcreation

import io.joern.csharpsrc2cpg.CSharpOperators
import io.joern.csharpsrc2cpg.datastructures.CSharpMethod
import io.joern.csharpsrc2cpg.parser.DotNetJsonAst.*
import io.joern.csharpsrc2cpg.parser.{DotNetNodeInfo, ParserKeys}
import io.joern.x2cpg.utils.NodeBuilders.{newIdentifierNode, newOperatorCallNode}
import io.joern.csharpsrc2cpg.{CSharpOperators, Constants}
import io.joern.x2cpg.utils.NodeBuilders.{newCallNode, newIdentifierNode, newOperatorCallNode}
import io.joern.x2cpg.{Ast, Defines, ValidationMode}
import io.shiftleft.codepropertygraph.generated.nodes.{NewFieldIdentifier, NewLiteral, NewTypeRef}
import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators}
import ujson.Value
import io.joern.csharpsrc2cpg.Constants

import scala.collection.mutable.ArrayBuffer
import scala.util.{Failure, Success, Try}
Expand Down Expand Up @@ -564,4 +563,67 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
val _annotationNode = annotationNode(attribute, attribute.code, attributeName, fullName)
annotationAst(_annotationNode, argumentAsts)
}

/** Lowers a pattern expression into a condition and then a declaration if one occurs.
* @param isPatternExpression
* a pattern expression which may include a declaration.
* @return
* a condition and then (potentially) declaration.
*/
protected def astsForIsPatternExpression(isPatternExpression: DotNetNodeInfo): List[Ast] = {
val pattern = createDotNetNodeInfo(isPatternExpression.json(ParserKeys.Pattern))

val expressionNode = createDotNetNodeInfo(isPatternExpression.json(ParserKeys.Expression))
val expression = astForExpression(expressionNode)

pattern.node match {
case DeclarationPattern =>
val designation = createDotNetNodeInfo(pattern.json(ParserKeys.Designation))
val typeInfo = createDotNetNodeInfo(pattern.json(ParserKeys.Type))

val instanceOfCallNode = newOperatorCallNode(
Operators.instanceOf,
code(pattern),
Option(BuiltinTypes.Bool),
line(expressionNode),
column(expressionNode)
)

val assignmentAst = newOperatorCallNode(
Operators.assignment,
s"${typeInfo.code} ${designation.code} = ${expressionNode.code}",
Option(nodeTypeFullName(typeInfo)),
line(expressionNode),
column(expressionNode)
)

val designationAst = astForIdentifier(designation, nodeTypeFullName(typeInfo))

val typeNode = NewTypeRef()
.code(nodeTypeFullName(typeInfo))
.lineNumber(line(expressionNode))
.columnNumber(column(expressionNode))
.typeFullName(nodeTypeFullName(typeInfo))

val conditionAst = callAst(instanceOfCallNode, expression :+ Ast(typeNode))
val assignmentCallAst = callAst(assignmentAst, designationAst +: expression)

List(conditionAst, assignmentCallAst)
case ConstantPattern =>
val expr = createDotNetNodeInfo(pattern.json(ParserKeys.Expression))
val exprAst = astForExpression(expr)

val typeFullName = nodeTypeFullName(expr)

val equalCallNode =
newOperatorCallNode(Operators.equals, code(pattern), Option(BuiltinTypes.Bool), line(expr), column(expr))
val equalCallAst = callAst(equalCallNode, expression ++ exprAst)

List(equalCallAst)
case x =>
logger.warn(s"Unsupported pattern in pattern expression, $x")
astForExpression(pattern).toList
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ import io.joern.csharpsrc2cpg.CSharpOperators
import io.joern.csharpsrc2cpg.parser.DotNetJsonAst.*
import io.joern.csharpsrc2cpg.parser.{DotNetNodeInfo, ParserKeys}
import io.joern.x2cpg.{Ast, ValidationMode}
import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, DispatchTypes}
import io.shiftleft.codepropertygraph.generated.nodes.{NewBlock, NewControlStructure, NewIdentifier}
import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, DispatchTypes}

import scala.::
import scala.util.{Success, Try}
Expand All @@ -16,12 +16,38 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t
astForStatement(createDotNetNodeInfo(statement))
}

/** Separates the `AST` result of a conditional expression into the condition as well as any declared variables to
* prepend.
* @param conditionAst
* the condition.
* @param prependIfBody
* statements to prepend to the `if`/`then` body.
*/
final case class ConditionAstResult(conditionAst: Ast, prependIfBody: List[Ast])

// TODO: Use this method elsewhere on other control structures
protected def astForConditionNode(condNode: DotNetNodeInfo): ConditionAstResult = {
lazy val default = ConditionAstResult(astForNode(condNode).headOption.getOrElse(Ast()), List.empty)
condNode.node match {
case x: PatternExpr =>
astsForIsPatternExpression(condNode) match {
case head :: tail => ConditionAstResult(head, tail)
case Nil =>
logger.warn(
s"Unable to handle pattern expression $x in condition expression, resorting to default behaviour"
)
default
}
case _ => default
}
}

private def astForIfStatement(ifStmt: DotNetNodeInfo): Seq[Ast] = {
val conditionNode = createDotNetNodeInfo(ifStmt.json(ParserKeys.Condition))
val conditionAst = astForNode(conditionNode).headOption.getOrElse(Ast())
val conditionNode = createDotNetNodeInfo(ifStmt.json(ParserKeys.Condition))
val ConditionAstResult(conditionAst, prependIfBody) = astForConditionNode(conditionNode)

val thenNode = createDotNetNodeInfo(ifStmt.json(ParserKeys.Statement))
val thenAst: Ast = astForBlock(thenNode)
val thenAst: Ast = astForBlock(thenNode, prefixAsts = prependIfBody)
val ifNode =
controlStructureNode(ifStmt, ControlStructureTypes.IF, s"if (${conditionNode.code})")
val elseAst = ifStmt.json(ParserKeys.Else) match
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,16 @@ object DotNetJsonAst {

object ParenthesizedLambdaExpression extends BaseLambdaExpression

sealed trait PatternExpr extends BaseExpr

object IsPatternExpression extends PatternExpr

object DeclarationPattern extends PatternExpr

object SingleVariableDesignation extends PatternExpr

object Designation extends PatternExpr

sealed trait ClauseExpr extends BaseExpr

object EqualsValueClause extends ClauseExpr
Expand Down Expand Up @@ -264,64 +274,66 @@ object DotNetJsonAst {
*/
object ParserKeys {

val AstRoot = "AstRoot"
val Arguments = "Arguments"
val ArgumentList = "ArgumentList"
val AttributeLists = "AttributeLists"
val Attributes = "Attributes"
val BaseList = "BaseList"
val Body = "Body"
val Block = "Block"
val Catches = "Catches"
val Code = "Code"
val ColumnStart = "ColumnStart"
val ColumnEnd = "ColumnEnd"
val Condition = "Condition"
val Contents = "Contents"
val Declaration = "Declaration"
val Elements = "Elements"
val ElementType = "ElementType"
val Else = "Else"
val Expression = "Expression"
val ExpressionElement = "ExpressionElement"
val Expressions = "Expressions"
val ExpressionBody = "ExpressionBody"
val Finally = "Finally"
val FileName = "FileName"
val Identifier = "Identifier"
val Incrementors = "Incrementors"
val Initializer = "Initializer"
val Initializers = "Initializers"
val Keyword = "Keyword"
val Kind = "Kind"
val Labels = "Labels"
val Left = "Left"
val LineStart = "LineStart"
val LineEnd = "LineEnd"
val MetaData = "MetaData"
val Members = "Members"
val Modifiers = "Modifiers"
val Name = "Name"
val NameEquals = "NameEquals"
val Operand = "Operand"
val OperatorToken = "OperatorToken"
val Parameter = "Parameter"
val Parameters = "Parameters"
val ParameterList = "ParameterList"
val Pattern = "Pattern"
val Sections = "Sections"
val Statement = "Statement"
val Statements = "Statements"
val ReturnType = "ReturnType"
val Right = "Right"
val TextToken = "TextToken"
val Type = "Type"
val TypeArgumentList = "TypeArgumentList"
val Types = "Types"
val Usings = "Usings"
val Value = "Value"
val Variables = "Variables"
val WhenFalse = "WhenFalse"
val WhenNotNull = "WhenNotNull"
val WhenTrue = "WhenTrue"
val AstRoot = "AstRoot"
val Arguments = "Arguments"
val ArgumentList = "ArgumentList"
val AttributeLists = "AttributeLists"
val Attributes = "Attributes"
val BaseList = "BaseList"
val Body = "Body"
val Block = "Block"
val Catches = "Catches"
val Code = "Code"
val ColumnStart = "ColumnStart"
val ColumnEnd = "ColumnEnd"
val Condition = "Condition"
val Contents = "Contents"
val Declaration = "Declaration"
val Designation = "Designation"
val Elements = "Elements"
val ElementType = "ElementType"
val Else = "Else"
val Expression = "Expression"
val ExpressionElement = "ExpressionElement"
val Expressions = "Expressions"
val ExpressionBody = "ExpressionBody"
val Finally = "Finally"
val FileName = "FileName"
val Identifier = "Identifier"
val Incrementors = "Incrementors"
val Initializer = "Initializer"
val Initializers = "Initializers"
val Keyword = "Keyword"
val Kind = "Kind"
val Labels = "Labels"
val Left = "Left"
val LineStart = "LineStart"
val LineEnd = "LineEnd"
val MetaData = "MetaData"
val Members = "Members"
val Modifiers = "Modifiers"
val Name = "Name"
val NameEquals = "NameEquals"
val Operand = "Operand"
val OperatorToken = "OperatorToken"
val Parameter = "Parameter"
val Parameters = "Parameters"
val ParameterList = "ParameterList"
val Pattern = "Pattern"
val Sections = "Sections"
val SingleVariableDesignation = "SingleVariableDesignation"
val Statement = "Statement"
val Statements = "Statements"
val ReturnType = "ReturnType"
val Right = "Right"
val TextToken = "TextToken"
val Type = "Type"
val TypeArgumentList = "TypeArgumentList"
val Types = "Types"
val Usings = "Usings"
val Value = "Value"
val Variables = "Variables"
val WhenFalse = "WhenFalse"
val WhenNotNull = "WhenNotNull"
val WhenTrue = "WhenTrue"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package io.joern.csharpsrc2cpg.querying.ast

import io.joern.csharpsrc2cpg.testfixtures.CSharpCode2CpgFixture
import io.shiftleft.codepropertygraph.generated.nodes.{Call, Identifier, Literal, TypeRef}
import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, NodeTypes, Operators}
import io.shiftleft.semanticcpg.language.*

class PatternMatchingTests extends CSharpCode2CpgFixture {

"Pattern matching to extract the non-null value in an if-statement" should {
val cpg = code(basicBoilerplate("""
|int? maybe = 12;
|
|if (maybe is int number)
|{
| Console.WriteLine($"The nullable int 'maybe' has the value {number}");
|}
|else
|{
| Console.WriteLine("The nullable int 'maybe' doesn't hold a value");
|}
|""".stripMargin))

"lower an assignment from `maybe` to `number` as the first statement of the if-body" in {
inside(cpg.assignment.where(_.target.isIdentifier.name("number")).headOption) {
case Some(assignment) =>
assignment.order shouldBe 1
assignment.inAst.exists(_.label == NodeTypes.CONTROL_STRUCTURE) shouldBe true

inside(assignment.argument.l) {
case (number: Identifier) :: (maybe: Identifier) :: Nil =>
number.name shouldBe "number"
number.typeFullName shouldBe "System.Int32"

maybe.name shouldBe "maybe"
maybe.typeFullName shouldBe "System.Int32"
case xs => fail(s"Expected two identifier arguments, instead got [${xs.code.mkString(",")}]")
}

case None => fail("Expected an assignment `number = maybe`")
}
}

"have an instanceOf-style check as the if-condition" in {
inside(cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).condition.headOption) {
case Some(condition: Call) =>
condition.name shouldBe Operators.instanceOf
inside(condition.argument.l) {
case (maybe: Identifier) :: (intType: TypeRef) :: Nil =>
maybe.name shouldBe "maybe"
maybe.typeFullName shouldBe "System.Int32"

intType.typeFullName shouldBe "System.Int32"
case xs =>
fail(
s"Expected an identifier and type ref argument to `instanceOf`, instead got [${xs.code.mkString(",")}]"
)
}
case _ => fail("Expected an if-statement with a condition call")

}
}
}

"Pattern matching with null type check" should {
val cpg = code(basicBoilerplate("""
|int? maybe = 12;
|
|if (maybe is null)
|{
| Console.WriteLine($"The nullable int 'maybe' has the value {number}");
|}
|""".stripMargin))

"have equals check in if statement" in {
inside(cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).condition.headOption) {
case Some(condition: Call) =>
condition.name shouldBe Operators.equals

inside(condition.argument.l) {
case (maybe: Identifier) :: (nullType: Literal) :: Nil =>
maybe.name shouldBe "maybe"
maybe.typeFullName shouldBe "System.Int32"

nullType.typeFullName shouldBe "null"
case xs => fail(s"Expect identifier and literal, instead got [${xs.code.mkString(", ")}]")
}
case _ => fail("Expected an if-statement with condition call")
}
}
}

}

0 comments on commit 93567c7

Please sign in to comment.