Skip to content

Commit

Permalink
pysrc2cpg: Type Hint Support for Parameters and Returns (#2322)
Browse files Browse the repository at this point in the history
* Simplified methodParameterNode to replace overridden parameters with their optional equivalents
* Parse type hints for parameters and returns and add them to `dynamicTypeHintFullName`
* If hint is detected as a builtin or from the typing package, the full name is given instead of the symbol alone
  • Loading branch information
DavidBakerEffendi authored Feb 28, 2023
1 parent 44d5f43 commit b8ed0de
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 27 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.joern.pysrc2cpg

import io.joern.pysrc2cpg.PythonAstVisitor.{allBuiltinClasses, builtinPrefix, typingClassesV3, typingPrefix}
import io.joern.pythonparser.ast
import io.joern.x2cpg.Defines
import io.joern.x2cpg.utils.NodeBuilders
import io.shiftleft.codepropertygraph.generated.{DispatchTypes, EvaluationStrategies, nodes}
Expand Down Expand Up @@ -118,39 +120,39 @@ class NodeBuilder(diffGraph: DiffGraphBuilder) {

def methodParameterNode(
name: String,
index: Int,
isVariadic: Boolean,
lineAndColumn: LineAndColumn
lineAndColumn: LineAndColumn,
index: Option[Int] = None,
typeHint: Option[ast.iexpr] = None
): nodes.NewMethodParameterIn = {
val methodParameterNode = nodes
.NewMethodParameterIn()
.name(name)
.code(name)
.index(index)
.evaluationStrategy(EvaluationStrategies.BY_SHARING)
.typeFullName(Constants.ANY)
.isVariadic(isVariadic)
.lineNumber(lineAndColumn.line)
.columnNumber(lineAndColumn.column)
index.foreach(idx => methodParameterNode.index(idx))
methodParameterNode.dynamicTypeHintFullName(extractTypesFromHint(typeHint))
addNodeToDiff(methodParameterNode)
}

def methodParameterNode(
name: String,
isVariadic: Boolean,
lineAndColumn: LineAndColumn
): nodes.NewMethodParameterIn = {
val methodParameterNode = nodes
.NewMethodParameterIn()
.name(name)
.code(name)
.evaluationStrategy(EvaluationStrategies.BY_SHARING)
.typeFullName(Constants.ANY)
.isVariadic(isVariadic)
.lineNumber(lineAndColumn.line)
.columnNumber(lineAndColumn.column)
addNodeToDiff(methodParameterNode)
}
def extractTypesFromHint(typeHint: Option[ast.iexpr] = None): Seq[String] =
typeHint
.collect {
case n: ast.Name => n.id
// TODO: Definitely a place for follow up handling of generics - currently only take the polymorphic type
// without type args. To see the type arguments, see ast.Subscript.slice
case n: ast.Subscript if n.value.isInstanceOf[ast.Name] => n.value.asInstanceOf[ast.Name].id
}
.map { typeName =>
if (allBuiltinClasses.contains(typeName)) s"$builtinPrefix$typeName"
else if (typingClassesV3.contains(typeName)) s"$typingPrefix$typeName"
else typeName
}
.toSeq

def methodReturnNode(dynamicTypeHintFullName: Option[String], lineAndColumn: LineAndColumn): nodes.NewMethodReturn = {
val methodReturnNode = NodeBuilders
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import io.joern.pysrc2cpg.PythonAstVisitor.{builtinPrefix, metaClassSuffix}
import io.joern.pysrc2cpg.memop._
import io.joern.pythonparser.ast
import io.shiftleft.codepropertygraph.generated._
import io.shiftleft.codepropertygraph.generated.nodes.{NewIdentifier, NewMember, NewMethod, NewNode, NewTypeDecl}
import io.shiftleft.codepropertygraph.generated.nodes.{NewMethod, NewNode, NewTypeDecl}
import overflowdb.BatchedUpdate.DiffGraphBuilder

import scala.collection.mutable
Expand Down Expand Up @@ -353,6 +353,7 @@ class PythonAstVisitor(
}

val methodReturnNode = nodeBuilder.methodReturnNode(returnTypeHint, lineAndColumn)
methodReturnNode.dynamicTypeHintFullName(nodeBuilder.extractTypesFromHint(returns))
edgeBuilder.astEdge(methodReturnNode, methodNode, 2)

val bodyOrder = new AutoIncIndex(1)
Expand Down Expand Up @@ -553,7 +554,7 @@ class PythonAstVisitor(
parameterProvider = () => {
MethodParameters(
0,
nodeBuilder.methodParameterNode("cls", 0, isVariadic = false, lineAndColumn) :: Nil ++
nodeBuilder.methodParameterNode("cls", isVariadic = false, lineAndColumn, Option(0)) :: Nil ++
convert(parameters, 1)
)
},
Expand Down Expand Up @@ -692,7 +693,7 @@ class PythonAstVisitor(
parameterProvider = () => {
MethodParameters(
0,
nodeBuilder.methodParameterNode("cls", 0, isVariadic = false, lineAndColumn) :: Nil ++
nodeBuilder.methodParameterNode("cls", isVariadic = false, lineAndColumn, Some(0)) :: Nil ++
convert(parametersWithoutSelf, 1)
)
},
Expand Down Expand Up @@ -1905,15 +1906,27 @@ class PythonAstVisitor(
// will all be slightly different in the future when we can represent the
// different types in the cpg.
def convertPosOnlyArg(arg: ast.Arg, index: AutoIncIndex): nodes.NewMethodParameterIn = {
nodeBuilder.methodParameterNode(arg.arg, index.getAndInc, isVariadic = false, lineAndColOf(arg))
nodeBuilder.methodParameterNode(
arg.arg,
isVariadic = false,
lineAndColOf(arg),
Option(index.getAndInc),
arg.annotation
)
}

def convertNormalArg(arg: ast.Arg, index: AutoIncIndex): nodes.NewMethodParameterIn = {
nodeBuilder.methodParameterNode(arg.arg, index.getAndInc, isVariadic = false, lineAndColOf(arg))
nodeBuilder.methodParameterNode(
arg.arg,
isVariadic = false,
lineAndColOf(arg),
Option(index.getAndInc),
arg.annotation
)
}

def convertVarArg(arg: ast.Arg, index: AutoIncIndex): nodes.NewMethodParameterIn = {
nodeBuilder.methodParameterNode(arg.arg, index.getAndInc, isVariadic = true, lineAndColOf(arg))
nodeBuilder.methodParameterNode(arg.arg, isVariadic = true, lineAndColOf(arg), Option(index.getAndInc))
}

def convertKeywordOnlyArg(arg: ast.Arg): nodes.NewMethodParameterIn = {
Expand Down Expand Up @@ -1942,6 +1955,7 @@ class PythonAstVisitor(

object PythonAstVisitor {
val builtinPrefix = "__builtin."
val typingPrefix = "typing."
val metaClassSuffix = "<meta>"

// This list contains all functions from https://docs.python.org/3/library/functions.html#built-in-funcs
Expand Down Expand Up @@ -2107,4 +2121,86 @@ object PythonAstVisitor {
"unicode",
"xrange"
)

lazy val allBuiltinClasses: Set[String] = (builtinClassesV2 ++ builtinClassesV3).toSet

lazy val typingClassesV3: Set[String] = Set(
"Annotated",
"Any",
"Callable",
"ClassVar",
"Final",
"ForwardRef",
"Generic",
"Literal",
"Optional",
"Protocol",
"Tuple",
"Type",
"TypeVar",
"Union",
"AbstractSet",
"ByteString",
"Container",
"ContextManager",
"Hashable",
"ItemsView",
"Iterable",
"Iterator",
"KeysView",
"Mapping",
"MappingView",
"MutableMapping",
"MutableSequence",
"MutableSet",
"Sequence",
"Sized",
"ValuesView",
"Awaitable",
"AsyncIterator",
"AsyncIterable",
"Coroutine",
"Collection",
"AsyncGenerator",
"AsyncContextManager",
"Reversible",
"SupportsAbs",
"SupportsBytes",
"SupportsComplex",
"SupportsFloat",
"SupportsIndex",
"SupportsInt",
"SupportsRound",
"ChainMap",
"Counter",
"Deque",
"Dict",
"DefaultDict",
"List",
"OrderedDict",
"Set",
"FrozenSet",
"NamedTuple",
"TypedDict",
"Generator",
"BinaryIO",
"IO",
"Match",
"Pattern",
"TextIO",
"AnyStr",
"cast",
"final",
"get_args",
"get_origin",
"get_type_hints",
"NewType",
"no_type_check",
"no_type_check_decorator",
"NoReturn",
"overload",
"runtime_checkable",
"Text",
"TYPE_CHECKING"
)
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package io.joern.pysrc2cpg.cpg

import io.joern.pysrc2cpg.Constants
import io.joern.pysrc2cpg.Py2CpgTestContext
import io.joern.pysrc2cpg.{Constants, Py2CpgTestContext}
import io.shiftleft.semanticcpg.language._
import org.scalatest.freespec.AnyFreeSpec
import org.scalatest.matchers.should.Matchers
Expand Down Expand Up @@ -136,4 +135,53 @@ class FunctionDefCpgTests extends AnyFreeSpec with Matchers {

}

"type hinted function" - {
lazy val cpg = Py2CpgTestContext.buildCpg("""
|from typing import List, Optional
|
|def func1(a: int, b: int) -> float:
| return a / b
|
|def func2(a: Optional[str] = None) -> List[Union[str | None]]:
| return [a]
|""".stripMargin)

"test parameter hint of method definition using built-in types" in {
cpg.method
.name("func1")
.parameter
.dynamicTypeHintFullName
.dedup
.l shouldBe Seq("__builtin.int")
}

"test parameter hint of method definition using types from 'typing'" in {
cpg.method
.name("func2")
.parameter
.dynamicTypeHintFullName
.dedup
.l shouldBe Seq("typing.Optional")
}

"test return hint of method definition using built-in types" in {
cpg.method
.name("func1")
.methodReturn
.dynamicTypeHintFullName
.dedup
.l shouldBe Seq("__builtin.float")
}

"test a return hint of method definition using types from 'typing'" in {
cpg.method
.name("func2")
.methodReturn
.dynamicTypeHintFullName
.dedup
.l shouldBe Seq("typing.List")
}

}

}

0 comments on commit b8ed0de

Please sign in to comment.