Skip to content

Commit

Permalink
pysrc2cpg: Determine if a Field Access is a Method Ref (#2326)
Browse files Browse the repository at this point in the history
* Added another import condition that was missing a kind of import
* Added crude "addedNode" tracker to type recovery
* Using the type recovery, will determine if the field access corresponds to an internal method and will add a method ref to the parent node
  • Loading branch information
DavidBakerEffendi authored Mar 1, 2023
1 parent 143be0f commit 47ca2b6
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ import overflowdb.BatchedUpdate.DiffGraphBuilder
import overflowdb.traversal.Traversal

import java.io.{File => JFile}
import java.nio.file.{Path, Paths}
import java.nio.file.Paths
import java.util.regex.Matcher
import scala.collection.mutable
import scala.util.Try

class PythonTypeRecovery(cpg: Cpg) extends XTypeRecovery[File](cpg) {
Expand All @@ -23,14 +24,19 @@ class PythonTypeRecovery(cpg: Cpg) extends XTypeRecovery[File](cpg) {
override def generateRecoveryForCompilationUnitTask(
unit: File,
builder: DiffGraphBuilder
): RecoverForXCompilationUnit[File] = new RecoverForPythonFile(cpg, unit, builder, globalTable)
): RecoverForXCompilationUnit[File] = new RecoverForPythonFile(cpg, unit, builder, globalTable, addedNodes)

}

/** Performs type recovery from the root of a compilation unit level
*/
class RecoverForPythonFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder, globalTable: SymbolTable[GlobalKey])
extends RecoverForXCompilationUnit[File](cpg, cu, builder, globalTable) {
class RecoverForPythonFile(
cpg: Cpg,
cu: File,
builder: DiffGraphBuilder,
globalTable: SymbolTable[GlobalKey],
addedNodes: mutable.Set[(Long, String)]
) extends RecoverForXCompilationUnit[File](cpg, cu, builder, globalTable, addedNodes) {

/** Overridden to include legacy import calls until imports are supported.
*/
Expand Down Expand Up @@ -84,7 +90,9 @@ class RecoverForPythonFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder, global
val pyFile = absPath.map(a => Paths.get(a.toString + ".py"))
fileOrDir match {
case Some(f) if f.isDirectory && !pyFile.exists { p => better.files.File(p).exists } =>
s"${path.replaceAll("\\.", sep)}/$funcOrModule.py:<module>"
s"${path.replaceAll("\\.", sep)}${java.io.File.separator}$funcOrModule.py:<module>"
case Some(f) if f.isDirectory && (f / s"$funcOrModule.py").exists =>
s"${(f / s"$funcOrModule.py").pathAsString}:<module>"
case _ =>
s"${path.replaceAll("\\.", sep)}.py:<module>.$funcOrModule"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,36 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) {

"recover its full name successfully" in {
val List(methodFullName) = cpg.call("query").methodFullName.l
methodFullName shouldBe "data/db_session.py:<module>.create_session.<returnValue>.query"
methodFullName shouldBe Seq("data", "db_session.py:<module>.create_session.<returnValue>.query").mkString(
File.separator
)
}
}

"recover a method ref from a field identifier" should {
lazy val cpg = code(
"""
|from django.conf.urls import url
|
|from student import views
|
|urlpatterns = [
| url(r'^addStudent/$', views.add_student)
|]
|""".stripMargin,
"urls.py"
).moreCode(
"""
|def add_student():
| pass
|""".stripMargin,
s"student${File.separator}views.py"
)

"recover the method full name related" in {
val Some(methodRef) = cpg.methodRef.code("views.add_student").headOption
methodRef.methodFullName shouldBe Seq("student", "views.py:<module>.add_student").mkString(File.separator)
methodRef.typeFullName shouldBe "<empty>"
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package io.joern.x2cpg.passes.frontend
import io.joern.x2cpg.Defines
import io.shiftleft.codepropertygraph.Cpg
import io.shiftleft.codepropertygraph.generated.nodes._
import io.shiftleft.codepropertygraph.generated.{Operators, PropertyNames}
import io.shiftleft.codepropertygraph.generated.{EdgeTypes, Operators, PropertyNames}
import io.shiftleft.passes.CpgPass
import io.shiftleft.semanticcpg.language._
import io.shiftleft.semanticcpg.language.operatorextension.OpNodes
Expand Down Expand Up @@ -50,13 +50,19 @@ abstract class XTypeRecovery[CompilationUnitType <: AstNode](cpg: Cpg, iteration
*/
protected val globalTable = new SymbolTable[GlobalKey](SBKey.fromNodeToGlobalKey)

/** In the case of new nodes being added, will make sure these aren't duplicated because of future iterations. This
* comes in pairs of parent ID -> string identifier.
*/
protected val addedNodes = mutable.HashSet.empty[(Long, String)]

override def run(builder: DiffGraphBuilder): Unit = try {
for (_ <- 0 until iterations)
compilationUnit
.map(unit => generateRecoveryForCompilationUnitTask(unit, builder).fork())
.foreach(_.get())
} finally {
globalTable.clear()
addedNodes.clear()
}

/** @return
Expand Down Expand Up @@ -96,14 +102,17 @@ object XTypeRecovery {
* the graph builder
* @param globalTable
* the global symbol table.
* @param addedNodes
* new node tracking set.
* @tparam CompilationUnitType
* the [[AstNode]] type used to represent a compilation unit of the language.
*/
abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode](
cpg: Cpg,
cu: CompilationUnitType,
builder: DiffGraphBuilder,
globalTable: SymbolTable[GlobalKey]
globalTable: SymbolTable[GlobalKey],
addedNodes: mutable.Set[(Long, String)]
) extends RecursiveTask[Unit] {

protected val logger: Logger = LoggerFactory.getLogger(getClass)
Expand Down Expand Up @@ -691,21 +700,63 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode](
persistType(call, callTypes)(builder)
}
// Case 3: 'i' is the receiver for a field access on member 'f'
case (Some(call: Call), List(i: Identifier, f: FieldIdentifier)) if call.name.equals(Operators.fieldAccess) =>
case (Some(fieldAccess: Call), List(i: Identifier, f: FieldIdentifier))
if fieldAccess.name.equals(Operators.fieldAccess) =>
val iTypes = if (symbolTable.contains(i)) symbolTable.get(i) else symbolTable.get(CallAlias(i.name))
val cTypes = symbolTable.get(call)
val cTypes = symbolTable.get(fieldAccess)
persistType(i, iTypes)(builder)
persistType(call, cTypes)(builder)
Traversal.from(call.astParent).isCall.headOption match {
persistType(fieldAccess, cTypes)(builder)
Traversal.from(fieldAccess.astParent).isCall.headOption match {
case Some(callFromFieldName) if symbolTable.contains(callFromFieldName) =>
persistType(callFromFieldName, symbolTable.get(callFromFieldName))(builder)
case Some(callFromFieldName) if iTypes.nonEmpty =>
persistType(callFromFieldName, iTypes.map(it => s"$it.${f.canonicalName}"))(builder)
case _ =>
}
// This field may be a function pointer
handlePotentialFunctionPointer(fieldAccess, i.name, iTypes, f)
case _ => persistType(x, symbolTable.get(x))(builder)
}

/** In the case this field access is a function pointer, we would want to make sure this has a method ref.
*/
private def handlePotentialFunctionPointer(
fieldAccess: Call,
baseName: String,
baseTypes: Set[String],
f: FieldIdentifier
): Unit = {
baseTypes
.map(t => s"$t.${f.canonicalName}")
.flatMap(p => cpg.method.fullNameExact(p))
.map { m =>
(
m,
NewMethodRef()
.code(s"$baseName.${f.canonicalName}")
.methodFullName(m.fullName)
.argumentIndex(f.argumentIndex + 1)
.lineNumber(fieldAccess.lineNumber)
.columnNumber(fieldAccess.columnNumber)
)
}
.filterNot { case (_, mRef) =>
addedNodes.contains((fieldAccess.id(), s"${mRef.label()}.${mRef.methodFullName}"))
}
.foreach { case (m, mRef) =>
fieldAccess.astParent
.filterNot(_.astChildren.isMethodRef.methodFullName(mRef.methodFullName).nonEmpty)
.foreach { inCall =>
builder.addNode(mRef)
builder.addEdge(mRef, m, EdgeTypes.REF)
builder.addEdge(inCall, mRef, EdgeTypes.AST)
builder.addEdge(inCall, mRef, EdgeTypes.ARGUMENT)
mRef.argumentIndex(inCall.astChildren.size)
}
addedNodes.add((fieldAccess.id(), s"${mRef.label()}.${mRef.methodFullName}"))
}
}

private def persistType(x: StoredNode, types: Set[String])(implicit builder: DiffGraphBuilder): Unit =
if (types.nonEmpty)
if (types.size == 1)
Expand Down

0 comments on commit 47ca2b6

Please sign in to comment.