Skip to content

Commit

Permalink
[c#] support chained ?. expressions (#5310)
Browse files Browse the repository at this point in the history
  • Loading branch information
xavierpinho authored Feb 14, 2025
1 parent ccbd2de commit b5fe810
Show file tree
Hide file tree
Showing 3 changed files with 302 additions and 47 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package io.joern.csharpsrc2cpg.astcreation

import io.joern.csharpsrc2cpg.astcreation.AstParseLevel.FULL_AST
import io.joern.csharpsrc2cpg.astcreation.BuiltinTypes.DotNetTypeMap
import io.joern.csharpsrc2cpg.datastructures.{CSharpMethod, FieldDecl}
import io.joern.csharpsrc2cpg.parser.DotNetJsonAst.*
Expand Down Expand Up @@ -768,54 +767,80 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
)
}

private def astForConditionalAccessExpression(
condAccExpr: DotNetNodeInfo,
baseType: Option[String] = None
): Seq[Ast] = {
val baseNode = createDotNetNodeInfo(condAccExpr.json(ParserKeys.Expression))
val baseTypeFullName =
baseType.orElse(Some(getTypeFullNameFromAstNode(astForNode(baseNode)))).filterNot(_.equals(Defines.Any))
private def makeMemberAccess(expression: DotNetNodeInfo, name: DotNetNodeInfo): DotNetNodeInfo = {
val json = ujson.Obj()
json(ParserKeys.Expression) = expression.json.transform(ujson.Value)
json(ParserKeys.Name) = name.json.transform(ujson.Value)
json(ParserKeys.MetaData) = expression.json(ParserKeys.MetaData).transform(ujson.Value)
json(ParserKeys.MetaData)(ParserKeys.Kind) = "ast.SimpleMemberAccessExpression"

Try(createDotNetNodeInfo(condAccExpr.json(ParserKeys.WhenNotNull))).toOption match {
case Some(node) =>
node.node match {
case ConditionalAccessExpression => astForConditionalAccessExpression(node, baseTypeFullName)
case MemberBindingExpression => astForMemberBindingExpression(node, baseTypeFullName)
case _ => astForNode(node)
}
case None => Seq.empty[Ast]
expression.copy(node = SimpleMemberAccessExpression, json = json)
}

private def makeInvocation(expression: DotNetNodeInfo, args: DotNetNodeInfo): DotNetNodeInfo = {
val json = ujson.Obj()
json(ParserKeys.Expression) = expression.json.transform(ujson.Value)
json(ParserKeys.ArgumentList) = args.json.transform(ujson.Value)
json(ParserKeys.MetaData) = expression.json(ParserKeys.MetaData).transform(ujson.Value)
json(ParserKeys.MetaData)(ParserKeys.Kind) = "ast.InvocationExpression"

expression.copy(node = InvocationExpression, json = json)
}

/** Traverses the "spine" of a chained `?.`/`.` expression. For instance, `x?.y.z?.w` becomes [x, y, z, w]. Notice
* that, whereas `.` is left-associative, `?.` is right-associative.
*/
private def traverseConditionalAccessSpine(expr: DotNetNodeInfo): Seq[DotNetNodeInfo] = {
expr.node match {
case ConditionalAccessExpression =>
val lhs = createDotNetNodeInfo(expr.json(ParserKeys.Expression))
val rhs = createDotNetNodeInfo(expr.json(ParserKeys.WhenNotNull))
lhs +: traverseConditionalAccessSpine(rhs)
case SimpleMemberAccessExpression =>
val lhs = createDotNetNodeInfo(expr.json(ParserKeys.Expression))
val rhs = createDotNetNodeInfo(expr.json(ParserKeys.Name))
traverseConditionalAccessSpine(lhs) :+ rhs
case _ =>
expr :: Nil
}
}

private def astForSuppressNullableWarningExpression(suppressNullableExpr: DotNetNodeInfo): Seq[Ast] = {
val _identifierNode = createDotNetNodeInfo(suppressNullableExpr.json(ParserKeys.Operand))
Seq(astForIdentifier(_identifierNode))
/** Given a sequence of nodes [x, y, z, w], creates the corresponding [[DotNetNodeInfo]] for `x.y.z.w`.
*/
private def rebuildSpineAsMemberAccesses(spine: Seq[DotNetNodeInfo]): Option[DotNetNodeInfo] = {
def combine(lhs: DotNetNodeInfo, rhs: DotNetNodeInfo): DotNetNodeInfo = rhs.node match {
case MemberBindingExpression =>
val name = createDotNetNodeInfo(rhs.json(ParserKeys.Name))
makeMemberAccess(lhs, name)
case InvocationExpression =>
val name = createDotNetNodeInfo(rhs.json(ParserKeys.Expression)(ParserKeys.Name))
val args = createDotNetNodeInfo(rhs.json(ParserKeys.ArgumentList))
makeInvocation(makeMemberAccess(lhs, name), args)
case SimpleMemberAccessExpression =>
val expr = createDotNetNodeInfo(rhs.json(ParserKeys.Expression))
val name = createDotNetNodeInfo(rhs.json(ParserKeys.Name))
makeMemberAccess(makeMemberAccess(lhs, expr), name)
case _ =>
makeMemberAccess(lhs, rhs)
}

spine.foldLeft(None: Option[DotNetNodeInfo]) { case (lhsOpt, rhs) => lhsOpt.map(combine(_, rhs)).orElse(Some(rhs)) }
}

private def astForMemberBindingExpression(
memberBindingExpr: DotNetNodeInfo,
baseTypeFullName: Option[String] = None
): Seq[Ast] = {
val typ = scope
.tryResolveFieldAccess(nameFromNode(memberBindingExpr), baseTypeFullName)
.map(_.typeName)
.map(f => scope.tryResolveTypeReference(f).map(_.name).orElse(Option(f)))
.getOrElse(Option(Defines.Any))

val fieldIdentifier = fieldIdentifierNode(memberBindingExpr, memberBindingExpr.code, memberBindingExpr.code)

val identifier = newIdentifierNode(memberBindingExpr.code, baseTypeFullName.getOrElse(Defines.Any))
val fieldAccess =
newOperatorCallNode(
Operators.fieldAccess,
memberBindingExpr.code,
typ,
memberBindingExpr.lineNumber,
memberBindingExpr.columnNumber
)
val fieldIdentifierAst = Ast(fieldIdentifier)
/** Handles `x?.y` expressions, by rewriting ConditionalAccessExpressions into SimpleMemberAccessExpresions, i.e.
* handling them as if they were `x.y`.
*/
private def astForConditionalAccessExpression(condAccExpr: DotNetNodeInfo): Seq[Ast] =
rebuildSpineAsMemberAccesses(traverseConditionalAccessSpine(condAccExpr)) match {
case None =>
logger.warn(s"Failed to rewrite ${code(condAccExpr)}. Skipping")
Nil
case Some(rewritten) => astForNode(rewritten)
}

Seq(callAst(fieldAccess, Seq(Ast(identifier)) ++ Seq(fieldIdentifierAst)))
private def astForSuppressNullableWarningExpression(suppressNullableExpr: DotNetNodeInfo): Seq[Ast] = {
val _identifierNode = createDotNetNodeInfo(suppressNullableExpr.json(ParserKeys.Operand))
Seq(astForIdentifier(_identifierNode))
}

protected def astForAttributeLists(attributeList: DotNetNodeInfo): Seq[Ast] = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
package io.joern.csharpsrc2cpg.querying.ast

import io.joern.csharpsrc2cpg.testfixtures.CSharpCode2CpgFixture
import io.shiftleft.codepropertygraph.generated.nodes.{Call, Identifier}
import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators}
import io.shiftleft.semanticcpg.language.*

class ConditionalAccessTests extends CSharpCode2CpgFixture {

"`this?.Bar` assigned to a variable" should {
val cpg = code("""
|class Foo
|{
| int Bar = 1;
| void DoStuff()
| {
| var x = this?.Bar;
| }
|}
|
|""".stripMargin)

"be lowered as a field access `this.Bar`" in {
inside(cpg.fieldAccess.where(_.fieldIdentifier.canonicalNameExact("Bar")).l) {
case fieldAccess :: Nil =>
fieldAccess.methodFullName shouldBe Operators.fieldAccess
fieldAccess.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH
fieldAccess.referencedMember.l shouldBe cpg.member("Bar").l
case _ => fail(s"Expected single field access to `Bar`")
}
}

"assigned variable has correct properties" in {
inside(cpg.assignment.where(_.target.isIdentifier.nameExact("x")).l) {
case assign :: Nil =>
assign.code shouldBe "x = this?.Bar"
assign.typeFullName shouldBe "System.Int32"
assign.target.start.isIdentifier.typeFullName.headOption shouldBe Some("System.Int32")
case _ => fail(s"Expected single assignment to `x`")
}
}
}

"`this?.Bar?.Baz` assigned to a variable" should {
val cpg = code("""
|class Foo
|{
| Foo Bar;
| Foo Baz;
| void DoStuff()
| {
| var x = this?.Bar?.Baz;
| }
|}
|
|""".stripMargin)

"be lowered as a field access `this.Bar.Baz`" in {
inside(cpg.fieldAccess.where(_.fieldIdentifier.canonicalNameExact("Baz")).l) {
case bazFieldAccess :: Nil =>
bazFieldAccess.methodFullName shouldBe Operators.fieldAccess
bazFieldAccess.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH
bazFieldAccess.referencedMember.l shouldBe cpg.member("Baz").l

inside(bazFieldAccess.start.argument(1).fieldAccess.l) {
case barFieldAccess :: Nil =>
barFieldAccess.methodFullName shouldBe Operators.fieldAccess
barFieldAccess.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH
barFieldAccess.referencedMember.l shouldBe cpg.member("Bar").l
case _ => fail(s"Expected single field access to `Baz`")
}
case _ => fail(s"Expected single field access to `Bar`")
}
}

"assigned variable has correct properties" in {
inside(cpg.assignment.where(_.target.isIdentifier.nameExact("x")).l) {
case assign :: Nil =>
assign.code shouldBe "x = this?.Bar?.Baz"
assign.typeFullName shouldBe "Foo"
assign.target.start.isIdentifier.typeFullName.headOption shouldBe Some("Foo")
case _ => fail(s"Expected single assignment to `x`")
}
}
}

"`this?.Bar()`" should {
val cpg = code("""
|class Foo
|{
| int Bar() { return 1; }
| void DoStuff()
| {
| this?.Bar();
| }
|}
|""".stripMargin)

"be lowered as a call to `Bar()` with receiver `this`" in {
inside(cpg.call.nameExact("Bar").l) {
case bar :: Nil =>
bar.methodFullName shouldBe "Foo.Bar:System.Int32()"
bar.receiver.l shouldBe bar.argument.argumentIndex(0).l
bar.argument.argumentIndexGt(0) shouldBe empty
inside(bar.argument(0)) {
case thisArg: Identifier =>
thisArg.code shouldBe "this"
thisArg.typeFullName shouldBe "Foo"
case xs => fail(s"Expected single identifier argument to Bar, but got $xs")
}
case xs => fail(s"Expected single call to Bar, but got $xs")
}
}
}

"`this?.Bar()?.Baz()`" should {
val cpg = code("""
|class Foo
|{
| Foo Bar() { return null; }
| Foo Baz() { return null; }
| void DoStuff()
| {
| this?.Bar()?.Baz();
| }
|}
|""".stripMargin)

"have correct properties and arguments to `Baz()`" in {
inside(cpg.call.nameExact("Baz").l) {
case baz :: Nil =>
baz.methodFullName shouldBe "Foo.Baz:Foo()"
baz.receiver.l shouldBe baz.argument.argumentIndex(0).l
baz.argument.argumentIndexGt(0) shouldBe empty
baz.argument(0).start.isCall.methodFullName.l shouldBe List("Foo.Bar:Foo()")
case xs => fail(s"Expected single call to Baz, but got $xs")
}
}
}

"`this?.Bar.Baz` assigned to a variable" should {
val cpg = code("""
|class Foo
|{
| Foo Bar;
| Foo Baz;
| void DoStuff()
| {
| var x = this?.Bar.Baz;
| }
|}
|""".stripMargin)

"be lowered as a field access `this.Bar.Baz`" in {
inside(cpg.fieldAccess.where(_.fieldIdentifier.canonicalNameExact("Baz")).l) {
case bazFieldAccess :: Nil =>
bazFieldAccess.methodFullName shouldBe Operators.fieldAccess
bazFieldAccess.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH
bazFieldAccess.referencedMember.l shouldBe cpg.member("Baz").l

inside(bazFieldAccess.start.argument(1).fieldAccess.l) {
case barFieldAccess :: Nil =>
barFieldAccess.methodFullName shouldBe Operators.fieldAccess
barFieldAccess.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH
barFieldAccess.referencedMember.l shouldBe cpg.member("Bar").l
case _ => fail(s"Expected single field access to `Baz`")
}
case _ => fail(s"Expected single field access to `Bar`")
}
}

"assigned variable has correct properties" in {
inside(cpg.assignment.where(_.target.isIdentifier.nameExact("x")).l) {
case assign :: Nil =>
assign.code shouldBe "x = this?.Bar.Baz"
assign.typeFullName shouldBe "Foo"
assign.target.start.isIdentifier.typeFullName.headOption shouldBe Some("Foo")
case _ => fail(s"Expected single assignment to `x`")
}
}
}

"`this?.Bar(0)?.Baz()?.Quux()`" should {
val cpg = code("""
|class Foo
|{
| Foo Bar(int x) { return null; }
| Foo Baz() { return null; }
| Foo Quux() { return null; }
| void DoStuff()
| {
| this?.Bar(0)?.Baz()?.Quux();
| }
|}
|
|""".stripMargin)

"have correct properties and arguments to `Quux()`" in {
inside(cpg.call.nameExact("Quux").l) {
case quux :: Nil =>
quux.methodFullName shouldBe "Foo.Quux:Foo()"
quux.receiver.l shouldBe quux.argument.argumentIndex(0).l
quux.argument.argumentIndexGt(0) shouldBe empty
quux.argument(0) shouldBe cpg.call.nameExact("Baz").head
case _ => fail("Expected single call to `Quux`")
}
}

"have correct properties and arguments to `Baz()`" in {
inside(cpg.call.nameExact("Baz").l) {
case baz :: Nil =>
baz.methodFullName shouldBe "Foo.Baz:Foo()"
baz.receiver.l shouldBe baz.argument.argumentIndex(0).l
baz.argument.argumentIndexGt(0) shouldBe empty
baz.argument(0) shouldBe cpg.call.nameExact("Bar").head
case _ => fail("Expected single call to `Baz`")
}
}

"have correct properties and arguments to `Bar()`" in {
inside(cpg.call.nameExact("Bar").l) {
case bar :: Nil =>
bar.methodFullName shouldBe "Foo.Bar:Foo(System.Int32)"
bar.receiver.l shouldBe bar.argument.argumentIndex(0).l
bar.argument.argumentIndexGt(0).size shouldBe 1
bar.argument(1) shouldBe cpg.literal("0").head
bar.argument(0) shouldBe cpg.identifier("this").head
case _ => fail("Expected single call to `Bar`")
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,7 @@ import io.shiftleft.semanticcpg.language.*

class MemberAccessTests extends CSharpCode2CpgFixture {

// TODO: This test-case relies on the usage of getters, that are currently being
// reworked to be METHODs instead of MEMBERs. In particular, `bar?.Qux` should
// resemble `bar.get_Qux()`. We need to adapt astForMemberBindingExpression
// to accommodate this.
"conditional property access expressions" ignore {
"conditional property access expressions" should {
val cpg = code("""
|namespace Foo {
| public class Baz {
Expand Down Expand Up @@ -119,6 +115,7 @@ class MemberAccessTests extends CSharpCode2CpgFixture {
}

}

"conditional method access expressions" should {
val cpg = code("""
|namespace Foo {
Expand Down

0 comments on commit b5fe810

Please sign in to comment.