From c5c8ef5f8ce8a552d87182019eba3a5b52eeaea5 Mon Sep 17 00:00:00 2001 From: David Baker Effendi Date: Wed, 15 Jan 2025 11:27:33 +0200 Subject: [PATCH] [ruby] Lower Array Creation for In Pattern Match (#5229) Pattern match array creation is now lowered so that both reads and writes assign array indices to where each element much be read and written from. --- .../AstForControlStructuresCreator.scala | 85 ++++++++++++++----- .../AstForExpressionsCreator.scala | 1 + .../astcreation/RubyIntermediateAst.scala | 5 ++ .../rubysrc2cpg/querying/CaseTests.scala | 11 ++- 4 files changed, 77 insertions(+), 25 deletions(-) diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForControlStructuresCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForControlStructuresCreator.scala index 832db51de288..a9f1c0db0dc2 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForControlStructuresCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForControlStructuresCreator.scala @@ -1,12 +1,15 @@ package io.joern.rubysrc2cpg.astcreation import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.{ + ArrayLiteral, ArrayPattern, BinaryExpression, BreakExpression, CaseExpression, ControlFlowStatement, DoWhileExpression, + DummyAst, + DynamicLiteral, ElseClause, ForExpression, IfExpression, @@ -17,10 +20,9 @@ import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.{ NextExpression, OperatorAssignment, RescueExpression, - ReturnExpression, RubyExpression, - SimpleCall, SimpleIdentifier, + SimpleObjectInstantiation, SingleAssignment, SplattingRubyNode, StatementList, @@ -32,18 +34,12 @@ import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.{ WhenClause, WhileExpression } -import io.joern.rubysrc2cpg.parser.RubyJsonHelpers +import io.joern.rubysrc2cpg.datastructures.BlockScope import io.joern.rubysrc2cpg.passes.Defines -import io.joern.rubysrc2cpg.passes.Defines.RubyOperators +import io.joern.rubysrc2cpg.passes.Defines.getBuiltInType import io.joern.x2cpg.{Ast, ValidationMode} +import io.shiftleft.codepropertygraph.generated.nodes.{NewBlock, NewFieldIdentifier, NewLiteral, NewLocal} import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, DispatchTypes, Operators} -import io.shiftleft.codepropertygraph.generated.nodes.{ - NewBlock, - NewFieldIdentifier, - NewIdentifier, - NewLiteral, - NewLocal -} trait AstForControlStructuresCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator => @@ -335,16 +331,67 @@ trait AstForControlStructuresCreator(implicit withSchemaValidation: ValidationMo ifElseChain.iterator.toList } - def generatedNode: StatementList = node.expression - .map { e => - val tmp = SimpleIdentifier(None)(e.span.spanStart(this.tmpGen.fresh)) - StatementList( - List(SingleAssignment(tmp, "=", e)(e.span)) ++ - goCase(Some(tmp)) - )(node.span) + val caseExpr = node.expression + .map { + case arrayLiteral: ArrayLiteral => + val tmp = SimpleIdentifier(None)(arrayLiteral.span.spanStart(this.tmpGen.fresh)) + val arrayLiteralAst = DummyAst(astForTempArray(arrayLiteral))(arrayLiteral.span) + (tmp, arrayLiteralAst) + case e => + val tmp = SimpleIdentifier(None)(e.span.spanStart(this.tmpGen.fresh)) + (tmp, e) } + .map((tmp, e) => StatementList(List(SingleAssignment(tmp, "=", e)(e.span)) ++ goCase(Some(tmp)))(node.span)) .getOrElse(StatementList(goCase(None))(node.span)) - astsForStatement(generatedNode) + + astsForStatement(caseExpr) + } + + private def astForTempArray(node: ArrayLiteral): Ast = { + val tmp = this.tmpGen.fresh + + def tmpRubyNode(tmpNode: Option[RubyExpression] = None) = + SimpleIdentifier()(tmpNode.map(_.span).getOrElse(node.span).spanStart(tmp)) + + def tmpAst(tmpNode: Option[RubyExpression] = None) = astForSimpleIdentifier(tmpRubyNode(tmpNode)) + + val block = blockNode(node, node.text, Defines.Any) + scope.pushNewScope(BlockScope(block)) + val tmpLocal = NewLocal().name(tmp).code(tmp) + scope.addToScope(tmp, tmpLocal) + + val arguments = if (node.text.startsWith("%")) { + val argumentsType = + if (node.isStringArray) getBuiltInType(Defines.String) + else getBuiltInType(Defines.Symbol) + node.elements.map { + case element @ StaticLiteral(_) => StaticLiteral(argumentsType)(element.span) + case element @ DynamicLiteral(_, expressions) => DynamicLiteral(argumentsType, expressions)(element.span) + case element => element + } + } else { + node.elements + } + val argumentAsts = arguments.zipWithIndex.map { case (arg, idx) => + val indices = StaticLiteral(getBuiltInType(Defines.Integer))(arg.span.spanStart(idx.toString)) :: Nil + val base = tmpRubyNode(Option(arg)) + val indexAccess = IndexAccess(base, indices)(arg.span.spanStart(s"${base.text}[$idx]")) + val assignment = SingleAssignment(indexAccess, "=", arg)(arg.span.spanStart(s"${indexAccess.text} = ${arg.text}")) + astForExpression(assignment) + } + + val arrayInitCall = { + val base = SimpleIdentifier()(node.span.spanStart(Defines.Array)) + astForExpression(SimpleObjectInstantiation(base, Nil)(node.span)) + } + + val assignment = + callNode(node, code(node), Operators.assignment, Operators.assignment, DispatchTypes.STATIC_DISPATCH) + val tmpAssignment = callAst(assignment, tmpAst() :: arrayInitCall :: Nil) + val tmpRetAst = tmpAst(node.elements.lastOption) + + scope.popScope() + blockAst(block, tmpAssignment +: argumentAsts :+ tmpRetAst) } private def astForOperatorAssignmentExpression(node: OperatorAssignment): Ast = { diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala index 23dad50f0729..b01501df13ee 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala @@ -63,6 +63,7 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { case node: AccessModifier => astForSimpleIdentifier(node.toSimpleIdentifier) case node: ArrayPattern => astForArrayPattern(node) case node: DummyNode => Ast(node.node) + case node: DummyAst => node.ast case node: Unknown => astForUnknown(node) case x => logger.warn(s"Unhandled expression of type ${x.getClass.getSimpleName}") diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/RubyIntermediateAst.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/RubyIntermediateAst.scala index ddbd417b900d..1889a7015437 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/RubyIntermediateAst.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/RubyIntermediateAst.scala @@ -1,6 +1,7 @@ package io.joern.rubysrc2cpg.astcreation import io.joern.rubysrc2cpg.passes.{Defines, GlobalTypes} +import io.joern.x2cpg.Ast import io.shiftleft.codepropertygraph.generated.nodes.NewNode import java.util.Objects @@ -621,6 +622,10 @@ object RubyIntermediateAst { */ final case class DummyNode(node: NewNode)(span: TextSpan) extends RubyExpression(span) + /** A dummy class for wrapping around `Ast` and allowing it to integrate with RubyNode classes. + */ + final case class DummyAst(ast: Ast)(span: TextSpan) extends RubyExpression(span) + final case class UnaryExpression(op: String, expression: RubyExpression)(span: TextSpan) extends RubyExpression(span) final case class BinaryExpression(lhs: RubyExpression, op: String, rhs: RubyExpression)(span: TextSpan) diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/CaseTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/CaseTests.scala index 00bff922deaf..b2c6320162e9 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/CaseTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/CaseTests.scala @@ -1,6 +1,5 @@ package io.joern.rubysrc2cpg.querying -import io.joern.rubysrc2cpg.passes.Defines.RubyOperators import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture import io.shiftleft.codepropertygraph.generated.Operators import io.shiftleft.codepropertygraph.generated.nodes.* @@ -115,11 +114,11 @@ class CaseTests extends RubyCode2CpgFixture { val block @ List(_) = cpg.method.name("class_for").block.astChildren.isBlock.l - val List(assign) = block.astChildren.assignment.l; + val assign = block.astChildren.assignment.head val List(lhs, rhs) = assign.argument.l lhs.start.isIdentifier.name.l shouldBe List("") - rhs.start.isCall.code.l shouldBe List("[type, location]") + rhs.start.isBlock.code.l shouldBe List("[type, location]") // array lowering val headIf @ List(_) = block.astChildren.isControlStructure.l val ifStmts @ List(_, _, _) = headIf.repeat(_.astChildren.order(3).astChildren.isControlStructure)(_.emit).l; @@ -165,11 +164,11 @@ class CaseTests extends RubyCode2CpgFixture { val block @ List(_) = cpg.method.name("class_for").block.astChildren.isBlock.l - val List(assign, _, _) = block.astChildren.assignment.l; - val List(lhs, rhs) = assign.argument.l + val assign = block.astChildren.assignment.head + val List(lhs, rhs) = assign.argument.l lhs.start.isIdentifier.name.l shouldBe List("") - rhs.start.isCall.code.l shouldBe List("[type, location]") + rhs.start.isBlock.code.l shouldBe List("[type, location]") // where the array lowering happens val headIf @ List(_) = block.astChildren.isControlStructure.l val ifStmts @ List(_, _) = headIf.repeat(_.astChildren.order(3).astChildren.isControlStructure)(_.emit).l;