From bcf279979f440a299a1f178d93828bbf7262453e Mon Sep 17 00:00:00 2001 From: David Baker Effendi Date: Wed, 15 Feb 2023 16:40:06 +0200 Subject: [PATCH] XTypeRecovery: Clean Up With Visitor-Pattern & Support New Imports (#2272) - Re-implemented `XTypeRecovery` using visitor pattern. This reduces the `PythonTypeRecovery` file by 50% and increases the language agnostic `XTypeRecovery` file by 100%. - `XTypeRecovery` can now support the new `Import` node while leaving room to override the `Call` version as still used in Python. - This change improves the soundness of the type recovery and, as such, resulted in some modifications required for some test cases. - This change also promotes re-use of code and easy diagnostics using a logger with debugging locators. Debug level used is `WARN`. Breaking changes: - Built-in types are now prefixed with `__builtin` which is what the front-end appears to use. - `__builtin.None` is now supported, which means potentially more types may be suggested if variables were initialized with `var = None`. --- .../joern/pysrc2cpg/PythonTypeRecovery.scala | 404 +++---------- .../io/joern/pysrc2cpg/cpg/CallCpgTests.scala | 6 +- .../passes/TypeRecoveryPassTests.scala | 48 +- .../x2cpg/passes/frontend/SymbolTable.scala | 22 +- .../x2cpg/passes/frontend/XTypeRecovery.scala | 554 +++++++++++++++--- .../language/nodemethods/AstNodeMethods.scala | 2 + .../generalizations/AstNodeTraversal.scala | 5 + 7 files changed, 600 insertions(+), 441 deletions(-) diff --git a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/PythonTypeRecovery.scala b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/PythonTypeRecovery.scala index 4946d4490a33..5f0677cb9192 100644 --- a/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/PythonTypeRecovery.scala +++ b/joern-cli/frontends/pysrc2cpg/src/main/scala/io/joern/pysrc2cpg/PythonTypeRecovery.scala @@ -1,13 +1,12 @@ package io.joern.pysrc2cpg +import io.joern.pysrc2cpg.PythonTypeRecovery.BUILTIN_PREFIX import io.joern.x2cpg.Defines import io.joern.x2cpg.passes.frontend._ import io.shiftleft.codepropertygraph.Cpg -import io.shiftleft.codepropertygraph.generated.Operators import io.shiftleft.codepropertygraph.generated.nodes._ import io.shiftleft.semanticcpg.language._ -import io.shiftleft.semanticcpg.language.operatorextension.OpNodes -import io.shiftleft.semanticcpg.language.operatorextension.OpNodes.{Assignment, FieldAccess} +import io.shiftleft.semanticcpg.language.operatorextension.OpNodes.FieldAccess import overflowdb.BatchedUpdate.DiffGraphBuilder import overflowdb.traversal.Traversal @@ -21,8 +20,7 @@ class PythonTypeRecovery(cpg: Cpg) extends XTypeRecovery[File](cpg) { override def generateRecoveryForCompilationUnitTask( unit: File, - builder: DiffGraphBuilder, - globalTable: SymbolTable[GlobalKey] + builder: DiffGraphBuilder ): RecoverForXCompilationUnit[File] = new RecoverForPythonFile(cpg, unit, builder, globalTable) } @@ -35,8 +33,7 @@ class PythonTypeRecovery(cpg: Cpg) extends XTypeRecovery[File](cpg) { * @param fullName * the full name to where this method is defined where it's assumed to be defined under a named Python file. */ -class ScopedPythonProcedure(callingName: String, fullName: String, isConstructor: Boolean = false) - extends ScopedXProcedure(callingName, fullName, isConstructor) { +case class ScopedPythonProcedure(callingName: String, fullName: String, isConstructor: Boolean = false) { /** @return * the full name of the procedure where it's assumed that it is defined within an __init.py__ of the @@ -48,44 +45,41 @@ class ScopedPythonProcedure(callingName: String, fullName: String, isConstructor * the two ways that this procedure could be resolved to in Python. This will be pruned later by comparing this to * actual methods in the CPG. */ - override def possibleCalleeNames: Set[String] = + def possibleCalleeNames: Set[String] = if (isConstructor) Set(fullName.concat(s".${Defines.ConstructorMethodName}")) else Set(fullName, fullNameAsInit) + override def toString: String = s"ProcedureCalledAs(${possibleCalleeNames.mkString(", ")})" + } -/** Tasks responsible for populating the symbol table with import data and method definition data. - * - * @param node - * a node that references import information. +/** Performs type recovery from the root of a compilation unit level */ -class SetPythonProcedureDefTask(node: CfgNode, symbolTable: SymbolTable[LocalKey]) extends SetXProcedureDefTask(node) { +class RecoverForPythonFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder, globalTable: SymbolTable[GlobalKey]) + extends RecoverForXCompilationUnit[File](cpg, cu, builder, globalTable) { - /** Refers to the declared import information. - * - * @param importCall - * the call that imports entities into this scope. + /** Overriden to include legacy import calls until imports are supported. */ - override def visitImport(importCall: Call): Unit = { - importCall.argumentOut.l match { + override def importNodes(cu: AstNode): Traversal[AstNode] = + cu.ast.isCall.nameExact("import") ++ super.importNodes(cu) + + override def visitImport(i: Call): Unit = { + i.argumentOut.l match { case List(path: Literal, funcOrModule: Literal) => val calleeNames = extractMethodDetailsFromImport(path.code, funcOrModule.code).possibleCalleeNames symbolTable.put(CallAlias(funcOrModule.code), calleeNames) + symbolTable.put(LocalVar(funcOrModule.code), calleeNames) case List(path: Literal, funcOrModule: Literal, alias: Literal) => val calleeNames = extractMethodDetailsFromImport(path.code, funcOrModule.code, Option(alias.code)).possibleCalleeNames symbolTable.put(CallAlias(alias.code), calleeNames) + symbolTable.put(LocalVar(alias.code), calleeNames) case x => logger.warn(s"Unknown import pattern: ${x.map(_.label).mkString(", ")}") } } - override def visitImport(m: Method): Unit = { - val calleeNames = new ScopedPythonProcedure(m.name, m.fullName).possibleCalleeNames - symbolTable.put(m, calleeNames) - } - /** Parses all imports and identifies their full names and how they are to be called in this scope. * * @param path @@ -101,7 +95,7 @@ class SetPythonProcedureDefTask(node: CfgNode, symbolTable: SymbolTable[LocalKey path: String, funcOrModule: String, maybeAlias: Option[String] = None - ): ScopedXProcedure = { + ): ScopedPythonProcedure = { val isConstructor = funcOrModule.split("\\.").last.charAt(0).isUpper if (path.isEmpty) { if (funcOrModule.contains(".")) { @@ -131,29 +125,6 @@ class SetPythonProcedureDefTask(node: CfgNode, symbolTable: SymbolTable[LocalKey } } -} - -/** Performs type recovery from the root of a compilation unit level - * - * @param cu - * a compilation unit, e.g. file. - * @param builder - * the graph builder - */ -class RecoverForPythonFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder, globalTable: SymbolTable[GlobalKey]) - extends RecoverForXCompilationUnit[File](cu, builder) { - - /** Adds built-in functions to expect. - */ - override def prepopulateSymbolTable(): Unit = - PythonTypeRecovery.BUILTINS - .map(t => (CallAlias(t), s"${PythonTypeRecovery.BUILTIN_PREFIX}.$t")) - .foreach { case (alias, typ) => - symbolTable.put(alias, typ) - } - - override def importNodes(cu: AstNode): Traversal[CfgNode] = cu.ast.isCall.nameExact("import") - override def postVisitImports(): Unit = { symbolTable.view.foreach { case (k, v) => val ms = cpg.method.fullNameExact(v.toSeq: _*).l @@ -166,11 +137,9 @@ class RecoverForPythonFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder, global // This is likely external and we will ignore the init variant to be consistent symbolTable.put(k, symbolTable(k).filterNot(_.contains("__init__.py"))) } - // Imports are by default used as calls, a second pass will tell us if this is not the case and we should // check against global table - // TODO: This is a bit of a bandaid compared to potentially having alias sensitivity. Values could be an - // Either[SBKey, Set[String] where Left[SBKey] could point to the aliased symbol + // TODO: This is a bit of a bandaid compared to potentially having alias sensitivity. def fieldVar(path: String) = FieldVar(path.stripSuffix(s".${k.identifier}"), k.identifier) @@ -182,292 +151,87 @@ class RecoverForPythonFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder, global } } - override def generateSetProcedureDefTask(node: CfgNode, symbolTable: SymbolTable[LocalKey]): SetXProcedureDefTask = - new SetPythonProcedureDefTask(node, symbolTable) - - /** Using assignment and import information (in the global symbol table), will propagate these types in the symbol - * table. - * - * @param assignment - * assignment call pointer. + /** Determines if a function call is a constructor by following the heuristic that Python classes are typically + * camel-case and start with an upper-case character. */ - override def visitAssignments(assignment: Assignment): Unit = { - assignment.argumentOut.take(2).l match { - case List(i: Identifier, c: Call) if symbolTable.contains(c) => - val importedTypes = symbolTable.get(c) - setIdentifierFromFunctionType(i, c.name, c.code, importedTypes) - case List(i: Identifier, c: CfgNode) - if visitLiteralAssignment(i, c, symbolTable) => // if unsuccessful, then check next - case List(i: Identifier, c: Call) if c.receiver.isCall.name.exists(_.equals(Operators.fieldAccess)) => - val field = c.receiver.isCall.name(Operators.fieldAccess).map(new OpNodes.FieldAccess(_)).head - visitCallFromFieldMember(i, c, field, symbolTable) - // Use global table knowledge (in i >= 2 iterations) or CPG to extract field types - case List(_: Identifier, c: Call) if c.name.equals(Operators.fieldAccess) => - c.inCall.argument - .flatMap { - case n: Call if n.name.equals(Operators.fieldAccess) => new OpNodes.FieldAccess(n).argumentOut - case n => n - } - .take(3) - .l match { - case List(assigned: Identifier, i: Identifier, f: FieldIdentifier) - if symbolTable.contains(CallAlias(i.name)) => - // Get field from global table if referenced as function call - val fieldTypes = symbolTable - .get(CallAlias(i.name)) - .flatMap(recModule => globalTable.get(FieldVar(recModule, f.canonicalName))) - if (fieldTypes.nonEmpty) symbolTable.append(assigned, fieldTypes) - case List(assigned: Identifier, i: Identifier, f: FieldIdentifier) - if symbolTable - .contains(LocalVar(i.name)) => - // Get field from global table if referenced as a variable - val localTypes = symbolTable.get(LocalVar(i.name)) - val memberTypes = localTypes - .flatMap { t => - cpg.typeDecl.fullNameExact(t).member.nameExact(f.canonicalName).l ++ - cpg.typeDecl.fullNameExact(t).method.fullNameExact(t).l - } - .flatMap { - case m: Member => Some(m.typeFullName) - case m: Method => Some(m.fullName) - case _ => None - } - if (memberTypes.nonEmpty) - // First use the member type info from the CPG, if present - symbolTable.append(assigned, memberTypes) - else if (localTypes.nonEmpty) { - // If not available, use a dummy variable that can be useful for call matching - symbolTable.append(assigned, localTypes.map { t => s"$t.(${f.canonicalName})" }) - } - case List(assigned: Identifier, i: Identifier, f: FieldIdentifier) - if symbolTable.contains(CallAlias(s"${i.name}.${f.canonicalName}")) => - // In this case, if are the paths of an import, e.g. import foo.bar and it is referred to as foo.bar later - // TODO: Does this handle foo.bar.baz ? - val callAlias = CallAlias(s"${i.name}.${f.canonicalName}") - val importedTypes = symbolTable.get(callAlias) - setIdentifierFromFunctionType(assigned, callAlias.identifier, callAlias.identifier, importedTypes) - case List(assigned: Identifier, i: Identifier, f: FieldIdentifier) => - // TODO: This is really tricky to find without proper object tracking, so we match name only - val fieldTypes = globalTable.view.filter(_._1.identifier.equals(f.canonicalName)).flatMap(_._2).toSet - if (fieldTypes.nonEmpty) symbolTable.append(assigned, fieldTypes) - case List(assigned: Identifier, c: Call, f: FieldIdentifier) if c.name.equals(Operators.fieldAccess) => - // TODO: This is the step that handles foo.bar.baz, which gives the impression that there is the need to - // handle this pattern recursively - val baseType = c.astChildren.isFieldIdentifier.canonicalName - .zip(c.astChildren.isIdentifier.flatMap(symbolTable.get)) - .map { case (ff, bt) => s"$bt.($ff).(${f.canonicalName})" } - .toSet - if (baseType.nonEmpty) symbolTable.append(assigned, baseType) + override def isConstructor(c: Call): Boolean = + c.name.nonEmpty && c.name.charAt(0).isUpper && c.code.endsWith(")") - case _ => - } - // Field load from call - case List(fl: Call, c: Call) if fl.name.equals(Operators.fieldAccess) && symbolTable.contains(c) => - (fl.astChildren.l, c.astChildren.l) match { - case (List(self: Identifier, fieldIdentifier: FieldIdentifier), args: List[_]) => - symbolTable.append(fieldIdentifier, symbolTable.get(c)) - globalTable.append(fieldVarName(fieldIdentifier), symbolTable.get(c)) - case _ => - } - // Field load from index access - case List(fl: Call, c: Call) if fl.name.equals(Operators.fieldAccess) && c.name.equals(Operators.indexAccess) => - (fl.astChildren.l, c.astChildren.l) match { - case (List(self: Identifier, fieldIdentifier: FieldIdentifier), ::(rhsFAccess: Call, _)) - if rhsFAccess.name.equals(Operators.fieldAccess) => - val rhsField = rhsFAccess.fieldAccess.fieldIdentifier.head - // TODO: Check if a type for the RHS index access is recovered - val types = symbolTable.get(rhsField).map(t => s"$t.") - symbolTable.append(fieldIdentifier, types) - globalTable.append(fieldVarName(fieldIdentifier), types) - case _ => - } - case _ => + /** If the parent method is module then it can be used as a field. + */ + override def isField(i: Identifier): Boolean = + i.method.name.matches("(|__init__)") || super.isField(i) + + override def visitIdentifierAssignedToOperator(i: Identifier, c: Call, operation: String): Set[String] = { + operation match { + case ".listLiteral" => associateTypes(i, Set(s"$BUILTIN_PREFIX.list")) + case ".tupleLiteral" => associateTypes(i, Set(s"$BUILTIN_PREFIX.tuple")) + case ".dictLiteral" => associateTypes(i, Set(s"$BUILTIN_PREFIX.dict")) + case _ => super.visitIdentifierAssignedToOperator(i, c, operation) } } - private def fieldVarName(f: FieldIdentifier): FieldVar = { - if (f.astSiblings.map(_.code).exists(_.contains("self"))) { - // This will match the type decl - FieldVar(f.method.typeDecl.fullName.head, f.canonicalName) - } else { - // This will typically match the - FieldVar(f.file.method.fullName.head, f.canonicalName) - } + override def visitIdentifierAssignedToConstructor(i: Identifier, c: Call): Set[String] = { + val constructorPaths = symbolTable + .get(c) + .map(_.stripSuffix(s".${Defines.ConstructorMethodName}")) + .map(x => (x.split("\\.").last, x)) + .map { + case (x, y) => s"$y.$x" + case (_, z) => z + } + associateTypes(i, constructorPaths) } - private def setIdentifierFromFunctionType( - i: Identifier, - callName: String, - callCode: String, - importedTypes: Set[String] - ): Unit = { - if (!callCode.endsWith(")")) { - // Case 1: The identifier is at the assignment to a function pointer. Lack of parenthesis should indicate this. - setIdentifier(i, importedTypes) - } else if (!callName.isBlank && callName.charAt(0).isUpper && callCode.endsWith(")")) { - // Case 2: The identifier is receiving a constructor invocation, thus is now an instance of the type - setIdentifier( - i, - importedTypes - .map(_.stripSuffix(s".${Defines.ConstructorMethodName}")) - .map(x => (x.split("\\.").last, x)) - .map { - case (x, y) => s"$y.$x" - case (_, z) => z - } - ) - } else { - // TODO: This identifier should contain the type of the return value of 'c'. - // e.g. x = foo(a, b) but not x = y.foo(a, b) as foo in the latter case is interpreted as a field access - } + override def visitIdentifierAssignedToCall(i: Identifier, c: Call): Set[String] = { + // Ignore legacy import representation + if (c.name.equals("import")) Set.empty + // Stop custom annotation representation from hitting superclass + else if (c.name.isBlank) Set.empty + else super.visitIdentifierAssignedToCall(i, c) } - private def setIdentifier(i: Identifier, types: Set[String]): Option[Set[String]] = { - if (i.method.name.equals("")) globalTable.put(i, types) - symbolTable.append(i, types) + override def visitIdentifierAssignedToFieldLoad(i: Identifier, fa: FieldAccess): Set[String] = { + val fieldParents = getFieldParents(fa) + fa.astChildren.l match { + case List(base: Identifier, fi: FieldIdentifier) if base.name.equals("self") && fieldParents.nonEmpty => + val globalTypes = fieldParents.flatMap(fp => globalTable.get(FieldVar(fp, fi.canonicalName))) + associateTypes(i, globalTypes) + case _ => super.visitIdentifierAssignedToFieldLoad(i, fa) + } } - private def visitCallFromFieldMember( - i: Identifier, - c: Call, - field: FieldAccess, - symbolTable: SymbolTable[LocalKey] - ): Unit = { - field.astChildren.l match { - case List(rec: Identifier, f: FieldIdentifier) if symbolTable.contains(rec) => - // First we ask if the call receiver is known as a variable - val identifierFullName = symbolTable.get(rec).map(_.concat(s".${f.canonicalName}")) - val callMethodFullName = - if (f.canonicalName.charAt(0).isUpper) - identifierFullName.map(_.concat(s".${Defines.ConstructorMethodName}")) - else - identifierFullName.map(x => (x.split("\\.").takeRight(2), x)).map { - case (Array(x, y), z) if x.charAt(0).isUpper && !z.contains("") => s"${z.stripSuffix(y)}$x.$y" - case (_, y) => y - } - symbolTable.put(i, identifierFullName) - symbolTable.put(c, callMethodFullName) - case List(rec: Identifier, f: FieldIdentifier) if symbolTable.contains(CallAlias(rec.name)) => - // Second we ask if the call receiver is known as a function pointer (imports are interpreted as functions first) - val funcTypes = symbolTable.get(CallAlias(rec.name)).map(t => s"$t.${f.canonicalName}") - // TODO: Look in the CPG if we can resolve the method return value - symbolTable.put(i, funcTypes.map(t => s"$t.")) - symbolTable.put(c, funcTypes) - case _ => + override def getFieldParents(fa: FieldAccess): Set[String] = { + if (fa.method.name.equals("")) { + Set(fa.method.fullName) + } else if (fa.method.typeDecl.nonEmpty) { + val parentTypes = + fa.method.typeDecl.fullName.map(_.stripSuffix("")).map { t => s"$t.${t.split("\\.").last}" }.toSeq + val baseTypes = cpg.typeDecl.fullNameExact(parentTypes: _*).inheritsFromTypeFullName.toSeq + // TODO: inheritsFromTypeFullName does not give full name in pysrc2cpg + val baseTypeFullNames = cpg.typ.nameExact(baseTypes: _*).fullName.toSeq + (parentTypes ++ baseTypeFullNames) + .map(_.concat(".")) + .filterNot(t => t.toLowerCase.matches("(any|object)")) + .toSet + } else { + super.getFieldParents(fa) } } - /** Will handle literal value assignments. - * @param lhs - * the identifier. - * @param rhs - * the literal. - * @param symbolTable - * the symbol table. - * @return - * true if a literal assigment was successfully determined and added to the symbol table, false if otherwise. - */ - private def visitLiteralAssignment(lhs: Identifier, rhs: CfgNode, symbolTable: SymbolTable[LocalKey]): Boolean = { - ((lhs, rhs) match { - case (i: Identifier, l: Literal) if Try(java.lang.Integer.parseInt(l.code)).isSuccess => - setIdentifier(i, Set("int")) - case (i: Identifier, l: Literal) if Try(java.lang.Double.parseDouble(l.code)).isSuccess => - setIdentifier(i, Set("float")) - case (i: Identifier, l: Literal) if "True".equals(l.code) || "False".equals(l.code) => - setIdentifier(i, Set("bool")) - case (i: Identifier, l: Literal) if l.code.matches("^(\"|').*(\"|')$") => - setIdentifier(i, Set("str")) - case (i: Identifier, c: Call) if c.name.equals(".listLiteral") => - setIdentifier(i, Set("list")) - case (i: Identifier, c: Call) if c.name.equals(".tupleLiteral") => - setIdentifier(i, Set("tuple")) - case (i: Identifier, b: Block) - if b.astChildren.isCall.headOption.exists( - _.argument.isCall.exists(_.name.equals(".dictLiteral")) - ) => - setIdentifier(i, Set("dict")) - case _ => None - }).hasNext + override def getLiteralType(l: Literal): Set[String] = { + l match { + case _ if Try(java.lang.Integer.parseInt(l.code)).isSuccess => Set(s"$BUILTIN_PREFIX.int") + case _ if Try(java.lang.Double.parseDouble(l.code)).isSuccess => Set(s"$BUILTIN_PREFIX.float") + case _ if "True".equals(l.code) || "False".equals(l.code) => Set(s"$BUILTIN_PREFIX.bool") + case _ if l.code.matches("^(\"|').*(\"|')$") => Set(s"$BUILTIN_PREFIX.str") + case _ if l.code.equals("None") => Set(s"$BUILTIN_PREFIX.None") + case _ => Set() + } } } object PythonTypeRecovery { - - /** @see - * Python Built-in Functions - */ - lazy val BUILTINS: Set[String] = Set( - "abs", - "aiter", - "all", - "anext", - "ascii", - "bin", - "bool", - "breakpoint", - "bytearray", - "bytes", - "callable", - "chr", - "classmethod", - "compile", - "complex", - "delattr", - "dict", - "dir", - "divmod", - "enumerate", - "eval", - "exec", - "filter", - "float", - "format", - "frozenset", - "getattr", - "globals", - "hasattr", - "hash", - "help", - "hex", - "id", - "input", - "int", - "isinstance", - "issubclass", - "iter", - "len", - "list", - "locals", - "map", - "max", - "memoryview", - "min", - "next", - "object", - "oct", - "open", - "ord", - "pow", - "print", - "property", - "range", - "repr", - "reversed", - "round", - "set", - "setattr", - "slice", - "sorted", - "staticmethod", - "str", - "sum", - "super", - "tuple", - "type", - "vars", - "zip", - "__import__" - ) - def BUILTIN_PREFIX = "builtins.py:" + def BUILTIN_PREFIX = "__builtin" } diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/CallCpgTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/CallCpgTests.scala index 5aae5609efd9..7ec13f219950 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/CallCpgTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/cpg/CallCpgTests.scala @@ -203,9 +203,9 @@ class CallCpgTests extends PySrc2CpgFixture(withOssDataflow = false) { "test that the identifiers are not set to the function pointers but rather the 'ANY' return value" in { val List(x, y, z) = cpg.identifier.name("x", "y", "z").l - x.typeFullName shouldBe "ANY" - y.typeFullName shouldBe "ANY" - z.typeFullName shouldBe "ANY" + x.typeFullName shouldBe "foo.py:.foo_func." + y.typeFullName shouldBe Seq("foo", "bar", "__init__.py:.bar_func.").mkString(File.separator) + z.typeFullName shouldBe "foo.py:.faz." } "test call node properties for normal import from module on root path" in { diff --git a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/passes/TypeRecoveryPassTests.scala b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/passes/TypeRecoveryPassTests.scala index 70b4c6b0eb9a..84155b4a78d1 100644 --- a/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/passes/TypeRecoveryPassTests.scala +++ b/joern-cli/frontends/pysrc2cpg/src/test/scala/io/joern/pysrc2cpg/passes/TypeRecoveryPassTests.scala @@ -22,22 +22,26 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) { "resolve 'x' identifier types despite shadowing" in { val List(xOuterScope, xInnerScope) = cpg.identifier("x").take(2).l - xOuterScope.dynamicTypeHintFullName shouldBe Seq("int", "str") - xInnerScope.dynamicTypeHintFullName shouldBe Seq("int", "str") + xOuterScope.dynamicTypeHintFullName shouldBe Seq("__builtin.int", "__builtin.str") + xInnerScope.dynamicTypeHintFullName shouldBe Seq("__builtin.int", "__builtin.str") } "resolve 'y' and 'z' identifier collection types" in { val List(zDict, zList, zTuple) = cpg.identifier("z").take(3).l - zDict.dynamicTypeHintFullName shouldBe Seq("dict", "list", "tuple") - zList.dynamicTypeHintFullName shouldBe Seq("dict", "list", "tuple") - zTuple.dynamicTypeHintFullName shouldBe Seq("dict", "list", "tuple") + zDict.dynamicTypeHintFullName shouldBe Seq("__builtin.dict", "__builtin.list", "__builtin.tuple") + zList.dynamicTypeHintFullName shouldBe Seq("__builtin.dict", "__builtin.list", "__builtin.tuple") + zTuple.dynamicTypeHintFullName shouldBe Seq("__builtin.dict", "__builtin.list", "__builtin.tuple") } "resolve 'z' identifier calls conservatively" in { // TODO: These should have callee entries but the method stubs are not present here val List(zAppend) = cpg.call("append").l zAppend.methodFullName shouldBe Defines.DynamicCallUnknownFallName - zAppend.dynamicTypeHintFullName shouldBe Seq("dict.append", "list.append", "tuple.append") + zAppend.dynamicTypeHintFullName shouldBe Seq( + "__builtin.dict.append", + "__builtin.list.append", + "__builtin.tuple.append" + ) } } @@ -82,6 +86,11 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) { postMessage.methodFullName shouldBe "slack_sdk.py:.WebClient.WebClient.chat_postMessage" } + "resolve a dummy 'send' return value from sg.send" in { + val List(postMessage) = cpg.identifier("response").l + postMessage.typeFullName shouldBe "sendgrid.py:.SendGridAPIClient.SendGridAPIClient.send." + } + } "type recovery on class members" should { @@ -125,7 +134,7 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) { "resolve 'User' field types" in { val List(id, firstname, age, address) = cpg.identifier.nameExact("id", "firstname", "age", "address").takeRight(4).l - id.typeFullName shouldBe "flask_sqlalchemy.py:.SQLAlchemy.SQLAlchemy.Column" + id.typeFullName shouldBe "flask_sqlalchemy.py:.SQLAlchemy.SQLAlchemy.Column.Column" firstname.typeFullName shouldBe "flask_sqlalchemy.py:.SQLAlchemy.SQLAlchemy.Column.Column" age.typeFullName shouldBe "flask_sqlalchemy.py:.SQLAlchemy.SQLAlchemy.Column.Column" address.typeFullName shouldBe "flask_sqlalchemy.py:.SQLAlchemy.SQLAlchemy.Column.Column" @@ -150,14 +159,14 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) { "resolve 'print' and 'max' calls" in { val Some(printCall) = cpg.call("print").headOption - printCall.methodFullName shouldBe "builtins.py:.print" + printCall.methodFullName shouldBe "__builtin.print" val Some(maxCall) = cpg.call("max").headOption - maxCall.methodFullName shouldBe "builtins.py:.max" + maxCall.methodFullName shouldBe "__builtin.max" } - "select the imported abs over the built-in type when call is shadowed" in { + "conservatively present either option when an imported function uses the same name as a builtin" in { val Some(absCall) = cpg.call("abs").headOption - absCall.dynamicTypeHintFullName shouldBe Seq("foo.py:.abs") + absCall.dynamicTypeHintFullName shouldBe Seq("foo.py:.abs", "__builtin.abs") } } @@ -190,9 +199,9 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) { "resolve 'x' and 'y' locally under foo.py" in { val Some(x) = cpg.file.name(".*foo.*").ast.isIdentifier.name("x").headOption - x.typeFullName shouldBe "int" + x.typeFullName shouldBe "__builtin.int" val Some(y) = cpg.file.name(".*foo.*").ast.isIdentifier.name("y").headOption - y.typeFullName shouldBe "str" + y.typeFullName shouldBe "__builtin.str" } "resolve 'foo.x' and 'foo.y' field access primitive types correctly" in { @@ -203,9 +212,9 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) { .name("z") .l z1.typeFullName shouldBe "ANY" - z1.dynamicTypeHintFullName shouldBe Seq("int", "str") + z1.dynamicTypeHintFullName shouldBe Seq("__builtin.int", "__builtin.str") z2.typeFullName shouldBe "ANY" - z2.dynamicTypeHintFullName shouldBe Seq("int", "str") + z2.dynamicTypeHintFullName shouldBe Seq("__builtin.int", "__builtin.str") } "resolve 'foo.d' field access object types correctly" in { @@ -256,7 +265,7 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) { | try: | db.create_all() | db.session.add(user) - | return jsonify({"success": message}) + | return jsonify({"success": True}) | except Exception as e: | return 'There was an issue adding your task' |""".stripMargin, @@ -377,12 +386,15 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) { "recover a potential type for `self.collection` using the assignment at `get_collection` as a type hint" in { val Some(selfFindFound) = cpg.typeDecl(".*InstallationsDAO.*").ast.isCall.name("find_one").headOption - selfFindFound.methodFullName shouldBe "pymongo.py:.MongoClient....find_one" + selfFindFound.dynamicTypeHintFullName shouldBe Seq( + "__builtin.None.find_one", + "pymongo.py:.MongoClient....find_one" + ) } "correctly determine that, despite being unable to resolve the correct method full name, that it is an internal method" in { val Some(selfFindFound) = cpg.typeDecl(".*InstallationsDAO.*").ast.isCall.name("find_one").headOption - selfFindFound.callee.isExternal.headOption shouldBe Some(false) + selfFindFound.callee.isExternal.toSeq shouldBe Seq(true, false) } } diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/SymbolTable.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/SymbolTable.scala index fe3076b17af2..fd859bc96b12 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/SymbolTable.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/SymbolTable.scala @@ -56,6 +56,10 @@ sealed class LocalKey(identifier: String) extends SBKey(identifier) { */ case class LocalVar(override val identifier: String) extends LocalKey(identifier) +/** A collection object that can be accessed with potentially dynamic keys and values. + */ +case class CollectionVar(override val identifier: String, idx: String) extends LocalKey(identifier) + /** A name that refers to some kind of callee. */ case class CallAlias(override val identifier: String) extends LocalKey(identifier) @@ -98,25 +102,25 @@ class SymbolTable[K <: SBKey](fromNode: AstNode => K) { table.put(newKey, newValues) } - def put(sbKey: K, typeFullNames: Set[String]): Option[Set[String]] = - table.put(sbKey, typeFullNames) + def put(sbKey: K, typeFullNames: Set[String]): Set[String] = + table.put(sbKey, typeFullNames).getOrElse(Set.empty) - def put(sbKey: K, typeFullName: String): Option[Set[String]] = + def put(sbKey: K, typeFullName: String): Set[String] = put(sbKey, Set(typeFullName)) - def put(node: AstNode, typeFullNames: Set[String]): Option[Set[String]] = + def put(node: AstNode, typeFullNames: Set[String]): Set[String] = put(fromNode(node), typeFullNames) - def append(node: AstNode, typeFullName: String): Option[Set[String]] = + def append(node: AstNode, typeFullName: String): Set[String] = append(node, Set(typeFullName)) - def append(node: AstNode, typeFullNames: Set[String]): Option[Set[String]] = + def append(node: AstNode, typeFullNames: Set[String]): Set[String] = append(fromNode(node), typeFullNames) - def append(sbKey: K, typeFullNames: Set[String]): Option[Set[String]] = { + def append(sbKey: K, typeFullNames: Set[String]): Set[String] = { table.get(sbKey) match { - case Some(ts) => table.put(sbKey, ts ++ typeFullNames) - case None => table.put(sbKey, typeFullNames) + case Some(ts) => put(sbKey, ts ++ typeFullNames) + case None => put(sbKey, typeFullNames) } } diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/XTypeRecovery.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/XTypeRecovery.scala index 8be04607106a..9ce804ee7e4a 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/XTypeRecovery.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/passes/frontend/XTypeRecovery.scala @@ -1,18 +1,19 @@ package io.joern.x2cpg.passes.frontend +import io.joern.x2cpg.Defines import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.nodes._ import io.shiftleft.codepropertygraph.generated.{Operators, PropertyNames} import io.shiftleft.passes.CpgPass import io.shiftleft.semanticcpg.language._ import io.shiftleft.semanticcpg.language.operatorextension.OpNodes -import io.shiftleft.semanticcpg.language.operatorextension.OpNodes.Assignment +import io.shiftleft.semanticcpg.language.operatorextension.OpNodes.{Assignment, FieldAccess} import org.slf4j.{Logger, LoggerFactory} import overflowdb.BatchedUpdate.DiffGraphBuilder import overflowdb.traversal.Traversal -import java.util.Objects import java.util.concurrent.RecursiveTask +import scala.collection.mutable /** Based on a flow-insensitive symbol-table-style approach. This pass aims to be fast and deterministic and does not * try to converge to some fixed point but rather iterates a fixed number of times. This will help recover: @@ -50,9 +51,9 @@ abstract class XTypeRecovery[ComputationalUnit <: AstNode](cpg: Cpg, iterations: protected val globalTable = new SymbolTable[GlobalKey](SBKey.fromNodeToGlobalKey) override def run(builder: DiffGraphBuilder): Unit = try { - for (_ <- 0 to iterations) + for (_ <- 0 until iterations) computationalUnit - .map(unit => generateRecoveryForCompilationUnitTask(unit, builder, globalTable).fork()) + .map(unit => generateRecoveryForCompilationUnitTask(unit, builder).fork()) .foreach(_.get()) } finally { globalTable.clear() @@ -68,95 +69,45 @@ abstract class XTypeRecovery[ComputationalUnit <: AstNode](cpg: Cpg, iterations: * the compilation unit. * @param builder * the graph builder. - * @param globalTable - * the global table. * @return * a forkable [[RecoverForXCompilationUnit]] task. */ def generateRecoveryForCompilationUnitTask( unit: ComputationalUnit, - builder: DiffGraphBuilder, - globalTable: SymbolTable[GlobalKey] + builder: DiffGraphBuilder ): RecoverForXCompilationUnit[ComputationalUnit] } -/** Defines how a procedure is available to be called in the current scope either by it being defined in this module or - * being imported. - * - * @param callingName - * how this procedure is to be called, i.e., alias name, name with path, etc. - * @param fullName - * the full name to where this method is defined. - */ -abstract class ScopedXProcedure(val callingName: String, val fullName: String, val isConstructor: Boolean = false) { - - /** @return - * there are multiple ways that this procedure could be resolved in some languages. This will be pruned later by - * comparing this to actual methods in the CPG using postVisitImport. - */ - def possibleCalleeNames: Set[String] = Set() - - override def toString: String = s"ProcedureCalledAs(${possibleCalleeNames.mkString(", ")})" - - override def equals(obj: Any): Boolean = { - obj match { - case o: ScopedXProcedure => - callingName.equals(o.callingName) && fullName.equals(o.fullName) && isConstructor == o.isConstructor - case _ => false - } - } - - override def hashCode(): Int = Objects.hash(callingName, fullName, isConstructor) - -} - -/** Tasks responsible for populating the symbol table with import data. - * - * @param node - * a node that references import information. - */ -abstract class SetXProcedureDefTask(node: CfgNode) extends RecursiveTask[Unit] { - - protected val logger: Logger = LoggerFactory.getLogger(classOf[SetXProcedureDefTask]) - - override def compute(): Unit = - node match { - case x: Method => visitImport(x) - case x: Call => visitImport(x) - case _ => - } - - /** Refers to the declared import information. - * - * @param importCall - * the call that imports entities into this scope. - */ - def visitImport(importCall: Call): Unit - - /** Sets how an application method would be referred to locally. - * - * @param m - * an internal method - */ - def visitImport(m: Method): Unit - +object XTypeRecovery { + val DUMMY_RETURN_TYPE = "" + val DUMMY_MEMBER_LOAD = "" + val DUMMY_INDEX_ACCESS = "" + def DUMMY_MEMBER_TYPE(prefix: String, memberName: String) = s"$prefix.$DUMMY_MEMBER_LOAD($memberName)" } /** Performs type recovery from the root of a compilation unit level * + * @param cpg + * the graph. * @param cu * a compilation unit, e.g. file, procedure, type, etc. * @param builder * the graph builder + * @param globalTable + * the global symbol table. * @tparam ComputationalUnit * the [[AstNode]] type used to represent a computational unit of the language. */ abstract class RecoverForXCompilationUnit[ComputationalUnit <: AstNode]( + cpg: Cpg, cu: ComputationalUnit, - builder: DiffGraphBuilder + builder: DiffGraphBuilder, + globalTable: SymbolTable[GlobalKey] ) extends RecursiveTask[Unit] { + protected val logger: Logger = LoggerFactory.getLogger(getClass) + /** Stores type information for local structures that live within this compilation unit, e.g. local variables. */ protected val symbolTable = new SymbolTable[LocalKey](SBKey.fromNodeToLocalKey) @@ -168,12 +119,17 @@ abstract class RecoverForXCompilationUnit[ComputationalUnit <: AstNode]( protected def assignments: Traversal[Assignment] = cu.ast.isCall.name(Operators.assignment).map(new OpNodes.Assignment(_)) + protected def members: Traversal[Member] = + cu.ast.isMember + override def compute(): Unit = try { prepopulateSymbolTable() // Set known aliases that point to imports for local and external methods/modules - setImportsFromDeclaredProcedures(importNodes(cu) ++ internalMethodNodes(cu)) + visitImports(importNodes(cu)) // Prune import names if the methods exist in the CPG postVisitImports() + // Populate fields + members.foreach(visitMembers) // Populate local symbol table with assignments assignments.foreach(visitAssignments) // Persist findings @@ -182,51 +138,465 @@ abstract class RecoverForXCompilationUnit[ComputationalUnit <: AstNode]( symbolTable.clear() } + private def debugLocation(n: AstNode): String = { + val rootPath = cpg.metaData.root.headOption.getOrElse("") + val fileName = n.file.name.headOption.getOrElse("").stripPrefix(rootPath) + val lineNo = n.lineNumber.getOrElse("") + s"$fileName#L$lineNo" + } + /** Using import information and internally defined procedures, will generate a mapping between how functions and * types are aliased and called and themselves. * * @param procedureDeclarations * imports to types or functions and internally defined methods themselves. */ - protected def setImportsFromDeclaredProcedures(procedureDeclarations: Traversal[CfgNode]): Unit = - procedureDeclarations.map(f => generateSetProcedureDefTask(f, symbolTable).fork()).foreach(_.get()) + protected def visitImports(procedureDeclarations: Traversal[AstNode]): Unit = { + procedureDeclarations.foreach { + case i: Import => visitImport(i) + case i: Call => visitImport(i) + } + } - /** Generates a task to create an import task. + /** Refers to the declared import information. This is for legacy import notation. * - * @param node - * the import node or method definition node. - * @param symbolTable - * the local table. - * @return - * a forkable [[SetXProcedureDefTask]] task. + * @param i + * the call that imports entities into this scope. */ - def generateSetProcedureDefTask(node: CfgNode, symbolTable: SymbolTable[LocalKey]): SetXProcedureDefTask + protected def visitImport(i: Call): Unit - /** @return - * the import nodes of this computational unit. + /** Visits an import and stores references in the symbol table as both an identifier and call. */ - def importNodes(cu: AstNode): Traversal[CfgNode] + protected def visitImport(i: Import): Unit = { + val entity = i.importedEntity + val alias = i.importedAs + if (entity.isDefined && alias.isDefined) { + symbolTable.append(LocalVar(alias.get), Set(entity.get)) + symbolTable.append(CallAlias(alias.get), Set(entity.get)) + } + } - /** @param cu - * the current computational unit. - * @return - * the methods defined within this computational unit. + /** @return + * the import nodes of this computational unit. */ - def internalMethodNodes(cu: AstNode): Traversal[Method] = cu.ast.isMethod.isExternal(false) + protected def importNodes(cu: AstNode): Traversal[AstNode] = cu.ast.isImport /** The initial import setting is over-approximated, so this step checks the CPG for any matches and prunes against * these findings. If there are no findings, it will leave the table as is. The latter is significant for external * types or methods. */ - def postVisitImports(): Unit = {} + protected def postVisitImports(): Unit = {} + + /** Using member information, will propagate member information to the global and local symbol table. By default, + * fields in the local table will be prepended with "this". + */ + protected def visitMembers(member: Member): Unit = { + symbolTable.put(LocalVar(member.name), Set.empty[String]) + globalTable.put(FieldVar(member.typeDecl.fullName, member.name), Set.empty[String]) + } /** Using assignment and import information (in the global symbol table), will propagate these types in the symbol * table. * - * @param assignment + * @param a * assignment call pointer. */ - def visitAssignments(assignment: Assignment): Unit + protected def visitAssignments(a: Assignment): Set[String] = { + a.argumentOut.l match { + case List(i: Identifier, b: Block) => visitIdentifierAssignedToBlock(i, b) + case List(i: Identifier, c: Call) => visitIdentifierAssignedToCall(i, c) + case List(i: Identifier, l: Literal) => visitIdentifierAssignedToLiteral(i, l) + case List(i: Identifier, m: MethodRef) => visitIdentifierAssignedToMethodRef(i, m) + case List(i: Identifier, t: TypeRef) => visitIdentifierAssignedToTypeRef(i, t) + case List(c: Call, i: Identifier) => visitCallAssignedToIdentifier(c, i) + case List(x: Call, y: Call) => visitCallAssignedToCall(x, y) + case List(c: Call, l: Literal) => visitCallAssignedToLiteral(c, l) + case List(c: Call, m: MethodRef) => visitCallAssignedToMethodRef(c, m) + case List(c: Call, b: Block) => visitCallAssignedToBlock(c, b) + case xs => + logger.warn(s"Unhandled assignment ${xs.map(x => (x.label, x.code)).mkString(",")} @ ${debugLocation(a)}") + Set.empty + } + } + + /** Visits an identifier being assigned to the result of some operation. + */ + protected def visitIdentifierAssignedToBlock(i: Identifier, b: Block): Set[String] = { + val blockTypes = visitStatementsInBlock(b) + if (blockTypes.nonEmpty) associateTypes(i, blockTypes) + else Set.empty + } + + /** Visits a call operation being assigned to the result of some operation. + */ + protected def visitCallAssignedToBlock(c: Call, b: Block): Set[String] = { + val blockTypes = visitStatementsInBlock(b) + assignTypesToCall(c, blockTypes) + } + + /** Process each statement but only assign the type of the last statement to the identifier + */ + protected def visitStatementsInBlock(b: Block): Set[String] = + b.astChildren + .map { + case x: Call if x.name.equals(Operators.assignment) => visitAssignments(new Assignment(x)) + case x: Identifier if x.astChildren.isEmpty && symbolTable.contains(x) => symbolTable.get(x) + case x: Call if symbolTable.contains(x) => symbolTable.get(x) + case x: Call if x.argument.headOption.exists(symbolTable.contains) => setCallMethodFullNameFromBase(x) + case x => logger.warn(s"Unhandled block element ${x.label}:${x.code} @ ${debugLocation(x)}"); Set.empty[String] + } + .lastOption + .getOrElse(Set.empty[String]) + + /** Visits an identifier being assigned to a call. This call could be an operation, function invocation, or + * constructor invocation. + */ + protected def visitIdentifierAssignedToCall(i: Identifier, c: Call): Set[String] = { + if (c.name.startsWith("")) { + visitIdentifierAssignedToOperator(i, c, c.name) + } else if (symbolTable.contains(c) && isConstructor(c)) { + visitIdentifierAssignedToConstructor(i, c) + } else if (symbolTable.contains(c)) { + visitIdentifierAssignedToCallRetVal(i, c) + } else if (c.argument.headOption.exists(symbolTable.contains)) { + setCallMethodFullNameFromBase(c) + // Repeat this method now that the call has a type + visitIdentifierAssignedToCall(i, c) + } else { + // We can try obtain a return type for this call + visitIdentifierAssignedToCallRetVal(i, c) + } + } + + /** Will build a call full path using the call base node. This method assumes the base node is in the symbol table. + */ + protected def setCallMethodFullNameFromBase(c: Call): Set[String] = { + val recTypes = c.argument.headOption.map(symbolTable.get).getOrElse(Set.empty[String]) + val callTypes = recTypes.map(_.concat(s".${c.name}")) + symbolTable.append(c, callTypes) + } + + /** A heuristic method to determine if a call is a constructor or not. + */ + protected def isConstructor(c: Call): Boolean + + /** A heuristic method to determine if an identifier may be a field or not. The result means that it would be stored + * in the global symbol table. By default this checks if the identifier name matches a member name. + */ + protected def isField(i: Identifier): Boolean = i.method.typeDecl.member.exists(_.name.equals(i.name)) + + /** Associates the types with the identifier. This may sometimes be an identifier that should be considered a field + * which this method uses [[isField(i)]] to determine. + */ + protected def associateTypes(i: Identifier, types: Set[String]): Set[String] = { + if (isField(i)) globalTable.put(i, types) + symbolTable.append(i, types) + } + + /** Returns the appropriate field parent scope. + */ + protected def getFieldParents(fa: FieldAccess): Set[String] = { + val fieldName = getFieldName(fa).split("\\.").last + cpg.typeDecl.filter(_.member.exists(_.name.equals(fieldName))).fullName.filterNot(_.contains("ANY")).toSet + } + + /** Associates the types with the identifier. This may sometimes be an identifier that should be considered a field + * which this method uses [[isField(i)]] to determine. + */ + protected def associateTypes(symbol: LocalVar, fa: FieldAccess, types: Set[String]): Set[String] = { + val fieldParents = getFieldParents(fa) + fa.astChildren.l match { + case ::(i: Identifier, _) if isField(i) && fieldParents.nonEmpty => + fieldParents.foreach(fp => globalTable.put(FieldVar(fp, symbol.identifier), types)) + case _ => + } + symbolTable.append(symbol, types) + } + + /** Similar to [[associateTypes()]] but used in the case where there is some kind of field load. + */ + protected def associateInterproceduralTypes( + i: Identifier, + base: Identifier, + fi: FieldIdentifier, + fieldName: String, + baseTypes: Set[String] + ): Set[String] = { + val globalTypes = getFieldBaseType(base, fi) + associateInterproceduralTypes(i, fieldName, fi.canonicalName, globalTypes, baseTypes) + } + + protected def associateInterproceduralTypes( + i: Identifier, + fieldFullName: String, + fieldName: String, + globalTypes: Set[String], + baseTypes: Set[String] + ): Set[String] = { + if (globalTypes.nonEmpty) { + // We have been able to resolve the type inter-procedurally + associateTypes(i, globalTypes) + } else if (baseTypes.nonEmpty) { + if (baseTypes.equals(symbolTable.get(LocalVar(fieldFullName)))) { + associateTypes(i, baseTypes) + } else { + // If not available, use a dummy variable that can be useful for call matching + associateTypes(i, baseTypes.map(t => XTypeRecovery.DUMMY_MEMBER_TYPE(t, fieldName))) + } + } else { + logger.warn(s"Unable to associate interprocedural type for $i = $fieldFullName @ ${debugLocation(i)}") + Set.empty + } + } + + /** Visits an identifier being assigned to an operator call. + */ + protected def visitIdentifierAssignedToOperator(i: Identifier, c: Call, operation: String): Set[String] = { + operation match { + case Operators.alloc => visitIdentifierAssignedToConstructor(i, c) + case Operators.fieldAccess => visitIdentifierAssignedToFieldLoad(i, new FieldAccess(c)) + case x => logger.warn(s"Unhandled operation $x (${c.code}) @ ${debugLocation(c)}"); Set.empty + } + } + + /** Visits an identifier being assigned to a constructor and attempts to speculate the constructor path. + */ + protected def visitIdentifierAssignedToConstructor(i: Identifier, c: Call): Set[String] = { + val constructorPaths = symbolTable.get(c).map(t => t.concat(s".${Defines.ConstructorMethodName}")) + associateTypes(i, constructorPaths) + } + + /** Visits an identifier being assigned to a call's return value. + */ + protected def visitIdentifierAssignedToCallRetVal(i: Identifier, c: Call): Set[String] = { + if (symbolTable.contains(c)) { + val callReturns = methodReturnValues(symbolTable.get(c).toSeq) + associateTypes(i, callReturns) + } else if (c.argument.exists(_.argumentIndex == 0)) { + val callFullNames = (c.argument(0) match { + case i: Identifier if symbolTable.contains(LocalVar(i.name)) => symbolTable.get(LocalVar(i.name)) + case i: Identifier if symbolTable.contains(CallAlias(i.name)) => symbolTable.get(CallAlias(i.name)) + case _ => Set.empty + }).map(_.concat(s".${c.name}")).toSeq + val callReturns = methodReturnValues(callFullNames) + associateTypes(i, callReturns) + } else { + logger.warn(s"Unable to assign identifier '${i.name} to return value for ${c.code} @ ${debugLocation(c)}") + Set.empty + } + } + + /** Will attempt to find the return values of a method if in the CPG, otherwise will give a dummy value. + */ + protected def methodReturnValues(methodFullNames: Seq[String]): Set[String] = { + val rs = cpg.method.fullNameExact(methodFullNames: _*).methodReturn.typeFullName.filterNot(_.equals("ANY")).toSet + if (rs.isEmpty) methodFullNames.map(_.concat(s".${XTypeRecovery.DUMMY_RETURN_TYPE}")).toSet + else rs + } + + /** Will handle literal value assignments. Override if special handling is required. + */ + protected def visitIdentifierAssignedToLiteral(i: Identifier, l: Literal): Set[String] = + associateTypes(i, getLiteralType(l)) + + /** Not all frontends populate typeFullName for literals so we allow this to be overridden. + */ + protected def getLiteralType(l: Literal): Set[String] = Set(l.typeFullName) + + /** Will handle an identifier holding a function pointer. + */ + protected def visitIdentifierAssignedToMethodRef(i: Identifier, m: MethodRef): Set[String] = + symbolTable.append(CallAlias(i.name), Set(m.methodFullName)) + + /** Will handle an identifier holding a type pointer. + */ + protected def visitIdentifierAssignedToTypeRef(i: Identifier, t: TypeRef): Set[String] = + symbolTable.append(CallAlias(i.name), Set(t.typeFullName)) + + /** Visits a call assigned to an identifier. This is often when there are operators involved. + */ + protected def visitCallAssignedToIdentifier(c: Call, i: Identifier): Set[String] = { + val rhsTypes = symbolTable.get(i) + assignTypesToCall(c, rhsTypes) + } + + /** Visits a call assigned to the return value of a call. This is often when there are operators involved. + */ + protected def visitCallAssignedToCall(x: Call, y: Call): Set[String] = { + val rhsTypes = y.name match { + case Operators.fieldAccess => symbolTable.get(LocalVar(getFieldName(new FieldAccess(y)))) + case _ if symbolTable.contains(y) => symbolTable.get(y) + case Operators.indexAccess => getIndexAccessTypes(y) + case n => + logger.warn(s"Unknown RHS call type '$n' @ ${debugLocation(x)}") + Set.empty[String] + } + assignTypesToCall(x, rhsTypes) + } + + /** Given a LHS call, will retrieve its symbol to the given types. + */ + protected def assignTypesToCall(x: Call, types: Set[String]): Set[String] = { + if (types.isEmpty) return Set.empty + getSymbolFromCall(x) match { + case (lhs, globalKeys) if globalKeys.nonEmpty => + globalKeys.foreach(gt => globalTable.append(gt, types)) + symbolTable.append(lhs, types) + case (lhs, _) => symbolTable.append(lhs, types) + } + } + + /** Will attempt to retrieve index access types otherwise will return dummy value. + */ + protected def getIndexAccessTypes(ia: Call): Set[String] = { + indexAccessToCollectionVar(ia) match { + case Some(cVar) if symbolTable.contains(cVar) => + symbolTable.get(cVar) + case Some(cVar) if symbolTable.contains(LocalVar(cVar.identifier)) => + symbolTable.get(LocalVar(cVar.identifier)).map(_.concat(s".${XTypeRecovery.DUMMY_INDEX_ACCESS}")) + case None => Set.empty + } + } + + /** Tries to identify the underlying symbol from the call operation as it is used on the LHS of an assignment. The + * second element is a list of any associated global keys if applicable. + */ + protected def getSymbolFromCall(c: Call): (LocalKey, Set[GlobalKey]) = c.name match { + case Operators.fieldAccess => + val fa = new FieldAccess(c) + val fieldName = getFieldName(fa) + (LocalVar(fieldName), getFieldParents(fa).map(fp => FieldVar(fp, fieldName))) + case Operators.indexAccess => (indexAccessToCollectionVar(c).getOrElse(LocalVar(c.name)), Set.empty) + case x => + logger.warn(s"Unknown LHS call type '$x' @ ${debugLocation(c)}") + (LocalVar(c.name), Set.empty) + } + + /** Extracts a string representation of the name of the field within this field access. + */ + protected def getFieldName(fa: FieldAccess, prefix: String = "", suffix: String = ""): String = { + def wrapName(n: String) = { + val sb = new mutable.StringBuilder() + if (prefix.nonEmpty) sb.append(s"$prefix.") + sb.append(n) + if (suffix.nonEmpty) sb.append(s".$suffix") + sb.toString() + } + + fa.astChildren.l match { + case List(i: Identifier, f: FieldIdentifier) if i.name.matches("(self|this)") => wrapName(f.canonicalName) + case List(i: Identifier, f: FieldIdentifier) => wrapName(s"${i.name}.${f.canonicalName}") + case List(c: Call, f: FieldIdentifier) if c.name.equals(Operators.fieldAccess) => + wrapName(getFieldName(new FieldAccess(c), suffix = f.canonicalName)) + case List(f: FieldIdentifier, c: Call) if c.name.equals(Operators.fieldAccess) => + wrapName(getFieldName(new FieldAccess(c), prefix = f.canonicalName)) + case xs => + logger.warn(s"Unhandled field structure ${xs.map(x => (x.label, x.code)).mkString(",")} @ ${debugLocation(fa)}") + wrapName("") + } + } + + protected def visitCallAssignedToLiteral(c: Call, l: Literal): Set[String] = { + if (c.name.equals(Operators.indexAccess)) { + // For now, we will just handle this on a very basic level + c.argumentOut.l match { + case List(_: Identifier, _: Literal) => + indexAccessToCollectionVar(c).map(cv => symbolTable.append(cv, getLiteralType(l))).getOrElse(Set.empty) + case List(_: Identifier, idx: Identifier) if symbolTable.contains(idx) => + // Imprecise but sound! + indexAccessToCollectionVar(c).map(cv => symbolTable.append(cv, symbolTable.get(idx))).getOrElse(Set.empty) + case xs => + logger.warn(s"Unhandled index access point assigned to literal ${xs.map(_.label)} @ ${debugLocation(c)}") + Set.empty + } + } else if (c.name.equals(Operators.fieldAccess)) { + val fa = new FieldAccess(c) + val fieldName = getFieldName(fa) + associateTypes(LocalVar(fieldName), fa, getLiteralType(l)) + } else { + logger.warn(s"Unhandled call assigned to literal point ${c.name} @ ${debugLocation(c)}") + Set.empty + } + } + + /** Handles a call operation assigned to a method/function pointer. + */ + protected def visitCallAssignedToMethodRef(c: Call, m: MethodRef): Set[String] = + assignTypesToCall(c, Set(m.methodFullName)) + + /** Generates an identifier for collection/index-access operations in the symbol table. + */ + protected def indexAccessToCollectionVar(c: Call): Option[CollectionVar] = { + + def callName(x: Call) = + if (x.name.equals(Operators.fieldAccess)) + getFieldName(new FieldAccess(x)) + else if (x.name.equals(Operators.indexAccess)) + indexAccessToCollectionVar(x) + .map(cv => s"${cv.identifier}[${cv.idx}]") + .getOrElse(XTypeRecovery.DUMMY_INDEX_ACCESS) + else x.name + + Option(c.argumentOut.l match { + case List(i: Identifier, idx: Literal) => CollectionVar(i.name, idx.code) + case List(i: Identifier, idx: Identifier) => CollectionVar(i.name, idx.code) + case List(c: Call, idx: Call) => CollectionVar(callName(c), callName(idx)) + case List(c: Call, idx: Literal) => CollectionVar(callName(c), idx.code) + case List(c: Call, idx: Identifier) => CollectionVar(callName(c), idx.code) + case xs => + logger.warn(s"Unhandled index access ${xs.map(x => (x.label, x.code)).mkString(",")} @ ${debugLocation(c)}") + null + }) + } + + /** Will handle an identifier being assigned to a field value. + */ + protected def visitIdentifierAssignedToFieldLoad(i: Identifier, fa: FieldAccess): Set[String] = { + val fieldName = getFieldName(fa) + fa.astChildren.l match { + case List(base: Identifier, fi: FieldIdentifier) if symbolTable.contains(LocalVar(base.name)) => + // Get field from global table if referenced as a variable + val localTypes = symbolTable.get(LocalVar(base.name)) + associateInterproceduralTypes(i, base, fi, fieldName, localTypes) + case List(base: Identifier, fi: FieldIdentifier) if symbolTable.contains(LocalVar(fieldName)) => + val localTypes = symbolTable.get(LocalVar(fieldName)) + associateInterproceduralTypes(i, base, fi, fieldName, localTypes) + case List(c: Call, f: FieldIdentifier) if c.name.equals(Operators.fieldAccess) => + val baseName = getFieldName(new FieldAccess(c)) + // Build type regardless of length + // TODO: This is more prone to giving dummy values as it does not do global look-ups + // but this is okay for now + val buf = mutable.ArrayBuffer.empty[String] + for (segment <- baseName.split("\\.") ++ Array(f.canonicalName)) { + val types = + if (buf.isEmpty) symbolTable.get(LocalVar(segment)) + else buf.flatMap(t => symbolTable.get(LocalVar(s"$t.$segment"))).toSet + if (types.nonEmpty) { + buf.clear() + buf.addAll(types) + } else { + val bufCopy = Array.from(buf) + buf.clear() + bufCopy.foreach(t => buf.addOne(XTypeRecovery.DUMMY_MEMBER_TYPE(t, segment))) + } + } + associateTypes(i, buf.toSet) + case _ => + logger.warn(s"Unable to assign identifier '${i.name}' to field load '$fieldName' @ ${debugLocation(i)}") + Set.empty + } + } + + protected def getFieldBaseType(base: Identifier, fi: FieldIdentifier): Set[String] = + getFieldBaseType(base.name, fi.canonicalName) + + protected def getFieldBaseType(baseName: String, fieldName: String): Set[String] = { + val localTypes = symbolTable.get(LocalVar(baseName)) + val globalTypes = localTypes + .map(t => FieldVar(t, fieldName)) + .flatMap(globalTable.get) + globalTypes + } /** Using an entry from the symbol table, will queue the CPG modification to persist the recovered type information. */ @@ -245,6 +615,8 @@ abstract class RecoverForXCompilationUnit[ComputationalUnit <: AstNode]( } } + /** TODO: Cleaning up using visitor patten + */ private def setTypeInformationForRecCall(x: AstNode, n: Option[Call], ms: List[AstNode]): Unit = (n, ms) match { // Case 1: 'call' is an assignment from some dynamic dispatch call diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/AstNodeMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/AstNodeMethods.scala index 05d202a17763..549bf39c0854 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/AstNodeMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/AstNodeMethods.scala @@ -14,6 +14,8 @@ class AstNodeMethods(val node: AstNode) extends AnyVal with NodeExtension { def isIdentifier: Boolean = node.isInstanceOf[Identifier] + def isImport: Boolean = node.isInstanceOf[Import] + def isFieldIdentifier: Boolean = node.isInstanceOf[FieldIdentifier] def isFile: Boolean = node.isInstanceOf[File] diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/AstNodeTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/AstNodeTraversal.scala index f2de2a8ed707..0fb3bab658d8 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/AstNodeTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/AstNodeTraversal.scala @@ -134,6 +134,11 @@ class AstNodeTraversal[A <: AstNode](val traversal: Traversal[A]) extends AnyVal def isIdentifier: Traversal[Identifier] = traversal.collectAll[Identifier] + /** Traverse only to AST nodes that are IMPORT nodes + */ + def isImport: Traversal[Import] = + traversal.collectAll[Import] + /** Traverse only to FILE AST nodes */ def isFile: Traversal[File] =