diff --git a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstForExpressionsCreator.scala b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstForExpressionsCreator.scala index c56d637a3294..7f127f38e035 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstForExpressionsCreator.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstForExpressionsCreator.scala @@ -1,6 +1,5 @@ package io.joern.csharpsrc2cpg.astcreation -import io.joern.csharpsrc2cpg.astcreation.AstParseLevel.FULL_AST import io.joern.csharpsrc2cpg.astcreation.BuiltinTypes.DotNetTypeMap import io.joern.csharpsrc2cpg.datastructures.{CSharpMethod, FieldDecl} import io.joern.csharpsrc2cpg.parser.DotNetJsonAst.* @@ -768,54 +767,80 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { ) } - private def astForConditionalAccessExpression( - condAccExpr: DotNetNodeInfo, - baseType: Option[String] = None - ): Seq[Ast] = { - val baseNode = createDotNetNodeInfo(condAccExpr.json(ParserKeys.Expression)) - val baseTypeFullName = - baseType.orElse(Some(getTypeFullNameFromAstNode(astForNode(baseNode)))).filterNot(_.equals(Defines.Any)) + private def makeMemberAccess(expression: DotNetNodeInfo, name: DotNetNodeInfo): DotNetNodeInfo = { + val json = ujson.Obj() + json(ParserKeys.Expression) = expression.json.transform(ujson.Value) + json(ParserKeys.Name) = name.json.transform(ujson.Value) + json(ParserKeys.MetaData) = expression.json(ParserKeys.MetaData).transform(ujson.Value) + json(ParserKeys.MetaData)(ParserKeys.Kind) = "ast.SimpleMemberAccessExpression" - Try(createDotNetNodeInfo(condAccExpr.json(ParserKeys.WhenNotNull))).toOption match { - case Some(node) => - node.node match { - case ConditionalAccessExpression => astForConditionalAccessExpression(node, baseTypeFullName) - case MemberBindingExpression => astForMemberBindingExpression(node, baseTypeFullName) - case _ => astForNode(node) - } - case None => Seq.empty[Ast] + expression.copy(node = SimpleMemberAccessExpression, json = json) + } + + private def makeInvocation(expression: DotNetNodeInfo, args: DotNetNodeInfo): DotNetNodeInfo = { + val json = ujson.Obj() + json(ParserKeys.Expression) = expression.json.transform(ujson.Value) + json(ParserKeys.ArgumentList) = args.json.transform(ujson.Value) + json(ParserKeys.MetaData) = expression.json(ParserKeys.MetaData).transform(ujson.Value) + json(ParserKeys.MetaData)(ParserKeys.Kind) = "ast.InvocationExpression" + + expression.copy(node = InvocationExpression, json = json) + } + + /** Traverses the "spine" of a chained `?.`/`.` expression. For instance, `x?.y.z?.w` becomes [x, y, z, w]. Notice + * that, whereas `.` is left-associative, `?.` is right-associative. + */ + private def traverseConditionalAccessSpine(expr: DotNetNodeInfo): Seq[DotNetNodeInfo] = { + expr.node match { + case ConditionalAccessExpression => + val lhs = createDotNetNodeInfo(expr.json(ParserKeys.Expression)) + val rhs = createDotNetNodeInfo(expr.json(ParserKeys.WhenNotNull)) + lhs +: traverseConditionalAccessSpine(rhs) + case SimpleMemberAccessExpression => + val lhs = createDotNetNodeInfo(expr.json(ParserKeys.Expression)) + val rhs = createDotNetNodeInfo(expr.json(ParserKeys.Name)) + traverseConditionalAccessSpine(lhs) :+ rhs + case _ => + expr :: Nil } } - private def astForSuppressNullableWarningExpression(suppressNullableExpr: DotNetNodeInfo): Seq[Ast] = { - val _identifierNode = createDotNetNodeInfo(suppressNullableExpr.json(ParserKeys.Operand)) - Seq(astForIdentifier(_identifierNode)) + /** Given a sequence of nodes [x, y, z, w], creates the corresponding [[DotNetNodeInfo]] for `x.y.z.w`. + */ + private def rebuildSpineAsMemberAccesses(spine: Seq[DotNetNodeInfo]): Option[DotNetNodeInfo] = { + def combine(lhs: DotNetNodeInfo, rhs: DotNetNodeInfo): DotNetNodeInfo = rhs.node match { + case MemberBindingExpression => + val name = createDotNetNodeInfo(rhs.json(ParserKeys.Name)) + makeMemberAccess(lhs, name) + case InvocationExpression => + val name = createDotNetNodeInfo(rhs.json(ParserKeys.Expression)(ParserKeys.Name)) + val args = createDotNetNodeInfo(rhs.json(ParserKeys.ArgumentList)) + makeInvocation(makeMemberAccess(lhs, name), args) + case SimpleMemberAccessExpression => + val expr = createDotNetNodeInfo(rhs.json(ParserKeys.Expression)) + val name = createDotNetNodeInfo(rhs.json(ParserKeys.Name)) + makeMemberAccess(makeMemberAccess(lhs, expr), name) + case _ => + makeMemberAccess(lhs, rhs) + } + + spine.foldLeft(None: Option[DotNetNodeInfo]) { case (lhsOpt, rhs) => lhsOpt.map(combine(_, rhs)).orElse(Some(rhs)) } } - private def astForMemberBindingExpression( - memberBindingExpr: DotNetNodeInfo, - baseTypeFullName: Option[String] = None - ): Seq[Ast] = { - val typ = scope - .tryResolveFieldAccess(nameFromNode(memberBindingExpr), baseTypeFullName) - .map(_.typeName) - .map(f => scope.tryResolveTypeReference(f).map(_.name).orElse(Option(f))) - .getOrElse(Option(Defines.Any)) - - val fieldIdentifier = fieldIdentifierNode(memberBindingExpr, memberBindingExpr.code, memberBindingExpr.code) - - val identifier = newIdentifierNode(memberBindingExpr.code, baseTypeFullName.getOrElse(Defines.Any)) - val fieldAccess = - newOperatorCallNode( - Operators.fieldAccess, - memberBindingExpr.code, - typ, - memberBindingExpr.lineNumber, - memberBindingExpr.columnNumber - ) - val fieldIdentifierAst = Ast(fieldIdentifier) + /** Handles `x?.y` expressions, by rewriting ConditionalAccessExpressions into SimpleMemberAccessExpresions, i.e. + * handling them as if they were `x.y`. + */ + private def astForConditionalAccessExpression(condAccExpr: DotNetNodeInfo): Seq[Ast] = + rebuildSpineAsMemberAccesses(traverseConditionalAccessSpine(condAccExpr)) match { + case None => + logger.warn(s"Failed to rewrite ${code(condAccExpr)}. Skipping") + Nil + case Some(rewritten) => astForNode(rewritten) + } - Seq(callAst(fieldAccess, Seq(Ast(identifier)) ++ Seq(fieldIdentifierAst))) + private def astForSuppressNullableWarningExpression(suppressNullableExpr: DotNetNodeInfo): Seq[Ast] = { + val _identifierNode = createDotNetNodeInfo(suppressNullableExpr.json(ParserKeys.Operand)) + Seq(astForIdentifier(_identifierNode)) } protected def astForAttributeLists(attributeList: DotNetNodeInfo): Seq[Ast] = { diff --git a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/ConditionalAccessTests.scala b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/ConditionalAccessTests.scala new file mode 100644 index 000000000000..e05b8381a2e6 --- /dev/null +++ b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/ConditionalAccessTests.scala @@ -0,0 +1,233 @@ +package io.joern.csharpsrc2cpg.querying.ast + +import io.joern.csharpsrc2cpg.testfixtures.CSharpCode2CpgFixture +import io.shiftleft.codepropertygraph.generated.nodes.{Call, Identifier} +import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} +import io.shiftleft.semanticcpg.language.* + +class ConditionalAccessTests extends CSharpCode2CpgFixture { + + "`this?.Bar` assigned to a variable" should { + val cpg = code(""" + |class Foo + |{ + | int Bar = 1; + | void DoStuff() + | { + | var x = this?.Bar; + | } + |} + | + |""".stripMargin) + + "be lowered as a field access `this.Bar`" in { + inside(cpg.fieldAccess.where(_.fieldIdentifier.canonicalNameExact("Bar")).l) { + case fieldAccess :: Nil => + fieldAccess.methodFullName shouldBe Operators.fieldAccess + fieldAccess.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + fieldAccess.referencedMember.l shouldBe cpg.member("Bar").l + case _ => fail(s"Expected single field access to `Bar`") + } + } + + "assigned variable has correct properties" in { + inside(cpg.assignment.where(_.target.isIdentifier.nameExact("x")).l) { + case assign :: Nil => + assign.code shouldBe "x = this?.Bar" + assign.typeFullName shouldBe "System.Int32" + assign.target.start.isIdentifier.typeFullName.headOption shouldBe Some("System.Int32") + case _ => fail(s"Expected single assignment to `x`") + } + } + } + + "`this?.Bar?.Baz` assigned to a variable" should { + val cpg = code(""" + |class Foo + |{ + | Foo Bar; + | Foo Baz; + | void DoStuff() + | { + | var x = this?.Bar?.Baz; + | } + |} + | + |""".stripMargin) + + "be lowered as a field access `this.Bar.Baz`" in { + inside(cpg.fieldAccess.where(_.fieldIdentifier.canonicalNameExact("Baz")).l) { + case bazFieldAccess :: Nil => + bazFieldAccess.methodFullName shouldBe Operators.fieldAccess + bazFieldAccess.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + bazFieldAccess.referencedMember.l shouldBe cpg.member("Baz").l + + inside(bazFieldAccess.start.argument(1).fieldAccess.l) { + case barFieldAccess :: Nil => + barFieldAccess.methodFullName shouldBe Operators.fieldAccess + barFieldAccess.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + barFieldAccess.referencedMember.l shouldBe cpg.member("Bar").l + case _ => fail(s"Expected single field access to `Baz`") + } + case _ => fail(s"Expected single field access to `Bar`") + } + } + + "assigned variable has correct properties" in { + inside(cpg.assignment.where(_.target.isIdentifier.nameExact("x")).l) { + case assign :: Nil => + assign.code shouldBe "x = this?.Bar?.Baz" + assign.typeFullName shouldBe "Foo" + assign.target.start.isIdentifier.typeFullName.headOption shouldBe Some("Foo") + case _ => fail(s"Expected single assignment to `x`") + } + } + } + + "`this?.Bar()`" should { + val cpg = code(""" + |class Foo + |{ + | int Bar() { return 1; } + | void DoStuff() + | { + | this?.Bar(); + | } + |} + |""".stripMargin) + + "be lowered as a call to `Bar()` with receiver `this`" in { + inside(cpg.call.nameExact("Bar").l) { + case bar :: Nil => + bar.methodFullName shouldBe "Foo.Bar:System.Int32()" + bar.receiver.l shouldBe bar.argument.argumentIndex(0).l + bar.argument.argumentIndexGt(0) shouldBe empty + inside(bar.argument(0)) { + case thisArg: Identifier => + thisArg.code shouldBe "this" + thisArg.typeFullName shouldBe "Foo" + case xs => fail(s"Expected single identifier argument to Bar, but got $xs") + } + case xs => fail(s"Expected single call to Bar, but got $xs") + } + } + } + + "`this?.Bar()?.Baz()`" should { + val cpg = code(""" + |class Foo + |{ + | Foo Bar() { return null; } + | Foo Baz() { return null; } + | void DoStuff() + | { + | this?.Bar()?.Baz(); + | } + |} + |""".stripMargin) + + "have correct properties and arguments to `Baz()`" in { + inside(cpg.call.nameExact("Baz").l) { + case baz :: Nil => + baz.methodFullName shouldBe "Foo.Baz:Foo()" + baz.receiver.l shouldBe baz.argument.argumentIndex(0).l + baz.argument.argumentIndexGt(0) shouldBe empty + baz.argument(0).start.isCall.methodFullName.l shouldBe List("Foo.Bar:Foo()") + case xs => fail(s"Expected single call to Baz, but got $xs") + } + } + } + + "`this?.Bar.Baz` assigned to a variable" should { + val cpg = code(""" + |class Foo + |{ + | Foo Bar; + | Foo Baz; + | void DoStuff() + | { + | var x = this?.Bar.Baz; + | } + |} + |""".stripMargin) + + "be lowered as a field access `this.Bar.Baz`" in { + inside(cpg.fieldAccess.where(_.fieldIdentifier.canonicalNameExact("Baz")).l) { + case bazFieldAccess :: Nil => + bazFieldAccess.methodFullName shouldBe Operators.fieldAccess + bazFieldAccess.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + bazFieldAccess.referencedMember.l shouldBe cpg.member("Baz").l + + inside(bazFieldAccess.start.argument(1).fieldAccess.l) { + case barFieldAccess :: Nil => + barFieldAccess.methodFullName shouldBe Operators.fieldAccess + barFieldAccess.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + barFieldAccess.referencedMember.l shouldBe cpg.member("Bar").l + case _ => fail(s"Expected single field access to `Baz`") + } + case _ => fail(s"Expected single field access to `Bar`") + } + } + + "assigned variable has correct properties" in { + inside(cpg.assignment.where(_.target.isIdentifier.nameExact("x")).l) { + case assign :: Nil => + assign.code shouldBe "x = this?.Bar.Baz" + assign.typeFullName shouldBe "Foo" + assign.target.start.isIdentifier.typeFullName.headOption shouldBe Some("Foo") + case _ => fail(s"Expected single assignment to `x`") + } + } + } + + "`this?.Bar(0)?.Baz()?.Quux()`" should { + val cpg = code(""" + |class Foo + |{ + | Foo Bar(int x) { return null; } + | Foo Baz() { return null; } + | Foo Quux() { return null; } + | void DoStuff() + | { + | this?.Bar(0)?.Baz()?.Quux(); + | } + |} + | + |""".stripMargin) + + "have correct properties and arguments to `Quux()`" in { + inside(cpg.call.nameExact("Quux").l) { + case quux :: Nil => + quux.methodFullName shouldBe "Foo.Quux:Foo()" + quux.receiver.l shouldBe quux.argument.argumentIndex(0).l + quux.argument.argumentIndexGt(0) shouldBe empty + quux.argument(0) shouldBe cpg.call.nameExact("Baz").head + case _ => fail("Expected single call to `Quux`") + } + } + + "have correct properties and arguments to `Baz()`" in { + inside(cpg.call.nameExact("Baz").l) { + case baz :: Nil => + baz.methodFullName shouldBe "Foo.Baz:Foo()" + baz.receiver.l shouldBe baz.argument.argumentIndex(0).l + baz.argument.argumentIndexGt(0) shouldBe empty + baz.argument(0) shouldBe cpg.call.nameExact("Bar").head + case _ => fail("Expected single call to `Baz`") + } + } + + "have correct properties and arguments to `Bar()`" in { + inside(cpg.call.nameExact("Bar").l) { + case bar :: Nil => + bar.methodFullName shouldBe "Foo.Bar:Foo(System.Int32)" + bar.receiver.l shouldBe bar.argument.argumentIndex(0).l + bar.argument.argumentIndexGt(0).size shouldBe 1 + bar.argument(1) shouldBe cpg.literal("0").head + bar.argument(0) shouldBe cpg.identifier("this").head + case _ => fail("Expected single call to `Bar`") + } + } + } + +} diff --git a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/MemberAccessTests.scala b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/MemberAccessTests.scala index 449465e0feef..fe28baa0c655 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/MemberAccessTests.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/MemberAccessTests.scala @@ -7,11 +7,7 @@ import io.shiftleft.semanticcpg.language.* class MemberAccessTests extends CSharpCode2CpgFixture { - // TODO: This test-case relies on the usage of getters, that are currently being - // reworked to be METHODs instead of MEMBERs. In particular, `bar?.Qux` should - // resemble `bar.get_Qux()`. We need to adapt astForMemberBindingExpression - // to accommodate this. - "conditional property access expressions" ignore { + "conditional property access expressions" should { val cpg = code(""" |namespace Foo { | public class Baz { @@ -119,6 +115,7 @@ class MemberAccessTests extends CSharpCode2CpgFixture { } } + "conditional method access expressions" should { val cpg = code(""" |namespace Foo {