Skip to content

Commit

Permalink
Mild refactorings for PythonTypeRecovery (#2293)
Browse files Browse the repository at this point in the history
* Rename `ComputationalUnit` to `CompilationUnit`
* Simplify/Shorten extraction of possible callee names
* More shortening

---------

Co-authored-by: Fabian Yamaguchi <[email protected]>
  • Loading branch information
fabsx00 and Fabian Yamaguchi authored Feb 22, 2023
1 parent 7ba7b82 commit e77f1ec
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 92 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import scala.util.Try

class PythonTypeRecovery(cpg: Cpg) extends XTypeRecovery[File](cpg) {

override def computationalUnit: Traversal[File] = cpg.file
override def compilationUnit: Traversal[File] = cpg.file

override def generateRecoveryForCompilationUnitTask(
unit: File,
Expand All @@ -26,36 +26,6 @@ class PythonTypeRecovery(cpg: Cpg) extends XTypeRecovery[File](cpg) {

}

/** Defines how a procedure is available to be called in the current scope either by it being defined in this module or
* being imported.
*
* @param callingName
* how this procedure is to be called, i.e., alias name, name with path, etc.
* @param fullName
* the full name to where this method is defined where it's assumed to be defined under a named Python file.
*/
case class ScopedPythonProcedure(callingName: String, fullName: String, isConstructor: Boolean = false) {

/** @return
* the full name of the procedure where it's assumed that it is defined within an <code>__init.py__</code> of the
* module.
*/
private def fullNameAsInit: String = fullName.replace(".py", s"${JFile.separator}__init__.py")

/** @return
* the two ways that this procedure could be resolved to in Python. This will be pruned later by comparing this to
* actual methods in the CPG.
*/
def possibleCalleeNames: Set[String] =
if (isConstructor)
Set(fullName.concat(s".${Defines.ConstructorMethodName}"))
else
Set(fullName, fullNameAsInit)

override def toString: String = s"ProcedureCalledAs(${possibleCalleeNames.mkString(", ")})"

}

/** Performs type recovery from the root of a compilation unit level
*/
class RecoverForPythonFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder, globalTable: SymbolTable[GlobalKey])
Expand All @@ -66,64 +36,63 @@ class RecoverForPythonFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder, global
override def importNodes(cu: AstNode): Traversal[AstNode] =
cu.ast.isCall.nameExact("import") ++ super.importNodes(cu)

override def visitImport(i: Call): Unit = {
i.argumentOut.l match {
case List(path: Literal, funcOrModule: Literal) =>
val calleeNames = extractMethodDetailsFromImport(path.code, funcOrModule.code).possibleCalleeNames
symbolTable.put(CallAlias(funcOrModule.code), calleeNames)
symbolTable.put(LocalVar(funcOrModule.code), calleeNames)
case List(path: Literal, funcOrModule: Literal, alias: Literal) =>
val calleeNames =
extractMethodDetailsFromImport(path.code, funcOrModule.code, Option(alias.code)).possibleCalleeNames
symbolTable.put(CallAlias(alias.code), calleeNames)
symbolTable.put(LocalVar(alias.code), calleeNames)
case x => logger.warn(s"Unknown import pattern: ${x.map(_.label).mkString(", ")}")
override def visitImport(importCall: Call): Unit = {
importCall.argument.l match {
case (path: Literal) :: (funcOrModule: Literal) :: alias =>
val calleeNames = extractPossibleCalleeNames(path.code, funcOrModule.code)
alias match {
case (alias: Literal) :: Nil =>
symbolTable.put(CallAlias(alias.code), calleeNames)
symbolTable.put(LocalVar(alias.code), calleeNames)
case Nil =>
symbolTable.put(CallAlias(funcOrModule.code), calleeNames)
symbolTable.put(LocalVar(funcOrModule.code), calleeNames)
case x =>
logger.warn(s"Unknown import pattern: ${x.map(_.label).mkString(", ")}")
}
}
}

/** Parses all imports and identifies their full names and how they are to be called in this scope.
/** For an import - given by its module path and the name of the imported function or module - determine the possible
* callee names.
*
* @param path
* the module path.
* @param funcOrModule
* the name of the imported entity.
* @param maybeAlias
* an optional alias given to the imported entity.
* @return
* the procedure information in this scope.
* the possible callee names
*/
private def extractMethodDetailsFromImport(
path: String,
funcOrModule: String,
maybeAlias: Option[String] = None
): ScopedPythonProcedure = {
val isConstructor = funcOrModule.split("\\.").last.charAt(0).isUpper
if (path.isEmpty) {
if (funcOrModule.contains(".")) {
// Case 1: We have imported a function using a qualified path, e.g., import foo.bar => (bar.py or bar/__init.py)
private def extractPossibleCalleeNames(path: String, funcOrModule: String): Set[String] = {
val sep = Matcher.quoteReplacement(JFile.separator)
val procedureName = path match {
case "" if funcOrModule.contains(".") =>
// Case 1: Qualified path: import foo.bar => (bar.py or bar/__init.py)
val splitFunc = funcOrModule.split("\\.")
val name = splitFunc.tail.mkString(".")
ScopedPythonProcedure(name, s"${splitFunc(0)}.py:<module>.$name", isConstructor)
} else {
// Case 2: We have imported a module, e.g., import foo => (foo.py or foo/__init.py)
ScopedPythonProcedure(funcOrModule, s"$funcOrModule.py:<module>", isConstructor)
}
} else {
val sep = Matcher.quoteReplacement(JFile.separator)
maybeAlias match {
s"${splitFunc(0)}.py:<module>.$name"
case "" =>
// Case 2: import of a module: import foo => (foo.py or foo/__init.py)
s"$funcOrModule.py:<module>"
case _ =>
// TODO: This assumes importing from modules and never importing nested method
// Case 3: We have imported a function from a module using an alias, e.g. import bar from foo as faz
case Some(alias) =>
ScopedPythonProcedure(alias, s"${path.replaceAll("\\.", sep)}.py:<module>.$funcOrModule", isConstructor)
// Case 4: We have imported a function from a module, e.g. import bar from foo
case None =>
ScopedPythonProcedure(
funcOrModule,
s"${path.replaceAll("\\.", sep)}.py:<module>.$funcOrModule",
isConstructor
)
}
// Case 3: Import from module using alias, e.g. import bar from foo as faz
s"${path.replaceAll("\\.", sep)}.py:<module>.$funcOrModule"
}

/** The two ways that this procedure could be resolved to in Python. */
def possibleCalleeNames(procedureName: String, isConstructor: Boolean): Set[String] =
if (isConstructor)
Set(procedureName.concat(s".${Defines.ConstructorMethodName}"))
else
Set(procedureName, fullNameAsInit)

/** the full name of the procedure where it's assumed that it is defined within an <code>__init.py__</code> of the
* module.
*/
def fullNameAsInit: String = procedureName.replace(".py", s"${JFile.separator}__init__.py")

possibleCalleeNames(procedureName, isConstructor(funcOrModule))
}

override def postVisitImports(): Unit = {
Expand Down Expand Up @@ -158,6 +127,9 @@ class RecoverForPythonFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder, global
override def isConstructor(c: Call): Boolean =
c.name.nonEmpty && c.name.charAt(0).isUpper && c.code.endsWith(")")

def isConstructor(funcOrModule: String): Boolean =
funcOrModule.split("\\.").lastOption.exists(_.charAt(0).isUpper)

/** If the parent method is module then it can be used as a field.
*/
override def isField(i: Identifier): Boolean =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import scala.collection.mutable
/** Based on a flow-insensitive symbol-table-style approach. This pass aims to be fast and deterministic and does not
* try to converge to some fixed point but rather iterates a fixed number of times. This will help recover:
* <ol><li>Imported call signatures from external dependencies</li><li>Dynamic type hints for mutable variables in a
* computational unit.</ol>
* compilation unit.</ol>
*
* The algorithm flows roughly as follows: <ol> <li> Scan for method signatures of methods for each compilation unit,
* either by internally defined methods or by reading import signatures. This includes looking for aliases, e.g. import
Expand All @@ -28,8 +28,8 @@ import scala.collection.mutable
* instances of where these fields and variables are used and update their type information.</li><li>If this variable
* is the receiver of a call, make sure to set the type of the call accordingly.</li></ol>
*
* In order to propagate types across computational units, but avoid the poor scalability of a fixed-point algorithm,
* the number of iterations can be configured using the [[iterations]] parameter. Note that [[iterations]] < 2 will not
* In order to propagate types across compilation units, but avoid the poor scalability of a fixed-point algorithm, the
* number of iterations can be configured using the [[iterations]] parameter. Note that [[iterations]] < 2 will not
* provide any interprocedural type recovery capabilities.
*
* The symbol tables use the [[SymbolTable]] class to track possible type information. <br> <strong>Note: Local symbols
Expand All @@ -41,28 +41,28 @@ import scala.collection.mutable
* @param iterations
* the total number of iterations through which types are to be propagated. At least 2 are recommended in order to
* propagate interprocedural types. Think of this as similar to the dataflowengineoss' 'maxCallDepth'.
* @tparam ComputationalUnit
* the [[AstNode]] type used to represent a computational unit of the language.
* @tparam CompilationUnitType
* the [[AstNode]] type used to represent a compilation unit of the language.
*/
abstract class XTypeRecovery[ComputationalUnit <: AstNode](cpg: Cpg, iterations: Int = 2) extends CpgPass(cpg) {
abstract class XTypeRecovery[CompilationUnitType <: AstNode](cpg: Cpg, iterations: Int = 2) extends CpgPass(cpg) {

/** Stores type information for global structures that persist across computational units, e.g. field identifiers.
/** Stores type information for global structures that persist across compilation units, e.g. field identifiers.
*/
protected val globalTable = new SymbolTable[GlobalKey](SBKey.fromNodeToGlobalKey)

override def run(builder: DiffGraphBuilder): Unit = try {
for (_ <- 0 until iterations)
computationalUnit
compilationUnit
.map(unit => generateRecoveryForCompilationUnitTask(unit, builder).fork())
.foreach(_.get())
} finally {
globalTable.clear()
}

/** @return
* the computational units as per how the language is compiled. e.g. file.
* the compilation units as per how the language is compiled. e.g. file.
*/
def computationalUnit: Traversal[ComputationalUnit]
def compilationUnit: Traversal[CompilationUnitType]

/** A factory method to generate a [[RecoverForXCompilationUnit]] task with the given parameters.
* @param unit
Expand All @@ -73,9 +73,9 @@ abstract class XTypeRecovery[ComputationalUnit <: AstNode](cpg: Cpg, iterations:
* a forkable [[RecoverForXCompilationUnit]] task.
*/
def generateRecoveryForCompilationUnitTask(
unit: ComputationalUnit,
unit: CompilationUnitType,
builder: DiffGraphBuilder
): RecoverForXCompilationUnit[ComputationalUnit]
): RecoverForXCompilationUnit[CompilationUnitType]

}

Expand All @@ -96,12 +96,12 @@ object XTypeRecovery {
* the graph builder
* @param globalTable
* the global symbol table.
* @tparam ComputationalUnit
* the [[AstNode]] type used to represent a computational unit of the language.
* @tparam CompilationUnitType
* the [[AstNode]] type used to represent a compilation unit of the language.
*/
abstract class RecoverForXCompilationUnit[ComputationalUnit <: AstNode](
abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode](
cpg: Cpg,
cu: ComputationalUnit,
cu: CompilationUnitType,
builder: DiffGraphBuilder,
globalTable: SymbolTable[GlobalKey]
) extends RecursiveTask[Unit] {
Expand Down Expand Up @@ -177,7 +177,7 @@ abstract class RecoverForXCompilationUnit[ComputationalUnit <: AstNode](
}

/** @return
* the import nodes of this computational unit.
* the import nodes of this compilation unit.
*/
protected def importNodes(cu: AstNode): Traversal[AstNode] = cu.ast.isImport

Expand Down

0 comments on commit e77f1ec

Please sign in to comment.