Skip to content

Commit

Permalink
pysrc2cpg: propagation of return types (#2328)
Browse files Browse the repository at this point in the history
* stash

* Propagate dynamic type hint

* stash

* Fix works now

* Remove `println`s

* scalafmt

---------

Co-authored-by: Fabian Yamaguchi <[email protected]>
  • Loading branch information
fabsx00 and Fabian Yamaguchi authored Mar 1, 2023
1 parent 94d69c7 commit bf75d7f
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -502,9 +502,7 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) {

"recover its full name successfully" in {
val List(methodFullName) = cpg.call("query").methodFullName.l
methodFullName shouldBe Seq("data", "db_session.py:<module>.create_session.<returnValue>.query").mkString(
File.separator
)
methodFullName shouldBe "sqlalchemy.orm.Session.query"
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ abstract class XTypeRecovery[CompilationUnitType <: AstNode](cpg: Cpg, iteration
compilationUnit
.map(unit => generateRecoveryForCompilationUnitTask(unit, builder).fork())
.foreach(_.get())

} finally {
globalTable.clear()
addedNodes.clear()
Expand Down Expand Up @@ -131,6 +132,26 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode](
protected def members: Traversal[Member] =
cu.ast.isMember

protected def visitCall(call: Call): Unit = {
symbolTable.get(call).foreach { methodFullName =>
val index = methodFullName.indexOf("<returnValue>")
if (index != -1) {
val methodToLookup = methodFullName.substring(0, index - 1)
val remainder = methodFullName.substring(index + "<returnValue>".length)
val methods = cpg.method.fullNameExact(methodToLookup).l
methods match {
case List(method) =>
val dynamicTypeHints = method.methodReturn.dynamicTypeHintFullName.toSet
val newTypes = dynamicTypeHints.map(x => x + remainder)
symbolTable.put(call, newTypes)
case List() =>
case _ =>
logger.warn(s"More than a single function matches method full name: $methodToLookup")
}
}
}
}

override def compute(): Unit = try {
prepopulateSymbolTable()
// Set known aliases that point to imports for local and external methods/modules
Expand All @@ -141,6 +162,9 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode](
members.foreach(visitMembers)
// Populate local symbol table with assignments
assignments.foreach(visitAssignments)
// Propagate return values
cpg.method.ast.isCall.foreach(visitCall)

// Persist findings
setTypeInformation()
} finally {
Expand Down

0 comments on commit bf75d7f

Please sign in to comment.