diff --git a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/PythonTypeRecovery.scala b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/PythonTypeRecovery.scala
index 4946d4490a33..5f0677cb9192 100644
--- a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/PythonTypeRecovery.scala
+++ b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/PythonTypeRecovery.scala
@@ -1,13 +1,12 @@
package io.joern.pysrc2cpg
+import io.joern.pysrc2cpg.PythonTypeRecovery.BUILTIN_PREFIX
import io.joern.x2cpg.Defines
import io.joern.x2cpg.passes.frontend._
import io.shiftleft.codepropertygraph.Cpg
-import io.shiftleft.codepropertygraph.generated.Operators
import io.shiftleft.codepropertygraph.generated.nodes._
import io.shiftleft.semanticcpg.language._
-import io.shiftleft.semanticcpg.language.operatorextension.OpNodes
-import io.shiftleft.semanticcpg.language.operatorextension.OpNodes.{Assignment, FieldAccess}
+import io.shiftleft.semanticcpg.language.operatorextension.OpNodes.FieldAccess
import overflowdb.BatchedUpdate.DiffGraphBuilder
import overflowdb.traversal.Traversal
@@ -21,8 +20,7 @@ class PythonTypeRecovery(cpg: Cpg) extends XTypeRecovery[File](cpg) {
override def generateRecoveryForCompilationUnitTask(
unit: File,
- builder: DiffGraphBuilder,
- globalTable: SymbolTable[GlobalKey]
+ builder: DiffGraphBuilder
): RecoverForXCompilationUnit[File] = new RecoverForPythonFile(cpg, unit, builder, globalTable)
}
@@ -35,8 +33,7 @@ class PythonTypeRecovery(cpg: Cpg) extends XTypeRecovery[File](cpg) {
* @param fullName
* the full name to where this method is defined where it's assumed to be defined under a named Python file.
*/
-class ScopedPythonProcedure(callingName: String, fullName: String, isConstructor: Boolean = false)
- extends ScopedXProcedure(callingName, fullName, isConstructor) {
+case class ScopedPythonProcedure(callingName: String, fullName: String, isConstructor: Boolean = false) {
/** @return
* the full name of the procedure where it's assumed that it is defined within an __init.py__
of the
@@ -48,44 +45,41 @@ class ScopedPythonProcedure(callingName: String, fullName: String, isConstructor
* the two ways that this procedure could be resolved to in Python. This will be pruned later by comparing this to
* actual methods in the CPG.
*/
- override def possibleCalleeNames: Set[String] =
+ def possibleCalleeNames: Set[String] =
if (isConstructor)
Set(fullName.concat(s".${Defines.ConstructorMethodName}"))
else
Set(fullName, fullNameAsInit)
+ override def toString: String = s"ProcedureCalledAs(${possibleCalleeNames.mkString(", ")})"
+
}
-/** Tasks responsible for populating the symbol table with import data and method definition data.
- *
- * @param node
- * a node that references import information.
+/** Performs type recovery from the root of a compilation unit level
*/
-class SetPythonProcedureDefTask(node: CfgNode, symbolTable: SymbolTable[LocalKey]) extends SetXProcedureDefTask(node) {
+class RecoverForPythonFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder, globalTable: SymbolTable[GlobalKey])
+ extends RecoverForXCompilationUnit[File](cpg, cu, builder, globalTable) {
- /** Refers to the declared import information.
- *
- * @param importCall
- * the call that imports entities into this scope.
+ /** Overriden to include legacy import calls until imports are supported.
*/
- override def visitImport(importCall: Call): Unit = {
- importCall.argumentOut.l match {
+ override def importNodes(cu: AstNode): Traversal[AstNode] =
+ cu.ast.isCall.nameExact("import") ++ super.importNodes(cu)
+
+ override def visitImport(i: Call): Unit = {
+ i.argumentOut.l match {
case List(path: Literal, funcOrModule: Literal) =>
val calleeNames = extractMethodDetailsFromImport(path.code, funcOrModule.code).possibleCalleeNames
symbolTable.put(CallAlias(funcOrModule.code), calleeNames)
+ symbolTable.put(LocalVar(funcOrModule.code), calleeNames)
case List(path: Literal, funcOrModule: Literal, alias: Literal) =>
val calleeNames =
extractMethodDetailsFromImport(path.code, funcOrModule.code, Option(alias.code)).possibleCalleeNames
symbolTable.put(CallAlias(alias.code), calleeNames)
+ symbolTable.put(LocalVar(alias.code), calleeNames)
case x => logger.warn(s"Unknown import pattern: ${x.map(_.label).mkString(", ")}")
}
}
- override def visitImport(m: Method): Unit = {
- val calleeNames = new ScopedPythonProcedure(m.name, m.fullName).possibleCalleeNames
- symbolTable.put(m, calleeNames)
- }
-
/** Parses all imports and identifies their full names and how they are to be called in this scope.
*
* @param path
@@ -101,7 +95,7 @@ class SetPythonProcedureDefTask(node: CfgNode, symbolTable: SymbolTable[LocalKey
path: String,
funcOrModule: String,
maybeAlias: Option[String] = None
- ): ScopedXProcedure = {
+ ): ScopedPythonProcedure = {
val isConstructor = funcOrModule.split("\\.").last.charAt(0).isUpper
if (path.isEmpty) {
if (funcOrModule.contains(".")) {
@@ -131,29 +125,6 @@ class SetPythonProcedureDefTask(node: CfgNode, symbolTable: SymbolTable[LocalKey
}
}
-}
-
-/** Performs type recovery from the root of a compilation unit level
- *
- * @param cu
- * a compilation unit, e.g. file.
- * @param builder
- * the graph builder
- */
-class RecoverForPythonFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder, globalTable: SymbolTable[GlobalKey])
- extends RecoverForXCompilationUnit[File](cu, builder) {
-
- /** Adds built-in functions to expect.
- */
- override def prepopulateSymbolTable(): Unit =
- PythonTypeRecovery.BUILTINS
- .map(t => (CallAlias(t), s"${PythonTypeRecovery.BUILTIN_PREFIX}.$t"))
- .foreach { case (alias, typ) =>
- symbolTable.put(alias, typ)
- }
-
- override def importNodes(cu: AstNode): Traversal[CfgNode] = cu.ast.isCall.nameExact("import")
-
override def postVisitImports(): Unit = {
symbolTable.view.foreach { case (k, v) =>
val ms = cpg.method.fullNameExact(v.toSeq: _*).l
@@ -166,11 +137,9 @@ class RecoverForPythonFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder, global
// This is likely external and we will ignore the init variant to be consistent
symbolTable.put(k, symbolTable(k).filterNot(_.contains("__init__.py")))
}
-
// Imports are by default used as calls, a second pass will tell us if this is not the case and we should
// check against global table
- // TODO: This is a bit of a bandaid compared to potentially having alias sensitivity. Values could be an
- // Either[SBKey, Set[String] where Left[SBKey] could point to the aliased symbol
+ // TODO: This is a bit of a bandaid compared to potentially having alias sensitivity.
def fieldVar(path: String) = FieldVar(path.stripSuffix(s".${k.identifier}"), k.identifier)
@@ -182,292 +151,87 @@ class RecoverForPythonFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder, global
}
}
- override def generateSetProcedureDefTask(node: CfgNode, symbolTable: SymbolTable[LocalKey]): SetXProcedureDefTask =
- new SetPythonProcedureDefTask(node, symbolTable)
-
- /** Using assignment and import information (in the global symbol table), will propagate these types in the symbol
- * table.
- *
- * @param assignment
- * assignment call pointer.
+ /** Determines if a function call is a constructor by following the heuristic that Python classes are typically
+ * camel-case and start with an upper-case character.
*/
- override def visitAssignments(assignment: Assignment): Unit = {
- assignment.argumentOut.take(2).l match {
- case List(i: Identifier, c: Call) if symbolTable.contains(c) =>
- val importedTypes = symbolTable.get(c)
- setIdentifierFromFunctionType(i, c.name, c.code, importedTypes)
- case List(i: Identifier, c: CfgNode)
- if visitLiteralAssignment(i, c, symbolTable) => // if unsuccessful, then check next
- case List(i: Identifier, c: Call) if c.receiver.isCall.name.exists(_.equals(Operators.fieldAccess)) =>
- val field = c.receiver.isCall.name(Operators.fieldAccess).map(new OpNodes.FieldAccess(_)).head
- visitCallFromFieldMember(i, c, field, symbolTable)
- // Use global table knowledge (in i >= 2 iterations) or CPG to extract field types
- case List(_: Identifier, c: Call) if c.name.equals(Operators.fieldAccess) =>
- c.inCall.argument
- .flatMap {
- case n: Call if n.name.equals(Operators.fieldAccess) => new OpNodes.FieldAccess(n).argumentOut
- case n => n
- }
- .take(3)
- .l match {
- case List(assigned: Identifier, i: Identifier, f: FieldIdentifier)
- if symbolTable.contains(CallAlias(i.name)) =>
- // Get field from global table if referenced as function call
- val fieldTypes = symbolTable
- .get(CallAlias(i.name))
- .flatMap(recModule => globalTable.get(FieldVar(recModule, f.canonicalName)))
- if (fieldTypes.nonEmpty) symbolTable.append(assigned, fieldTypes)
- case List(assigned: Identifier, i: Identifier, f: FieldIdentifier)
- if symbolTable
- .contains(LocalVar(i.name)) =>
- // Get field from global table if referenced as a variable
- val localTypes = symbolTable.get(LocalVar(i.name))
- val memberTypes = localTypes
- .flatMap { t =>
- cpg.typeDecl.fullNameExact(t).member.nameExact(f.canonicalName).l ++
- cpg.typeDecl.fullNameExact(t).method.fullNameExact(t).l
- }
- .flatMap {
- case m: Member => Some(m.typeFullName)
- case m: Method => Some(m.fullName)
- case _ => None
- }
- if (memberTypes.nonEmpty)
- // First use the member type info from the CPG, if present
- symbolTable.append(assigned, memberTypes)
- else if (localTypes.nonEmpty) {
- // If not available, use a dummy variable that can be useful for call matching
- symbolTable.append(assigned, localTypes.map { t => s"$t.(${f.canonicalName})" })
- }
- case List(assigned: Identifier, i: Identifier, f: FieldIdentifier)
- if symbolTable.contains(CallAlias(s"${i.name}.${f.canonicalName}")) =>
- // In this case, if are the paths of an import, e.g. import foo.bar and it is referred to as foo.bar later
- // TODO: Does this handle foo.bar.baz ?
- val callAlias = CallAlias(s"${i.name}.${f.canonicalName}")
- val importedTypes = symbolTable.get(callAlias)
- setIdentifierFromFunctionType(assigned, callAlias.identifier, callAlias.identifier, importedTypes)
- case List(assigned: Identifier, i: Identifier, f: FieldIdentifier) =>
- // TODO: This is really tricky to find without proper object tracking, so we match name only
- val fieldTypes = globalTable.view.filter(_._1.identifier.equals(f.canonicalName)).flatMap(_._2).toSet
- if (fieldTypes.nonEmpty) symbolTable.append(assigned, fieldTypes)
- case List(assigned: Identifier, c: Call, f: FieldIdentifier) if c.name.equals(Operators.fieldAccess) =>
- // TODO: This is the step that handles foo.bar.baz, which gives the impression that there is the need to
- // handle this pattern recursively
- val baseType = c.astChildren.isFieldIdentifier.canonicalName
- .zip(c.astChildren.isIdentifier.flatMap(symbolTable.get))
- .map { case (ff, bt) => s"$bt.($ff).(${f.canonicalName})" }
- .toSet
- if (baseType.nonEmpty) symbolTable.append(assigned, baseType)
+ override def isConstructor(c: Call): Boolean =
+ c.name.nonEmpty && c.name.charAt(0).isUpper && c.code.endsWith(")")
- case _ =>
- }
- // Field load from call
- case List(fl: Call, c: Call) if fl.name.equals(Operators.fieldAccess) && symbolTable.contains(c) =>
- (fl.astChildren.l, c.astChildren.l) match {
- case (List(self: Identifier, fieldIdentifier: FieldIdentifier), args: List[_]) =>
- symbolTable.append(fieldIdentifier, symbolTable.get(c))
- globalTable.append(fieldVarName(fieldIdentifier), symbolTable.get(c))
- case _ =>
- }
- // Field load from index access
- case List(fl: Call, c: Call) if fl.name.equals(Operators.fieldAccess) && c.name.equals(Operators.indexAccess) =>
- (fl.astChildren.l, c.astChildren.l) match {
- case (List(self: Identifier, fieldIdentifier: FieldIdentifier), ::(rhsFAccess: Call, _))
- if rhsFAccess.name.equals(Operators.fieldAccess) =>
- val rhsField = rhsFAccess.fieldAccess.fieldIdentifier.head
- // TODO: Check if a type for the RHS index access is recovered
- val types = symbolTable.get(rhsField).map(t => s"$t.")
- symbolTable.append(fieldIdentifier, types)
- globalTable.append(fieldVarName(fieldIdentifier), types)
- case _ =>
- }
- case _ =>
+ /** If the parent method is module then it can be used as a field.
+ */
+ override def isField(i: Identifier): Boolean =
+ i.method.name.matches("(|__init__)") || super.isField(i)
+
+ override def visitIdentifierAssignedToOperator(i: Identifier, c: Call, operation: String): Set[String] = {
+ operation match {
+ case ".listLiteral" => associateTypes(i, Set(s"$BUILTIN_PREFIX.list"))
+ case ".tupleLiteral" => associateTypes(i, Set(s"$BUILTIN_PREFIX.tuple"))
+ case ".dictLiteral" => associateTypes(i, Set(s"$BUILTIN_PREFIX.dict"))
+ case _ => super.visitIdentifierAssignedToOperator(i, c, operation)
}
}
- private def fieldVarName(f: FieldIdentifier): FieldVar = {
- if (f.astSiblings.map(_.code).exists(_.contains("self"))) {
- // This will match the type decl
- FieldVar(f.method.typeDecl.fullName.head, f.canonicalName)
- } else {
- // This will typically match the
- FieldVar(f.file.method.fullName.head, f.canonicalName)
- }
+ override def visitIdentifierAssignedToConstructor(i: Identifier, c: Call): Set[String] = {
+ val constructorPaths = symbolTable
+ .get(c)
+ .map(_.stripSuffix(s".${Defines.ConstructorMethodName}"))
+ .map(x => (x.split("\\.").last, x))
+ .map {
+ case (x, y) => s"$y.$x"
+ case (_, z) => z
+ }
+ associateTypes(i, constructorPaths)
}
- private def setIdentifierFromFunctionType(
- i: Identifier,
- callName: String,
- callCode: String,
- importedTypes: Set[String]
- ): Unit = {
- if (!callCode.endsWith(")")) {
- // Case 1: The identifier is at the assignment to a function pointer. Lack of parenthesis should indicate this.
- setIdentifier(i, importedTypes)
- } else if (!callName.isBlank && callName.charAt(0).isUpper && callCode.endsWith(")")) {
- // Case 2: The identifier is receiving a constructor invocation, thus is now an instance of the type
- setIdentifier(
- i,
- importedTypes
- .map(_.stripSuffix(s".${Defines.ConstructorMethodName}"))
- .map(x => (x.split("\\.").last, x))
- .map {
- case (x, y) => s"$y.$x"
- case (_, z) => z
- }
- )
- } else {
- // TODO: This identifier should contain the type of the return value of 'c'.
- // e.g. x = foo(a, b) but not x = y.foo(a, b) as foo in the latter case is interpreted as a field access
- }
+ override def visitIdentifierAssignedToCall(i: Identifier, c: Call): Set[String] = {
+ // Ignore legacy import representation
+ if (c.name.equals("import")) Set.empty
+ // Stop custom annotation representation from hitting superclass
+ else if (c.name.isBlank) Set.empty
+ else super.visitIdentifierAssignedToCall(i, c)
}
- private def setIdentifier(i: Identifier, types: Set[String]): Option[Set[String]] = {
- if (i.method.name.equals("")) globalTable.put(i, types)
- symbolTable.append(i, types)
+ override def visitIdentifierAssignedToFieldLoad(i: Identifier, fa: FieldAccess): Set[String] = {
+ val fieldParents = getFieldParents(fa)
+ fa.astChildren.l match {
+ case List(base: Identifier, fi: FieldIdentifier) if base.name.equals("self") && fieldParents.nonEmpty =>
+ val globalTypes = fieldParents.flatMap(fp => globalTable.get(FieldVar(fp, fi.canonicalName)))
+ associateTypes(i, globalTypes)
+ case _ => super.visitIdentifierAssignedToFieldLoad(i, fa)
+ }
}
- private def visitCallFromFieldMember(
- i: Identifier,
- c: Call,
- field: FieldAccess,
- symbolTable: SymbolTable[LocalKey]
- ): Unit = {
- field.astChildren.l match {
- case List(rec: Identifier, f: FieldIdentifier) if symbolTable.contains(rec) =>
- // First we ask if the call receiver is known as a variable
- val identifierFullName = symbolTable.get(rec).map(_.concat(s".${f.canonicalName}"))
- val callMethodFullName =
- if (f.canonicalName.charAt(0).isUpper)
- identifierFullName.map(_.concat(s".${Defines.ConstructorMethodName}"))
- else
- identifierFullName.map(x => (x.split("\\.").takeRight(2), x)).map {
- case (Array(x, y), z) if x.charAt(0).isUpper && !z.contains("") => s"${z.stripSuffix(y)}$x.$y"
- case (_, y) => y
- }
- symbolTable.put(i, identifierFullName)
- symbolTable.put(c, callMethodFullName)
- case List(rec: Identifier, f: FieldIdentifier) if symbolTable.contains(CallAlias(rec.name)) =>
- // Second we ask if the call receiver is known as a function pointer (imports are interpreted as functions first)
- val funcTypes = symbolTable.get(CallAlias(rec.name)).map(t => s"$t.${f.canonicalName}")
- // TODO: Look in the CPG if we can resolve the method return value
- symbolTable.put(i, funcTypes.map(t => s"$t."))
- symbolTable.put(c, funcTypes)
- case _ =>
+ override def getFieldParents(fa: FieldAccess): Set[String] = {
+ if (fa.method.name.equals("")) {
+ Set(fa.method.fullName)
+ } else if (fa.method.typeDecl.nonEmpty) {
+ val parentTypes =
+ fa.method.typeDecl.fullName.map(_.stripSuffix("")).map { t => s"$t.${t.split("\\.").last}" }.toSeq
+ val baseTypes = cpg.typeDecl.fullNameExact(parentTypes: _*).inheritsFromTypeFullName.toSeq
+ // TODO: inheritsFromTypeFullName does not give full name in pysrc2cpg
+ val baseTypeFullNames = cpg.typ.nameExact(baseTypes: _*).fullName.toSeq
+ (parentTypes ++ baseTypeFullNames)
+ .map(_.concat("."))
+ .filterNot(t => t.toLowerCase.matches("(any|object)"))
+ .toSet
+ } else {
+ super.getFieldParents(fa)
}
}
- /** Will handle literal value assignments.
- * @param lhs
- * the identifier.
- * @param rhs
- * the literal.
- * @param symbolTable
- * the symbol table.
- * @return
- * true if a literal assigment was successfully determined and added to the symbol table, false if otherwise.
- */
- private def visitLiteralAssignment(lhs: Identifier, rhs: CfgNode, symbolTable: SymbolTable[LocalKey]): Boolean = {
- ((lhs, rhs) match {
- case (i: Identifier, l: Literal) if Try(java.lang.Integer.parseInt(l.code)).isSuccess =>
- setIdentifier(i, Set("int"))
- case (i: Identifier, l: Literal) if Try(java.lang.Double.parseDouble(l.code)).isSuccess =>
- setIdentifier(i, Set("float"))
- case (i: Identifier, l: Literal) if "True".equals(l.code) || "False".equals(l.code) =>
- setIdentifier(i, Set("bool"))
- case (i: Identifier, l: Literal) if l.code.matches("^(\"|').*(\"|')$") =>
- setIdentifier(i, Set("str"))
- case (i: Identifier, c: Call) if c.name.equals(".listLiteral") =>
- setIdentifier(i, Set("list"))
- case (i: Identifier, c: Call) if c.name.equals(".tupleLiteral") =>
- setIdentifier(i, Set("tuple"))
- case (i: Identifier, b: Block)
- if b.astChildren.isCall.headOption.exists(
- _.argument.isCall.exists(_.name.equals(".dictLiteral"))
- ) =>
- setIdentifier(i, Set("dict"))
- case _ => None
- }).hasNext
+ override def getLiteralType(l: Literal): Set[String] = {
+ l match {
+ case _ if Try(java.lang.Integer.parseInt(l.code)).isSuccess => Set(s"$BUILTIN_PREFIX.int")
+ case _ if Try(java.lang.Double.parseDouble(l.code)).isSuccess => Set(s"$BUILTIN_PREFIX.float")
+ case _ if "True".equals(l.code) || "False".equals(l.code) => Set(s"$BUILTIN_PREFIX.bool")
+ case _ if l.code.matches("^(\"|').*(\"|')$") => Set(s"$BUILTIN_PREFIX.str")
+ case _ if l.code.equals("None") => Set(s"$BUILTIN_PREFIX.None")
+ case _ => Set()
+ }
}
}
object PythonTypeRecovery {
-
- /** @see
- * Python Built-in Functions
- */
- lazy val BUILTINS: Set[String] = Set(
- "abs",
- "aiter",
- "all",
- "anext",
- "ascii",
- "bin",
- "bool",
- "breakpoint",
- "bytearray",
- "bytes",
- "callable",
- "chr",
- "classmethod",
- "compile",
- "complex",
- "delattr",
- "dict",
- "dir",
- "divmod",
- "enumerate",
- "eval",
- "exec",
- "filter",
- "float",
- "format",
- "frozenset",
- "getattr",
- "globals",
- "hasattr",
- "hash",
- "help",
- "hex",
- "id",
- "input",
- "int",
- "isinstance",
- "issubclass",
- "iter",
- "len",
- "list",
- "locals",
- "map",
- "max",
- "memoryview",
- "min",
- "next",
- "object",
- "oct",
- "open",
- "ord",
- "pow",
- "print",
- "property",
- "range",
- "repr",
- "reversed",
- "round",
- "set",
- "setattr",
- "slice",
- "sorted",
- "staticmethod",
- "str",
- "sum",
- "super",
- "tuple",
- "type",
- "vars",
- "zip",
- "__import__"
- )
- def BUILTIN_PREFIX = "builtins.py:"
+ def BUILTIN_PREFIX = "__builtin"
}
diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/CallCpgTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/CallCpgTests.scala
index 5aae5609efd9..7ec13f219950 100644
--- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/CallCpgTests.scala
+++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/CallCpgTests.scala
@@ -203,9 +203,9 @@ class CallCpgTests extends PySrc2CpgFixture(withOssDataflow = false) {
"test that the identifiers are not set to the function pointers but rather the 'ANY' return value" in {
val List(x, y, z) = cpg.identifier.name("x", "y", "z").l
- x.typeFullName shouldBe "ANY"
- y.typeFullName shouldBe "ANY"
- z.typeFullName shouldBe "ANY"
+ x.typeFullName shouldBe "foo.py:.foo_func."
+ y.typeFullName shouldBe Seq("foo", "bar", "__init__.py:.bar_func.").mkString(File.separator)
+ z.typeFullName shouldBe "foo.py:.faz."
}
"test call node properties for normal import from module on root path" in {
diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/passes/TypeRecoveryPassTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/passes/TypeRecoveryPassTests.scala
index 70b4c6b0eb9a..84155b4a78d1 100644
--- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/passes/TypeRecoveryPassTests.scala
+++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/passes/TypeRecoveryPassTests.scala
@@ -22,22 +22,26 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) {
"resolve 'x' identifier types despite shadowing" in {
val List(xOuterScope, xInnerScope) = cpg.identifier("x").take(2).l
- xOuterScope.dynamicTypeHintFullName shouldBe Seq("int", "str")
- xInnerScope.dynamicTypeHintFullName shouldBe Seq("int", "str")
+ xOuterScope.dynamicTypeHintFullName shouldBe Seq("__builtin.int", "__builtin.str")
+ xInnerScope.dynamicTypeHintFullName shouldBe Seq("__builtin.int", "__builtin.str")
}
"resolve 'y' and 'z' identifier collection types" in {
val List(zDict, zList, zTuple) = cpg.identifier("z").take(3).l
- zDict.dynamicTypeHintFullName shouldBe Seq("dict", "list", "tuple")
- zList.dynamicTypeHintFullName shouldBe Seq("dict", "list", "tuple")
- zTuple.dynamicTypeHintFullName shouldBe Seq("dict", "list", "tuple")
+ zDict.dynamicTypeHintFullName shouldBe Seq("__builtin.dict", "__builtin.list", "__builtin.tuple")
+ zList.dynamicTypeHintFullName shouldBe Seq("__builtin.dict", "__builtin.list", "__builtin.tuple")
+ zTuple.dynamicTypeHintFullName shouldBe Seq("__builtin.dict", "__builtin.list", "__builtin.tuple")
}
"resolve 'z' identifier calls conservatively" in {
// TODO: These should have callee entries but the method stubs are not present here
val List(zAppend) = cpg.call("append").l
zAppend.methodFullName shouldBe Defines.DynamicCallUnknownFallName
- zAppend.dynamicTypeHintFullName shouldBe Seq("dict.append", "list.append", "tuple.append")
+ zAppend.dynamicTypeHintFullName shouldBe Seq(
+ "__builtin.dict.append",
+ "__builtin.list.append",
+ "__builtin.tuple.append"
+ )
}
}
@@ -82,6 +86,11 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) {
postMessage.methodFullName shouldBe "slack_sdk.py:.WebClient.WebClient.chat_postMessage"
}
+ "resolve a dummy 'send' return value from sg.send" in {
+ val List(postMessage) = cpg.identifier("response").l
+ postMessage.typeFullName shouldBe "sendgrid.py:.SendGridAPIClient.SendGridAPIClient.send."
+ }
+
}
"type recovery on class members" should {
@@ -125,7 +134,7 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) {
"resolve 'User' field types" in {
val List(id, firstname, age, address) =
cpg.identifier.nameExact("id", "firstname", "age", "address").takeRight(4).l
- id.typeFullName shouldBe "flask_sqlalchemy.py:.SQLAlchemy.SQLAlchemy.Column"
+ id.typeFullName shouldBe "flask_sqlalchemy.py:.SQLAlchemy.SQLAlchemy.Column.Column"
firstname.typeFullName shouldBe "flask_sqlalchemy.py:.SQLAlchemy.SQLAlchemy.Column.Column"
age.typeFullName shouldBe "flask_sqlalchemy.py:.SQLAlchemy.SQLAlchemy.Column.Column"
address.typeFullName shouldBe "flask_sqlalchemy.py:.SQLAlchemy.SQLAlchemy.Column.Column"
@@ -150,14 +159,14 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) {
"resolve 'print' and 'max' calls" in {
val Some(printCall) = cpg.call("print").headOption
- printCall.methodFullName shouldBe "builtins.py:.print"
+ printCall.methodFullName shouldBe "__builtin.print"
val Some(maxCall) = cpg.call("max").headOption
- maxCall.methodFullName shouldBe "builtins.py:.max"
+ maxCall.methodFullName shouldBe "__builtin.max"
}
- "select the imported abs over the built-in type when call is shadowed" in {
+ "conservatively present either option when an imported function uses the same name as a builtin" in {
val Some(absCall) = cpg.call("abs").headOption
- absCall.dynamicTypeHintFullName shouldBe Seq("foo.py:.abs")
+ absCall.dynamicTypeHintFullName shouldBe Seq("foo.py:.abs", "__builtin.abs")
}
}
@@ -190,9 +199,9 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) {
"resolve 'x' and 'y' locally under foo.py" in {
val Some(x) = cpg.file.name(".*foo.*").ast.isIdentifier.name("x").headOption
- x.typeFullName shouldBe "int"
+ x.typeFullName shouldBe "__builtin.int"
val Some(y) = cpg.file.name(".*foo.*").ast.isIdentifier.name("y").headOption
- y.typeFullName shouldBe "str"
+ y.typeFullName shouldBe "__builtin.str"
}
"resolve 'foo.x' and 'foo.y' field access primitive types correctly" in {
@@ -203,9 +212,9 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) {
.name("z")
.l
z1.typeFullName shouldBe "ANY"
- z1.dynamicTypeHintFullName shouldBe Seq("int", "str")
+ z1.dynamicTypeHintFullName shouldBe Seq("__builtin.int", "__builtin.str")
z2.typeFullName shouldBe "ANY"
- z2.dynamicTypeHintFullName shouldBe Seq("int", "str")
+ z2.dynamicTypeHintFullName shouldBe Seq("__builtin.int", "__builtin.str")
}
"resolve 'foo.d' field access object types correctly" in {
@@ -256,7 +265,7 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) {
| try:
| db.create_all()
| db.session.add(user)
- | return jsonify({"success": message})
+ | return jsonify({"success": True})
| except Exception as e:
| return 'There was an issue adding your task'
|""".stripMargin,
@@ -377,12 +386,15 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) {
"recover a potential type for `self.collection` using the assignment at `get_collection` as a type hint" in {
val Some(selfFindFound) = cpg.typeDecl(".*InstallationsDAO.*").ast.isCall.name("find_one").headOption
- selfFindFound.methodFullName shouldBe "pymongo.py:.MongoClient....find_one"
+ selfFindFound.dynamicTypeHintFullName shouldBe Seq(
+ "__builtin.None.find_one",
+ "pymongo.py:.MongoClient....find_one"
+ )
}
"correctly determine that, despite being unable to resolve the correct method full name, that it is an internal method" in {
val Some(selfFindFound) = cpg.typeDecl(".*InstallationsDAO.*").ast.isCall.name("find_one").headOption
- selfFindFound.callee.isExternal.headOption shouldBe Some(false)
+ selfFindFound.callee.isExternal.toSeq shouldBe Seq(true, false)
}
}
diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/SymbolTable.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/SymbolTable.scala
index fe3076b17af2..fd859bc96b12 100644
--- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/SymbolTable.scala
+++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/SymbolTable.scala
@@ -56,6 +56,10 @@ sealed class LocalKey(identifier: String) extends SBKey(identifier) {
*/
case class LocalVar(override val identifier: String) extends LocalKey(identifier)
+/** A collection object that can be accessed with potentially dynamic keys and values.
+ */
+case class CollectionVar(override val identifier: String, idx: String) extends LocalKey(identifier)
+
/** A name that refers to some kind of callee.
*/
case class CallAlias(override val identifier: String) extends LocalKey(identifier)
@@ -98,25 +102,25 @@ class SymbolTable[K <: SBKey](fromNode: AstNode => K) {
table.put(newKey, newValues)
}
- def put(sbKey: K, typeFullNames: Set[String]): Option[Set[String]] =
- table.put(sbKey, typeFullNames)
+ def put(sbKey: K, typeFullNames: Set[String]): Set[String] =
+ table.put(sbKey, typeFullNames).getOrElse(Set.empty)
- def put(sbKey: K, typeFullName: String): Option[Set[String]] =
+ def put(sbKey: K, typeFullName: String): Set[String] =
put(sbKey, Set(typeFullName))
- def put(node: AstNode, typeFullNames: Set[String]): Option[Set[String]] =
+ def put(node: AstNode, typeFullNames: Set[String]): Set[String] =
put(fromNode(node), typeFullNames)
- def append(node: AstNode, typeFullName: String): Option[Set[String]] =
+ def append(node: AstNode, typeFullName: String): Set[String] =
append(node, Set(typeFullName))
- def append(node: AstNode, typeFullNames: Set[String]): Option[Set[String]] =
+ def append(node: AstNode, typeFullNames: Set[String]): Set[String] =
append(fromNode(node), typeFullNames)
- def append(sbKey: K, typeFullNames: Set[String]): Option[Set[String]] = {
+ def append(sbKey: K, typeFullNames: Set[String]): Set[String] = {
table.get(sbKey) match {
- case Some(ts) => table.put(sbKey, ts ++ typeFullNames)
- case None => table.put(sbKey, typeFullNames)
+ case Some(ts) => put(sbKey, ts ++ typeFullNames)
+ case None => put(sbKey, typeFullNames)
}
}
diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/XTypeRecovery.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/XTypeRecovery.scala
index 8be04607106a..9ce804ee7e4a 100644
--- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/XTypeRecovery.scala
+++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/XTypeRecovery.scala
@@ -1,18 +1,19 @@
package io.joern.x2cpg.passes.frontend
+import io.joern.x2cpg.Defines
import io.shiftleft.codepropertygraph.Cpg
import io.shiftleft.codepropertygraph.generated.nodes._
import io.shiftleft.codepropertygraph.generated.{Operators, PropertyNames}
import io.shiftleft.passes.CpgPass
import io.shiftleft.semanticcpg.language._
import io.shiftleft.semanticcpg.language.operatorextension.OpNodes
-import io.shiftleft.semanticcpg.language.operatorextension.OpNodes.Assignment
+import io.shiftleft.semanticcpg.language.operatorextension.OpNodes.{Assignment, FieldAccess}
import org.slf4j.{Logger, LoggerFactory}
import overflowdb.BatchedUpdate.DiffGraphBuilder
import overflowdb.traversal.Traversal
-import java.util.Objects
import java.util.concurrent.RecursiveTask
+import scala.collection.mutable
/** Based on a flow-insensitive symbol-table-style approach. This pass aims to be fast and deterministic and does not
* try to converge to some fixed point but rather iterates a fixed number of times. This will help recover:
@@ -50,9 +51,9 @@ abstract class XTypeRecovery[ComputationalUnit <: AstNode](cpg: Cpg, iterations:
protected val globalTable = new SymbolTable[GlobalKey](SBKey.fromNodeToGlobalKey)
override def run(builder: DiffGraphBuilder): Unit = try {
- for (_ <- 0 to iterations)
+ for (_ <- 0 until iterations)
computationalUnit
- .map(unit => generateRecoveryForCompilationUnitTask(unit, builder, globalTable).fork())
+ .map(unit => generateRecoveryForCompilationUnitTask(unit, builder).fork())
.foreach(_.get())
} finally {
globalTable.clear()
@@ -68,95 +69,45 @@ abstract class XTypeRecovery[ComputationalUnit <: AstNode](cpg: Cpg, iterations:
* the compilation unit.
* @param builder
* the graph builder.
- * @param globalTable
- * the global table.
* @return
* a forkable [[RecoverForXCompilationUnit]] task.
*/
def generateRecoveryForCompilationUnitTask(
unit: ComputationalUnit,
- builder: DiffGraphBuilder,
- globalTable: SymbolTable[GlobalKey]
+ builder: DiffGraphBuilder
): RecoverForXCompilationUnit[ComputationalUnit]
}
-/** Defines how a procedure is available to be called in the current scope either by it being defined in this module or
- * being imported.
- *
- * @param callingName
- * how this procedure is to be called, i.e., alias name, name with path, etc.
- * @param fullName
- * the full name to where this method is defined.
- */
-abstract class ScopedXProcedure(val callingName: String, val fullName: String, val isConstructor: Boolean = false) {
-
- /** @return
- * there are multiple ways that this procedure could be resolved in some languages. This will be pruned later by
- * comparing this to actual methods in the CPG using postVisitImport.
- */
- def possibleCalleeNames: Set[String] = Set()
-
- override def toString: String = s"ProcedureCalledAs(${possibleCalleeNames.mkString(", ")})"
-
- override def equals(obj: Any): Boolean = {
- obj match {
- case o: ScopedXProcedure =>
- callingName.equals(o.callingName) && fullName.equals(o.fullName) && isConstructor == o.isConstructor
- case _ => false
- }
- }
-
- override def hashCode(): Int = Objects.hash(callingName, fullName, isConstructor)
-
-}
-
-/** Tasks responsible for populating the symbol table with import data.
- *
- * @param node
- * a node that references import information.
- */
-abstract class SetXProcedureDefTask(node: CfgNode) extends RecursiveTask[Unit] {
-
- protected val logger: Logger = LoggerFactory.getLogger(classOf[SetXProcedureDefTask])
-
- override def compute(): Unit =
- node match {
- case x: Method => visitImport(x)
- case x: Call => visitImport(x)
- case _ =>
- }
-
- /** Refers to the declared import information.
- *
- * @param importCall
- * the call that imports entities into this scope.
- */
- def visitImport(importCall: Call): Unit
-
- /** Sets how an application method would be referred to locally.
- *
- * @param m
- * an internal method
- */
- def visitImport(m: Method): Unit
-
+object XTypeRecovery {
+ val DUMMY_RETURN_TYPE = ""
+ val DUMMY_MEMBER_LOAD = ""
+ val DUMMY_INDEX_ACCESS = ""
+ def DUMMY_MEMBER_TYPE(prefix: String, memberName: String) = s"$prefix.$DUMMY_MEMBER_LOAD($memberName)"
}
/** Performs type recovery from the root of a compilation unit level
*
+ * @param cpg
+ * the graph.
* @param cu
* a compilation unit, e.g. file, procedure, type, etc.
* @param builder
* the graph builder
+ * @param globalTable
+ * the global symbol table.
* @tparam ComputationalUnit
* the [[AstNode]] type used to represent a computational unit of the language.
*/
abstract class RecoverForXCompilationUnit[ComputationalUnit <: AstNode](
+ cpg: Cpg,
cu: ComputationalUnit,
- builder: DiffGraphBuilder
+ builder: DiffGraphBuilder,
+ globalTable: SymbolTable[GlobalKey]
) extends RecursiveTask[Unit] {
+ protected val logger: Logger = LoggerFactory.getLogger(getClass)
+
/** Stores type information for local structures that live within this compilation unit, e.g. local variables.
*/
protected val symbolTable = new SymbolTable[LocalKey](SBKey.fromNodeToLocalKey)
@@ -168,12 +119,17 @@ abstract class RecoverForXCompilationUnit[ComputationalUnit <: AstNode](
protected def assignments: Traversal[Assignment] =
cu.ast.isCall.name(Operators.assignment).map(new OpNodes.Assignment(_))
+ protected def members: Traversal[Member] =
+ cu.ast.isMember
+
override def compute(): Unit = try {
prepopulateSymbolTable()
// Set known aliases that point to imports for local and external methods/modules
- setImportsFromDeclaredProcedures(importNodes(cu) ++ internalMethodNodes(cu))
+ visitImports(importNodes(cu))
// Prune import names if the methods exist in the CPG
postVisitImports()
+ // Populate fields
+ members.foreach(visitMembers)
// Populate local symbol table with assignments
assignments.foreach(visitAssignments)
// Persist findings
@@ -182,51 +138,465 @@ abstract class RecoverForXCompilationUnit[ComputationalUnit <: AstNode](
symbolTable.clear()
}
+ private def debugLocation(n: AstNode): String = {
+ val rootPath = cpg.metaData.root.headOption.getOrElse("")
+ val fileName = n.file.name.headOption.getOrElse("").stripPrefix(rootPath)
+ val lineNo = n.lineNumber.getOrElse("")
+ s"$fileName#L$lineNo"
+ }
+
/** Using import information and internally defined procedures, will generate a mapping between how functions and
* types are aliased and called and themselves.
*
* @param procedureDeclarations
* imports to types or functions and internally defined methods themselves.
*/
- protected def setImportsFromDeclaredProcedures(procedureDeclarations: Traversal[CfgNode]): Unit =
- procedureDeclarations.map(f => generateSetProcedureDefTask(f, symbolTable).fork()).foreach(_.get())
+ protected def visitImports(procedureDeclarations: Traversal[AstNode]): Unit = {
+ procedureDeclarations.foreach {
+ case i: Import => visitImport(i)
+ case i: Call => visitImport(i)
+ }
+ }
- /** Generates a task to create an import task.
+ /** Refers to the declared import information. This is for legacy import notation.
*
- * @param node
- * the import node or method definition node.
- * @param symbolTable
- * the local table.
- * @return
- * a forkable [[SetXProcedureDefTask]] task.
+ * @param i
+ * the call that imports entities into this scope.
*/
- def generateSetProcedureDefTask(node: CfgNode, symbolTable: SymbolTable[LocalKey]): SetXProcedureDefTask
+ protected def visitImport(i: Call): Unit
- /** @return
- * the import nodes of this computational unit.
+ /** Visits an import and stores references in the symbol table as both an identifier and call.
*/
- def importNodes(cu: AstNode): Traversal[CfgNode]
+ protected def visitImport(i: Import): Unit = {
+ val entity = i.importedEntity
+ val alias = i.importedAs
+ if (entity.isDefined && alias.isDefined) {
+ symbolTable.append(LocalVar(alias.get), Set(entity.get))
+ symbolTable.append(CallAlias(alias.get), Set(entity.get))
+ }
+ }
- /** @param cu
- * the current computational unit.
- * @return
- * the methods defined within this computational unit.
+ /** @return
+ * the import nodes of this computational unit.
*/
- def internalMethodNodes(cu: AstNode): Traversal[Method] = cu.ast.isMethod.isExternal(false)
+ protected def importNodes(cu: AstNode): Traversal[AstNode] = cu.ast.isImport
/** The initial import setting is over-approximated, so this step checks the CPG for any matches and prunes against
* these findings. If there are no findings, it will leave the table as is. The latter is significant for external
* types or methods.
*/
- def postVisitImports(): Unit = {}
+ protected def postVisitImports(): Unit = {}
+
+ /** Using member information, will propagate member information to the global and local symbol table. By default,
+ * fields in the local table will be prepended with "this".
+ */
+ protected def visitMembers(member: Member): Unit = {
+ symbolTable.put(LocalVar(member.name), Set.empty[String])
+ globalTable.put(FieldVar(member.typeDecl.fullName, member.name), Set.empty[String])
+ }
/** Using assignment and import information (in the global symbol table), will propagate these types in the symbol
* table.
*
- * @param assignment
+ * @param a
* assignment call pointer.
*/
- def visitAssignments(assignment: Assignment): Unit
+ protected def visitAssignments(a: Assignment): Set[String] = {
+ a.argumentOut.l match {
+ case List(i: Identifier, b: Block) => visitIdentifierAssignedToBlock(i, b)
+ case List(i: Identifier, c: Call) => visitIdentifierAssignedToCall(i, c)
+ case List(i: Identifier, l: Literal) => visitIdentifierAssignedToLiteral(i, l)
+ case List(i: Identifier, m: MethodRef) => visitIdentifierAssignedToMethodRef(i, m)
+ case List(i: Identifier, t: TypeRef) => visitIdentifierAssignedToTypeRef(i, t)
+ case List(c: Call, i: Identifier) => visitCallAssignedToIdentifier(c, i)
+ case List(x: Call, y: Call) => visitCallAssignedToCall(x, y)
+ case List(c: Call, l: Literal) => visitCallAssignedToLiteral(c, l)
+ case List(c: Call, m: MethodRef) => visitCallAssignedToMethodRef(c, m)
+ case List(c: Call, b: Block) => visitCallAssignedToBlock(c, b)
+ case xs =>
+ logger.warn(s"Unhandled assignment ${xs.map(x => (x.label, x.code)).mkString(",")} @ ${debugLocation(a)}")
+ Set.empty
+ }
+ }
+
+ /** Visits an identifier being assigned to the result of some operation.
+ */
+ protected def visitIdentifierAssignedToBlock(i: Identifier, b: Block): Set[String] = {
+ val blockTypes = visitStatementsInBlock(b)
+ if (blockTypes.nonEmpty) associateTypes(i, blockTypes)
+ else Set.empty
+ }
+
+ /** Visits a call operation being assigned to the result of some operation.
+ */
+ protected def visitCallAssignedToBlock(c: Call, b: Block): Set[String] = {
+ val blockTypes = visitStatementsInBlock(b)
+ assignTypesToCall(c, blockTypes)
+ }
+
+ /** Process each statement but only assign the type of the last statement to the identifier
+ */
+ protected def visitStatementsInBlock(b: Block): Set[String] =
+ b.astChildren
+ .map {
+ case x: Call if x.name.equals(Operators.assignment) => visitAssignments(new Assignment(x))
+ case x: Identifier if x.astChildren.isEmpty && symbolTable.contains(x) => symbolTable.get(x)
+ case x: Call if symbolTable.contains(x) => symbolTable.get(x)
+ case x: Call if x.argument.headOption.exists(symbolTable.contains) => setCallMethodFullNameFromBase(x)
+ case x => logger.warn(s"Unhandled block element ${x.label}:${x.code} @ ${debugLocation(x)}"); Set.empty[String]
+ }
+ .lastOption
+ .getOrElse(Set.empty[String])
+
+ /** Visits an identifier being assigned to a call. This call could be an operation, function invocation, or
+ * constructor invocation.
+ */
+ protected def visitIdentifierAssignedToCall(i: Identifier, c: Call): Set[String] = {
+ if (c.name.startsWith("")) {
+ visitIdentifierAssignedToOperator(i, c, c.name)
+ } else if (symbolTable.contains(c) && isConstructor(c)) {
+ visitIdentifierAssignedToConstructor(i, c)
+ } else if (symbolTable.contains(c)) {
+ visitIdentifierAssignedToCallRetVal(i, c)
+ } else if (c.argument.headOption.exists(symbolTable.contains)) {
+ setCallMethodFullNameFromBase(c)
+ // Repeat this method now that the call has a type
+ visitIdentifierAssignedToCall(i, c)
+ } else {
+ // We can try obtain a return type for this call
+ visitIdentifierAssignedToCallRetVal(i, c)
+ }
+ }
+
+ /** Will build a call full path using the call base node. This method assumes the base node is in the symbol table.
+ */
+ protected def setCallMethodFullNameFromBase(c: Call): Set[String] = {
+ val recTypes = c.argument.headOption.map(symbolTable.get).getOrElse(Set.empty[String])
+ val callTypes = recTypes.map(_.concat(s".${c.name}"))
+ symbolTable.append(c, callTypes)
+ }
+
+ /** A heuristic method to determine if a call is a constructor or not.
+ */
+ protected def isConstructor(c: Call): Boolean
+
+ /** A heuristic method to determine if an identifier may be a field or not. The result means that it would be stored
+ * in the global symbol table. By default this checks if the identifier name matches a member name.
+ */
+ protected def isField(i: Identifier): Boolean = i.method.typeDecl.member.exists(_.name.equals(i.name))
+
+ /** Associates the types with the identifier. This may sometimes be an identifier that should be considered a field
+ * which this method uses [[isField(i)]] to determine.
+ */
+ protected def associateTypes(i: Identifier, types: Set[String]): Set[String] = {
+ if (isField(i)) globalTable.put(i, types)
+ symbolTable.append(i, types)
+ }
+
+ /** Returns the appropriate field parent scope.
+ */
+ protected def getFieldParents(fa: FieldAccess): Set[String] = {
+ val fieldName = getFieldName(fa).split("\\.").last
+ cpg.typeDecl.filter(_.member.exists(_.name.equals(fieldName))).fullName.filterNot(_.contains("ANY")).toSet
+ }
+
+ /** Associates the types with the identifier. This may sometimes be an identifier that should be considered a field
+ * which this method uses [[isField(i)]] to determine.
+ */
+ protected def associateTypes(symbol: LocalVar, fa: FieldAccess, types: Set[String]): Set[String] = {
+ val fieldParents = getFieldParents(fa)
+ fa.astChildren.l match {
+ case ::(i: Identifier, _) if isField(i) && fieldParents.nonEmpty =>
+ fieldParents.foreach(fp => globalTable.put(FieldVar(fp, symbol.identifier), types))
+ case _ =>
+ }
+ symbolTable.append(symbol, types)
+ }
+
+ /** Similar to [[associateTypes()]] but used in the case where there is some kind of field load.
+ */
+ protected def associateInterproceduralTypes(
+ i: Identifier,
+ base: Identifier,
+ fi: FieldIdentifier,
+ fieldName: String,
+ baseTypes: Set[String]
+ ): Set[String] = {
+ val globalTypes = getFieldBaseType(base, fi)
+ associateInterproceduralTypes(i, fieldName, fi.canonicalName, globalTypes, baseTypes)
+ }
+
+ protected def associateInterproceduralTypes(
+ i: Identifier,
+ fieldFullName: String,
+ fieldName: String,
+ globalTypes: Set[String],
+ baseTypes: Set[String]
+ ): Set[String] = {
+ if (globalTypes.nonEmpty) {
+ // We have been able to resolve the type inter-procedurally
+ associateTypes(i, globalTypes)
+ } else if (baseTypes.nonEmpty) {
+ if (baseTypes.equals(symbolTable.get(LocalVar(fieldFullName)))) {
+ associateTypes(i, baseTypes)
+ } else {
+ // If not available, use a dummy variable that can be useful for call matching
+ associateTypes(i, baseTypes.map(t => XTypeRecovery.DUMMY_MEMBER_TYPE(t, fieldName)))
+ }
+ } else {
+ logger.warn(s"Unable to associate interprocedural type for $i = $fieldFullName @ ${debugLocation(i)}")
+ Set.empty
+ }
+ }
+
+ /** Visits an identifier being assigned to an operator call.
+ */
+ protected def visitIdentifierAssignedToOperator(i: Identifier, c: Call, operation: String): Set[String] = {
+ operation match {
+ case Operators.alloc => visitIdentifierAssignedToConstructor(i, c)
+ case Operators.fieldAccess => visitIdentifierAssignedToFieldLoad(i, new FieldAccess(c))
+ case x => logger.warn(s"Unhandled operation $x (${c.code}) @ ${debugLocation(c)}"); Set.empty
+ }
+ }
+
+ /** Visits an identifier being assigned to a constructor and attempts to speculate the constructor path.
+ */
+ protected def visitIdentifierAssignedToConstructor(i: Identifier, c: Call): Set[String] = {
+ val constructorPaths = symbolTable.get(c).map(t => t.concat(s".${Defines.ConstructorMethodName}"))
+ associateTypes(i, constructorPaths)
+ }
+
+ /** Visits an identifier being assigned to a call's return value.
+ */
+ protected def visitIdentifierAssignedToCallRetVal(i: Identifier, c: Call): Set[String] = {
+ if (symbolTable.contains(c)) {
+ val callReturns = methodReturnValues(symbolTable.get(c).toSeq)
+ associateTypes(i, callReturns)
+ } else if (c.argument.exists(_.argumentIndex == 0)) {
+ val callFullNames = (c.argument(0) match {
+ case i: Identifier if symbolTable.contains(LocalVar(i.name)) => symbolTable.get(LocalVar(i.name))
+ case i: Identifier if symbolTable.contains(CallAlias(i.name)) => symbolTable.get(CallAlias(i.name))
+ case _ => Set.empty
+ }).map(_.concat(s".${c.name}")).toSeq
+ val callReturns = methodReturnValues(callFullNames)
+ associateTypes(i, callReturns)
+ } else {
+ logger.warn(s"Unable to assign identifier '${i.name} to return value for ${c.code} @ ${debugLocation(c)}")
+ Set.empty
+ }
+ }
+
+ /** Will attempt to find the return values of a method if in the CPG, otherwise will give a dummy value.
+ */
+ protected def methodReturnValues(methodFullNames: Seq[String]): Set[String] = {
+ val rs = cpg.method.fullNameExact(methodFullNames: _*).methodReturn.typeFullName.filterNot(_.equals("ANY")).toSet
+ if (rs.isEmpty) methodFullNames.map(_.concat(s".${XTypeRecovery.DUMMY_RETURN_TYPE}")).toSet
+ else rs
+ }
+
+ /** Will handle literal value assignments. Override if special handling is required.
+ */
+ protected def visitIdentifierAssignedToLiteral(i: Identifier, l: Literal): Set[String] =
+ associateTypes(i, getLiteralType(l))
+
+ /** Not all frontends populate typeFullName
for literals so we allow this to be overridden.
+ */
+ protected def getLiteralType(l: Literal): Set[String] = Set(l.typeFullName)
+
+ /** Will handle an identifier holding a function pointer.
+ */
+ protected def visitIdentifierAssignedToMethodRef(i: Identifier, m: MethodRef): Set[String] =
+ symbolTable.append(CallAlias(i.name), Set(m.methodFullName))
+
+ /** Will handle an identifier holding a type pointer.
+ */
+ protected def visitIdentifierAssignedToTypeRef(i: Identifier, t: TypeRef): Set[String] =
+ symbolTable.append(CallAlias(i.name), Set(t.typeFullName))
+
+ /** Visits a call assigned to an identifier. This is often when there are operators involved.
+ */
+ protected def visitCallAssignedToIdentifier(c: Call, i: Identifier): Set[String] = {
+ val rhsTypes = symbolTable.get(i)
+ assignTypesToCall(c, rhsTypes)
+ }
+
+ /** Visits a call assigned to the return value of a call. This is often when there are operators involved.
+ */
+ protected def visitCallAssignedToCall(x: Call, y: Call): Set[String] = {
+ val rhsTypes = y.name match {
+ case Operators.fieldAccess => symbolTable.get(LocalVar(getFieldName(new FieldAccess(y))))
+ case _ if symbolTable.contains(y) => symbolTable.get(y)
+ case Operators.indexAccess => getIndexAccessTypes(y)
+ case n =>
+ logger.warn(s"Unknown RHS call type '$n' @ ${debugLocation(x)}")
+ Set.empty[String]
+ }
+ assignTypesToCall(x, rhsTypes)
+ }
+
+ /** Given a LHS call, will retrieve its symbol to the given types.
+ */
+ protected def assignTypesToCall(x: Call, types: Set[String]): Set[String] = {
+ if (types.isEmpty) return Set.empty
+ getSymbolFromCall(x) match {
+ case (lhs, globalKeys) if globalKeys.nonEmpty =>
+ globalKeys.foreach(gt => globalTable.append(gt, types))
+ symbolTable.append(lhs, types)
+ case (lhs, _) => symbolTable.append(lhs, types)
+ }
+ }
+
+ /** Will attempt to retrieve index access types otherwise will return dummy value.
+ */
+ protected def getIndexAccessTypes(ia: Call): Set[String] = {
+ indexAccessToCollectionVar(ia) match {
+ case Some(cVar) if symbolTable.contains(cVar) =>
+ symbolTable.get(cVar)
+ case Some(cVar) if symbolTable.contains(LocalVar(cVar.identifier)) =>
+ symbolTable.get(LocalVar(cVar.identifier)).map(_.concat(s".${XTypeRecovery.DUMMY_INDEX_ACCESS}"))
+ case None => Set.empty
+ }
+ }
+
+ /** Tries to identify the underlying symbol from the call operation as it is used on the LHS of an assignment. The
+ * second element is a list of any associated global keys if applicable.
+ */
+ protected def getSymbolFromCall(c: Call): (LocalKey, Set[GlobalKey]) = c.name match {
+ case Operators.fieldAccess =>
+ val fa = new FieldAccess(c)
+ val fieldName = getFieldName(fa)
+ (LocalVar(fieldName), getFieldParents(fa).map(fp => FieldVar(fp, fieldName)))
+ case Operators.indexAccess => (indexAccessToCollectionVar(c).getOrElse(LocalVar(c.name)), Set.empty)
+ case x =>
+ logger.warn(s"Unknown LHS call type '$x' @ ${debugLocation(c)}")
+ (LocalVar(c.name), Set.empty)
+ }
+
+ /** Extracts a string representation of the name of the field within this field access.
+ */
+ protected def getFieldName(fa: FieldAccess, prefix: String = "", suffix: String = ""): String = {
+ def wrapName(n: String) = {
+ val sb = new mutable.StringBuilder()
+ if (prefix.nonEmpty) sb.append(s"$prefix.")
+ sb.append(n)
+ if (suffix.nonEmpty) sb.append(s".$suffix")
+ sb.toString()
+ }
+
+ fa.astChildren.l match {
+ case List(i: Identifier, f: FieldIdentifier) if i.name.matches("(self|this)") => wrapName(f.canonicalName)
+ case List(i: Identifier, f: FieldIdentifier) => wrapName(s"${i.name}.${f.canonicalName}")
+ case List(c: Call, f: FieldIdentifier) if c.name.equals(Operators.fieldAccess) =>
+ wrapName(getFieldName(new FieldAccess(c), suffix = f.canonicalName))
+ case List(f: FieldIdentifier, c: Call) if c.name.equals(Operators.fieldAccess) =>
+ wrapName(getFieldName(new FieldAccess(c), prefix = f.canonicalName))
+ case xs =>
+ logger.warn(s"Unhandled field structure ${xs.map(x => (x.label, x.code)).mkString(",")} @ ${debugLocation(fa)}")
+ wrapName("")
+ }
+ }
+
+ protected def visitCallAssignedToLiteral(c: Call, l: Literal): Set[String] = {
+ if (c.name.equals(Operators.indexAccess)) {
+ // For now, we will just handle this on a very basic level
+ c.argumentOut.l match {
+ case List(_: Identifier, _: Literal) =>
+ indexAccessToCollectionVar(c).map(cv => symbolTable.append(cv, getLiteralType(l))).getOrElse(Set.empty)
+ case List(_: Identifier, idx: Identifier) if symbolTable.contains(idx) =>
+ // Imprecise but sound!
+ indexAccessToCollectionVar(c).map(cv => symbolTable.append(cv, symbolTable.get(idx))).getOrElse(Set.empty)
+ case xs =>
+ logger.warn(s"Unhandled index access point assigned to literal ${xs.map(_.label)} @ ${debugLocation(c)}")
+ Set.empty
+ }
+ } else if (c.name.equals(Operators.fieldAccess)) {
+ val fa = new FieldAccess(c)
+ val fieldName = getFieldName(fa)
+ associateTypes(LocalVar(fieldName), fa, getLiteralType(l))
+ } else {
+ logger.warn(s"Unhandled call assigned to literal point ${c.name} @ ${debugLocation(c)}")
+ Set.empty
+ }
+ }
+
+ /** Handles a call operation assigned to a method/function pointer.
+ */
+ protected def visitCallAssignedToMethodRef(c: Call, m: MethodRef): Set[String] =
+ assignTypesToCall(c, Set(m.methodFullName))
+
+ /** Generates an identifier for collection/index-access operations in the symbol table.
+ */
+ protected def indexAccessToCollectionVar(c: Call): Option[CollectionVar] = {
+
+ def callName(x: Call) =
+ if (x.name.equals(Operators.fieldAccess))
+ getFieldName(new FieldAccess(x))
+ else if (x.name.equals(Operators.indexAccess))
+ indexAccessToCollectionVar(x)
+ .map(cv => s"${cv.identifier}[${cv.idx}]")
+ .getOrElse(XTypeRecovery.DUMMY_INDEX_ACCESS)
+ else x.name
+
+ Option(c.argumentOut.l match {
+ case List(i: Identifier, idx: Literal) => CollectionVar(i.name, idx.code)
+ case List(i: Identifier, idx: Identifier) => CollectionVar(i.name, idx.code)
+ case List(c: Call, idx: Call) => CollectionVar(callName(c), callName(idx))
+ case List(c: Call, idx: Literal) => CollectionVar(callName(c), idx.code)
+ case List(c: Call, idx: Identifier) => CollectionVar(callName(c), idx.code)
+ case xs =>
+ logger.warn(s"Unhandled index access ${xs.map(x => (x.label, x.code)).mkString(",")} @ ${debugLocation(c)}")
+ null
+ })
+ }
+
+ /** Will handle an identifier being assigned to a field value.
+ */
+ protected def visitIdentifierAssignedToFieldLoad(i: Identifier, fa: FieldAccess): Set[String] = {
+ val fieldName = getFieldName(fa)
+ fa.astChildren.l match {
+ case List(base: Identifier, fi: FieldIdentifier) if symbolTable.contains(LocalVar(base.name)) =>
+ // Get field from global table if referenced as a variable
+ val localTypes = symbolTable.get(LocalVar(base.name))
+ associateInterproceduralTypes(i, base, fi, fieldName, localTypes)
+ case List(base: Identifier, fi: FieldIdentifier) if symbolTable.contains(LocalVar(fieldName)) =>
+ val localTypes = symbolTable.get(LocalVar(fieldName))
+ associateInterproceduralTypes(i, base, fi, fieldName, localTypes)
+ case List(c: Call, f: FieldIdentifier) if c.name.equals(Operators.fieldAccess) =>
+ val baseName = getFieldName(new FieldAccess(c))
+ // Build type regardless of length
+ // TODO: This is more prone to giving dummy values as it does not do global look-ups
+ // but this is okay for now
+ val buf = mutable.ArrayBuffer.empty[String]
+ for (segment <- baseName.split("\\.") ++ Array(f.canonicalName)) {
+ val types =
+ if (buf.isEmpty) symbolTable.get(LocalVar(segment))
+ else buf.flatMap(t => symbolTable.get(LocalVar(s"$t.$segment"))).toSet
+ if (types.nonEmpty) {
+ buf.clear()
+ buf.addAll(types)
+ } else {
+ val bufCopy = Array.from(buf)
+ buf.clear()
+ bufCopy.foreach(t => buf.addOne(XTypeRecovery.DUMMY_MEMBER_TYPE(t, segment)))
+ }
+ }
+ associateTypes(i, buf.toSet)
+ case _ =>
+ logger.warn(s"Unable to assign identifier '${i.name}' to field load '$fieldName' @ ${debugLocation(i)}")
+ Set.empty
+ }
+ }
+
+ protected def getFieldBaseType(base: Identifier, fi: FieldIdentifier): Set[String] =
+ getFieldBaseType(base.name, fi.canonicalName)
+
+ protected def getFieldBaseType(baseName: String, fieldName: String): Set[String] = {
+ val localTypes = symbolTable.get(LocalVar(baseName))
+ val globalTypes = localTypes
+ .map(t => FieldVar(t, fieldName))
+ .flatMap(globalTable.get)
+ globalTypes
+ }
/** Using an entry from the symbol table, will queue the CPG modification to persist the recovered type information.
*/
@@ -245,6 +615,8 @@ abstract class RecoverForXCompilationUnit[ComputationalUnit <: AstNode](
}
}
+ /** TODO: Cleaning up using visitor patten
+ */
private def setTypeInformationForRecCall(x: AstNode, n: Option[Call], ms: List[AstNode]): Unit =
(n, ms) match {
// Case 1: 'call' is an assignment from some dynamic dispatch call
diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/AstNodeMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/AstNodeMethods.scala
index 05d202a17763..549bf39c0854 100644
--- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/AstNodeMethods.scala
+++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/AstNodeMethods.scala
@@ -14,6 +14,8 @@ class AstNodeMethods(val node: AstNode) extends AnyVal with NodeExtension {
def isIdentifier: Boolean = node.isInstanceOf[Identifier]
+ def isImport: Boolean = node.isInstanceOf[Import]
+
def isFieldIdentifier: Boolean = node.isInstanceOf[FieldIdentifier]
def isFile: Boolean = node.isInstanceOf[File]
diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/AstNodeTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/AstNodeTraversal.scala
index f2de2a8ed707..0fb3bab658d8 100644
--- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/AstNodeTraversal.scala
+++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/AstNodeTraversal.scala
@@ -134,6 +134,11 @@ class AstNodeTraversal[A <: AstNode](val traversal: Traversal[A]) extends AnyVal
def isIdentifier: Traversal[Identifier] =
traversal.collectAll[Identifier]
+ /** Traverse only to AST nodes that are IMPORT nodes
+ */
+ def isImport: Traversal[Import] =
+ traversal.collectAll[Import]
+
/** Traverse only to FILE AST nodes
*/
def isFile: Traversal[File] =