Skip to content

Commit

Permalink
[ruby] Lower Array Creation for In Pattern Match (#5229)
Browse files Browse the repository at this point in the history
Pattern match array creation is now lowered so that both reads and writes assign array indices to where each element much be read and written from.
  • Loading branch information
DavidBakerEffendi authored Jan 15, 2025
1 parent 48d26b5 commit c5c8ef5
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 25 deletions.
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
package io.joern.rubysrc2cpg.astcreation

import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.{
ArrayLiteral,
ArrayPattern,
BinaryExpression,
BreakExpression,
CaseExpression,
ControlFlowStatement,
DoWhileExpression,
DummyAst,
DynamicLiteral,
ElseClause,
ForExpression,
IfExpression,
Expand All @@ -17,10 +20,9 @@ import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.{
NextExpression,
OperatorAssignment,
RescueExpression,
ReturnExpression,
RubyExpression,
SimpleCall,
SimpleIdentifier,
SimpleObjectInstantiation,
SingleAssignment,
SplattingRubyNode,
StatementList,
Expand All @@ -32,18 +34,12 @@ import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.{
WhenClause,
WhileExpression
}
import io.joern.rubysrc2cpg.parser.RubyJsonHelpers
import io.joern.rubysrc2cpg.datastructures.BlockScope
import io.joern.rubysrc2cpg.passes.Defines
import io.joern.rubysrc2cpg.passes.Defines.RubyOperators
import io.joern.rubysrc2cpg.passes.Defines.getBuiltInType
import io.joern.x2cpg.{Ast, ValidationMode}
import io.shiftleft.codepropertygraph.generated.nodes.{NewBlock, NewFieldIdentifier, NewLiteral, NewLocal}
import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, DispatchTypes, Operators}
import io.shiftleft.codepropertygraph.generated.nodes.{
NewBlock,
NewFieldIdentifier,
NewIdentifier,
NewLiteral,
NewLocal
}

trait AstForControlStructuresCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator =>

Expand Down Expand Up @@ -335,16 +331,67 @@ trait AstForControlStructuresCreator(implicit withSchemaValidation: ValidationMo
ifElseChain.iterator.toList
}

def generatedNode: StatementList = node.expression
.map { e =>
val tmp = SimpleIdentifier(None)(e.span.spanStart(this.tmpGen.fresh))
StatementList(
List(SingleAssignment(tmp, "=", e)(e.span)) ++
goCase(Some(tmp))
)(node.span)
val caseExpr = node.expression
.map {
case arrayLiteral: ArrayLiteral =>
val tmp = SimpleIdentifier(None)(arrayLiteral.span.spanStart(this.tmpGen.fresh))
val arrayLiteralAst = DummyAst(astForTempArray(arrayLiteral))(arrayLiteral.span)
(tmp, arrayLiteralAst)
case e =>
val tmp = SimpleIdentifier(None)(e.span.spanStart(this.tmpGen.fresh))
(tmp, e)
}
.map((tmp, e) => StatementList(List(SingleAssignment(tmp, "=", e)(e.span)) ++ goCase(Some(tmp)))(node.span))
.getOrElse(StatementList(goCase(None))(node.span))
astsForStatement(generatedNode)

astsForStatement(caseExpr)
}

private def astForTempArray(node: ArrayLiteral): Ast = {
val tmp = this.tmpGen.fresh

def tmpRubyNode(tmpNode: Option[RubyExpression] = None) =
SimpleIdentifier()(tmpNode.map(_.span).getOrElse(node.span).spanStart(tmp))

def tmpAst(tmpNode: Option[RubyExpression] = None) = astForSimpleIdentifier(tmpRubyNode(tmpNode))

val block = blockNode(node, node.text, Defines.Any)
scope.pushNewScope(BlockScope(block))
val tmpLocal = NewLocal().name(tmp).code(tmp)
scope.addToScope(tmp, tmpLocal)

val arguments = if (node.text.startsWith("%")) {
val argumentsType =
if (node.isStringArray) getBuiltInType(Defines.String)
else getBuiltInType(Defines.Symbol)
node.elements.map {
case element @ StaticLiteral(_) => StaticLiteral(argumentsType)(element.span)
case element @ DynamicLiteral(_, expressions) => DynamicLiteral(argumentsType, expressions)(element.span)
case element => element
}
} else {
node.elements
}
val argumentAsts = arguments.zipWithIndex.map { case (arg, idx) =>
val indices = StaticLiteral(getBuiltInType(Defines.Integer))(arg.span.spanStart(idx.toString)) :: Nil
val base = tmpRubyNode(Option(arg))
val indexAccess = IndexAccess(base, indices)(arg.span.spanStart(s"${base.text}[$idx]"))
val assignment = SingleAssignment(indexAccess, "=", arg)(arg.span.spanStart(s"${indexAccess.text} = ${arg.text}"))
astForExpression(assignment)
}

val arrayInitCall = {
val base = SimpleIdentifier()(node.span.spanStart(Defines.Array))
astForExpression(SimpleObjectInstantiation(base, Nil)(node.span))
}

val assignment =
callNode(node, code(node), Operators.assignment, Operators.assignment, DispatchTypes.STATIC_DISPATCH)
val tmpAssignment = callAst(assignment, tmpAst() :: arrayInitCall :: Nil)
val tmpRetAst = tmpAst(node.elements.lastOption)

scope.popScope()
blockAst(block, tmpAssignment +: argumentAsts :+ tmpRetAst)
}

private def astForOperatorAssignmentExpression(node: OperatorAssignment): Ast = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
case node: AccessModifier => astForSimpleIdentifier(node.toSimpleIdentifier)
case node: ArrayPattern => astForArrayPattern(node)
case node: DummyNode => Ast(node.node)
case node: DummyAst => node.ast
case node: Unknown => astForUnknown(node)
case x =>
logger.warn(s"Unhandled expression of type ${x.getClass.getSimpleName}")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.joern.rubysrc2cpg.astcreation

import io.joern.rubysrc2cpg.passes.{Defines, GlobalTypes}
import io.joern.x2cpg.Ast
import io.shiftleft.codepropertygraph.generated.nodes.NewNode

import java.util.Objects
Expand Down Expand Up @@ -621,6 +622,10 @@ object RubyIntermediateAst {
*/
final case class DummyNode(node: NewNode)(span: TextSpan) extends RubyExpression(span)

/** A dummy class for wrapping around `Ast` and allowing it to integrate with RubyNode classes.
*/
final case class DummyAst(ast: Ast)(span: TextSpan) extends RubyExpression(span)

final case class UnaryExpression(op: String, expression: RubyExpression)(span: TextSpan) extends RubyExpression(span)

final case class BinaryExpression(lhs: RubyExpression, op: String, rhs: RubyExpression)(span: TextSpan)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package io.joern.rubysrc2cpg.querying

import io.joern.rubysrc2cpg.passes.Defines.RubyOperators
import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture
import io.shiftleft.codepropertygraph.generated.Operators
import io.shiftleft.codepropertygraph.generated.nodes.*
Expand Down Expand Up @@ -115,11 +114,11 @@ class CaseTests extends RubyCode2CpgFixture {

val block @ List(_) = cpg.method.name("class_for").block.astChildren.isBlock.l

val List(assign) = block.astChildren.assignment.l;
val assign = block.astChildren.assignment.head
val List(lhs, rhs) = assign.argument.l

lhs.start.isIdentifier.name.l shouldBe List("<tmp-0>")
rhs.start.isCall.code.l shouldBe List("[type, location]")
rhs.start.isBlock.code.l shouldBe List("[type, location]") // array lowering

val headIf @ List(_) = block.astChildren.isControlStructure.l
val ifStmts @ List(_, _, _) = headIf.repeat(_.astChildren.order(3).astChildren.isControlStructure)(_.emit).l;
Expand Down Expand Up @@ -165,11 +164,11 @@ class CaseTests extends RubyCode2CpgFixture {

val block @ List(_) = cpg.method.name("class_for").block.astChildren.isBlock.l

val List(assign, _, _) = block.astChildren.assignment.l;
val List(lhs, rhs) = assign.argument.l
val assign = block.astChildren.assignment.head
val List(lhs, rhs) = assign.argument.l

lhs.start.isIdentifier.name.l shouldBe List("<tmp-0>")
rhs.start.isCall.code.l shouldBe List("[type, location]")
rhs.start.isBlock.code.l shouldBe List("[type, location]") // where the array lowering happens

val headIf @ List(_) = block.astChildren.isControlStructure.l
val ifStmts @ List(_, _) = headIf.repeat(_.astChildren.order(3).astChildren.isControlStructure)(_.emit).l;
Expand Down

0 comments on commit c5c8ef5

Please sign in to comment.