Skip to content

Commit

Permalink
Fix handling of python imports from directories (#2315)
Browse files Browse the repository at this point in the history
* Add failing test case.

* Detect directories

* Fix handling of directories

* Fix for windows

---------

Co-authored-by: Fabian Yamaguchi <[email protected]>
  • Loading branch information
fabsx00 and Fabian Yamaguchi authored Feb 27, 2023
1 parent a702527 commit c86ad75
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import overflowdb.BatchedUpdate.DiffGraphBuilder
import overflowdb.traversal.Traversal

import java.io.{File => JFile}
import java.nio.file.{Path, Paths}
import java.util.regex.Matcher
import scala.util.Try

Expand Down Expand Up @@ -50,6 +51,7 @@ class RecoverForPythonFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder, global
case x =>
logger.warn(s"Unknown import pattern: ${x.map(_.label).mkString(", ")}")
}
case _ =>
}
}

Expand All @@ -75,9 +77,17 @@ class RecoverForPythonFile(cpg: Cpg, cu: File, builder: DiffGraphBuilder, global
// Case 2: import of a module: import foo => (foo.py or foo/__init.py)
s"$funcOrModule.py:<module>"
case _ =>
// TODO: This assumes importing from modules and never importing nested method
// Case 3: Import from module using alias, e.g. import bar from foo as faz
s"${path.replaceAll("\\.", sep)}.py:<module>.$funcOrModule"
val rootDirectory = cpg.metaData.root.headOption
val absPath = rootDirectory.map(r => Paths.get(r, path))
val fileOrDir = absPath.map(a => better.files.File(a))
val pyFile = absPath.map(a => Paths.get(a.toString + ".py"))
fileOrDir match {
case Some(f) if f.isDirectory && !pyFile.exists { p => better.files.File(p).exists } =>
s"${path.replaceAll("\\.", sep)}/$funcOrModule.py:<module>"
case _ =>
s"${path.replaceAll("\\.", sep)}.py:<module>.$funcOrModule"
}
}

/** The two ways that this procedure could be resolved to in Python. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,31 @@ class TypeRecoveryPassTests extends PySrc2CpgFixture(withOssDataflow = false) {
booleanFieldConstructor.methodFullName shouldBe Seq("django", "db.py:<module>.models.BooleanField.<init>")
.mkString(File.separator)
}
}

"a call made via an import from a directory" should {
lazy val cpg = code("""
|from data import db_session
|
|def foo():
| db_sess = db_session.create_session()
| x = db_sess.query(foo, bar)
|""".stripMargin)
.moreCode(
"""
|from sqlalchemy.orm import Session
|
|def create_session() -> Session:
| global __factory
| return __factory()
|""".stripMargin,
"data/db_session.py"
)

"recover its full name successfully" in {
val List(methodFullName) = cpg.call("query").methodFullName.l
methodFullName shouldBe "data/db_session.py:<module>.create_session.<returnValue>.query"
}
}

}

0 comments on commit c86ad75

Please sign in to comment.