Skip to content

Commit

Permalink
XTypeRecovery: Clean Up With Visitor-Pattern & Support New Imports (#…
Browse files Browse the repository at this point in the history
…2272)

- Re-implemented `XTypeRecovery` using visitor pattern. This reduces the `PythonTypeRecovery` file by 50% and increases the language agnostic `XTypeRecovery` file by 100%.
- `XTypeRecovery` can now support the new `Import` node while leaving room to override the `Call` version as still used in Python.
- This change improves the soundness of the type recovery and, as such, resulted in some modifications required for some test cases.
- This change also promotes re-use of code and easy diagnostics using a logger with debugging locators. Debug level used is `WARN`. 

Breaking changes:
- Built-in types are now prefixed with `__builtin` which is what the front-end appears to use.
- `__builtin.None` is now supported, which means potentially more types may be suggested if variables were initialized with `var = None`.
  • Loading branch information
DavidBakerEffendi authored Feb 15, 2023
1 parent f7dbe52 commit bcf2799
Show file tree
Hide file tree
Showing 7 changed files with 600 additions and 441 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,9 @@ class CallCpgTests extends PySrc2CpgFixture(withOssDataflow = false) {

"test that the identifiers are not set to the function pointers but rather the 'ANY' return value" in {
val List(x, y, z) = cpg.identifier.name("x", "y", "z").l
x.typeFullName shouldBe "ANY"
y.typeFullName shouldBe "ANY"
z.typeFullName shouldBe "ANY"
x.typeFullName shouldBe "foo.py:<module>.foo_func.<returnValue>"
y.typeFullName shouldBe Seq("foo", "bar", "__init__.py:<module>.bar_func.<returnValue>").mkString(File.separator)
z.typeFullName shouldBe "foo.py:<module>.faz.<returnValue>"
}

"test call node properties for normal import from module on root path" in {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,26 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) {

"resolve 'x' identifier types despite shadowing" in {
val List(xOuterScope, xInnerScope) = cpg.identifier("x").take(2).l
xOuterScope.dynamicTypeHintFullName shouldBe Seq("int", "str")
xInnerScope.dynamicTypeHintFullName shouldBe Seq("int", "str")
xOuterScope.dynamicTypeHintFullName shouldBe Seq("__builtin.int", "__builtin.str")
xInnerScope.dynamicTypeHintFullName shouldBe Seq("__builtin.int", "__builtin.str")
}

"resolve 'y' and 'z' identifier collection types" in {
val List(zDict, zList, zTuple) = cpg.identifier("z").take(3).l
zDict.dynamicTypeHintFullName shouldBe Seq("dict", "list", "tuple")
zList.dynamicTypeHintFullName shouldBe Seq("dict", "list", "tuple")
zTuple.dynamicTypeHintFullName shouldBe Seq("dict", "list", "tuple")
zDict.dynamicTypeHintFullName shouldBe Seq("__builtin.dict", "__builtin.list", "__builtin.tuple")
zList.dynamicTypeHintFullName shouldBe Seq("__builtin.dict", "__builtin.list", "__builtin.tuple")
zTuple.dynamicTypeHintFullName shouldBe Seq("__builtin.dict", "__builtin.list", "__builtin.tuple")
}

"resolve 'z' identifier calls conservatively" in {
// TODO: These should have callee entries but the method stubs are not present here
val List(zAppend) = cpg.call("append").l
zAppend.methodFullName shouldBe Defines.DynamicCallUnknownFallName
zAppend.dynamicTypeHintFullName shouldBe Seq("dict.append", "list.append", "tuple.append")
zAppend.dynamicTypeHintFullName shouldBe Seq(
"__builtin.dict.append",
"__builtin.list.append",
"__builtin.tuple.append"
)
}
}

Expand Down Expand Up @@ -82,6 +86,11 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) {
postMessage.methodFullName shouldBe "slack_sdk.py:<module>.WebClient.WebClient<body>.chat_postMessage"
}

"resolve a dummy 'send' return value from sg.send" in {
val List(postMessage) = cpg.identifier("response").l
postMessage.typeFullName shouldBe "sendgrid.py:<module>.SendGridAPIClient.SendGridAPIClient<body>.send.<returnValue>"
}

}

"type recovery on class members" should {
Expand Down Expand Up @@ -125,7 +134,7 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) {
"resolve 'User' field types" in {
val List(id, firstname, age, address) =
cpg.identifier.nameExact("id", "firstname", "age", "address").takeRight(4).l
id.typeFullName shouldBe "flask_sqlalchemy.py:<module>.SQLAlchemy.SQLAlchemy<body>.Column"
id.typeFullName shouldBe "flask_sqlalchemy.py:<module>.SQLAlchemy.SQLAlchemy<body>.Column.Column<body>"
firstname.typeFullName shouldBe "flask_sqlalchemy.py:<module>.SQLAlchemy.SQLAlchemy<body>.Column.Column<body>"
age.typeFullName shouldBe "flask_sqlalchemy.py:<module>.SQLAlchemy.SQLAlchemy<body>.Column.Column<body>"
address.typeFullName shouldBe "flask_sqlalchemy.py:<module>.SQLAlchemy.SQLAlchemy<body>.Column.Column<body>"
Expand All @@ -150,14 +159,14 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) {

"resolve 'print' and 'max' calls" in {
val Some(printCall) = cpg.call("print").headOption
printCall.methodFullName shouldBe "builtins.py:<module>.print"
printCall.methodFullName shouldBe "__builtin.print"
val Some(maxCall) = cpg.call("max").headOption
maxCall.methodFullName shouldBe "builtins.py:<module>.max"
maxCall.methodFullName shouldBe "__builtin.max"
}

"select the imported abs over the built-in type when call is shadowed" in {
"conservatively present either option when an imported function uses the same name as a builtin" in {
val Some(absCall) = cpg.call("abs").headOption
absCall.dynamicTypeHintFullName shouldBe Seq("foo.py:<module>.abs")
absCall.dynamicTypeHintFullName shouldBe Seq("foo.py:<module>.abs", "__builtin.abs")
}

}
Expand Down Expand Up @@ -190,9 +199,9 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) {

"resolve 'x' and 'y' locally under foo.py" in {
val Some(x) = cpg.file.name(".*foo.*").ast.isIdentifier.name("x").headOption
x.typeFullName shouldBe "int"
x.typeFullName shouldBe "__builtin.int"
val Some(y) = cpg.file.name(".*foo.*").ast.isIdentifier.name("y").headOption
y.typeFullName shouldBe "str"
y.typeFullName shouldBe "__builtin.str"
}

"resolve 'foo.x' and 'foo.y' field access primitive types correctly" in {
Expand All @@ -203,9 +212,9 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) {
.name("z")
.l
z1.typeFullName shouldBe "ANY"
z1.dynamicTypeHintFullName shouldBe Seq("int", "str")
z1.dynamicTypeHintFullName shouldBe Seq("__builtin.int", "__builtin.str")
z2.typeFullName shouldBe "ANY"
z2.dynamicTypeHintFullName shouldBe Seq("int", "str")
z2.dynamicTypeHintFullName shouldBe Seq("__builtin.int", "__builtin.str")
}

"resolve 'foo.d' field access object types correctly" in {
Expand Down Expand Up @@ -256,7 +265,7 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) {
| try:
| db.create_all()
| db.session.add(user)
| return jsonify({"success": message})
| return jsonify({"success": True})
| except Exception as e:
| return 'There was an issue adding your task'
|""".stripMargin,
Expand Down Expand Up @@ -377,12 +386,15 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) {

"recover a potential type for `self.collection` using the assignment at `get_collection` as a type hint" in {
val Some(selfFindFound) = cpg.typeDecl(".*InstallationsDAO.*").ast.isCall.name("find_one").headOption
selfFindFound.methodFullName shouldBe "pymongo.py:<module>.MongoClient.<init>.<indexAccess>.<indexAccess>.find_one"
selfFindFound.dynamicTypeHintFullName shouldBe Seq(
"__builtin.None.find_one",
"pymongo.py:<module>.MongoClient.<init>.<indexAccess>.<indexAccess>.find_one"
)
}

"correctly determine that, despite being unable to resolve the correct method full name, that it is an internal method" in {
val Some(selfFindFound) = cpg.typeDecl(".*InstallationsDAO.*").ast.isCall.name("find_one").headOption
selfFindFound.callee.isExternal.headOption shouldBe Some(false)
selfFindFound.callee.isExternal.toSeq shouldBe Seq(true, false)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ sealed class LocalKey(identifier: String) extends SBKey(identifier) {
*/
case class LocalVar(override val identifier: String) extends LocalKey(identifier)

/** A collection object that can be accessed with potentially dynamic keys and values.
*/
case class CollectionVar(override val identifier: String, idx: String) extends LocalKey(identifier)

/** A name that refers to some kind of callee.
*/
case class CallAlias(override val identifier: String) extends LocalKey(identifier)
Expand Down Expand Up @@ -98,25 +102,25 @@ class SymbolTable[K <: SBKey](fromNode: AstNode => K) {
table.put(newKey, newValues)
}

def put(sbKey: K, typeFullNames: Set[String]): Option[Set[String]] =
table.put(sbKey, typeFullNames)
def put(sbKey: K, typeFullNames: Set[String]): Set[String] =
table.put(sbKey, typeFullNames).getOrElse(Set.empty)

def put(sbKey: K, typeFullName: String): Option[Set[String]] =
def put(sbKey: K, typeFullName: String): Set[String] =
put(sbKey, Set(typeFullName))

def put(node: AstNode, typeFullNames: Set[String]): Option[Set[String]] =
def put(node: AstNode, typeFullNames: Set[String]): Set[String] =
put(fromNode(node), typeFullNames)

def append(node: AstNode, typeFullName: String): Option[Set[String]] =
def append(node: AstNode, typeFullName: String): Set[String] =
append(node, Set(typeFullName))

def append(node: AstNode, typeFullNames: Set[String]): Option[Set[String]] =
def append(node: AstNode, typeFullNames: Set[String]): Set[String] =
append(fromNode(node), typeFullNames)

def append(sbKey: K, typeFullNames: Set[String]): Option[Set[String]] = {
def append(sbKey: K, typeFullNames: Set[String]): Set[String] = {
table.get(sbKey) match {
case Some(ts) => table.put(sbKey, ts ++ typeFullNames)
case None => table.put(sbKey, typeFullNames)
case Some(ts) => put(sbKey, ts ++ typeFullNames)
case None => put(sbKey, typeFullNames)
}
}

Expand Down
Loading

0 comments on commit bcf2799

Please sign in to comment.