diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/JavaSrc2Cpg.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/JavaSrc2Cpg.scala index 37c0b9380f52..0c19cd972077 100644 --- a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/JavaSrc2Cpg.scala +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/JavaSrc2Cpg.scala @@ -1,6 +1,6 @@ package io.joern.javasrc2cpg -import io.joern.javasrc2cpg.passes.{AstCreationPass, TypeInferencePass} +import io.joern.javasrc2cpg.passes.{AstCreationPass, OuterClassRefPass, TypeInferencePass} import io.joern.x2cpg.X2Cpg.withNewEmptyCpg import io.joern.x2cpg.passes.frontend.{JavaConfigFileCreationPass, MetaDataPass, TypeNodePass} import io.joern.x2cpg.X2CpgFrontend @@ -22,6 +22,7 @@ class JavaSrc2Cpg extends X2CpgFrontend[Config] { astCreationPass.createAndApply() astCreationPass.sourceParser.cleanupDelombokOutput() astCreationPass.clearJavaParserCaches() + new OuterClassRefPass(cpg).createAndApply() JavaConfigFileCreationPass(cpg).createAndApply() if (!config.skipTypeInfPass) { TypeNodePass.withRegisteredTypes(astCreationPass.global.usedTypes.keys().asScala.toList, cpg).createAndApply() diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/declarations/AstForTypeDeclsCreator.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/declarations/AstForTypeDeclsCreator.scala index cb86444af5d5..cc15b3a09370 100644 --- a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/declarations/AstForTypeDeclsCreator.scala +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/astcreation/declarations/AstForTypeDeclsCreator.scala @@ -453,8 +453,10 @@ private[declarations] trait AstForTypeDeclsCreator { this: AstCreator => receiverAst.root.foreach(receiver => diffGraph.addEdge(initRoot, receiver, EdgeTypes.RECEIVER)) val capturesAsts = - usedCaptures.filterNot(outerClassAst.isDefined && _.name == NameConstants.OuterClass).zipWithIndex.map { - (usedCapture, index) => + usedCaptures + .filterNot(outerClassAst.isDefined && _.name == NameConstants.OuterClass) + .zipWithIndex + .map { (usedCapture, index) => val identifier = NewIdentifier() .name(usedCapture.name) .code(usedCapture.name) @@ -462,10 +464,10 @@ private[declarations] trait AstForTypeDeclsCreator { this: AstCreator => .lineNumber(initRoot.lineNumber) .columnNumber(initRoot.columnNumber) - diffGraph.addEdge(identifier, usedCapture.node, EdgeTypes.REF) + val refsTo = Option.when(usedCapture.name != NameConstants.OuterClass)(usedCapture.node) - Ast(identifier) - } + Ast(identifier).withRefEdges(identifier, refsTo.toList) + } (receiverAst :: args ++ outerClassAst.toList ++ capturesAsts) .map { argAst => diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/passes/OuterClassRefPass.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/passes/OuterClassRefPass.scala new file mode 100644 index 000000000000..1c4c6023644b --- /dev/null +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/passes/OuterClassRefPass.scala @@ -0,0 +1,24 @@ +package io.joern.javasrc2cpg.passes; + +import io.joern.javasrc2cpg.util.NameConstants +import io.joern.x2cpg.Defines +import io.shiftleft.codepropertygraph.Cpg +import io.shiftleft.codepropertygraph.generated.EdgeTypes +import io.shiftleft.codepropertygraph.generated.nodes.{Method, TypeDecl} +import io.shiftleft.passes.ForkJoinParallelCpgPass +import io.shiftleft.semanticcpg.language.* + +class OuterClassRefPass(cpg: Cpg) extends ForkJoinParallelCpgPass[TypeDecl](cpg) { + override def generateParts(): Array[TypeDecl] = cpg.typeDecl.toArray + + override def runOnPart(diffGraph: DiffGraphBuilder, typeDecl: TypeDecl): Unit = { + typeDecl.method.nameExact(Defines.ConstructorMethodName).foreach { constructor => + constructor.ast.isIdentifier.nameExact(NameConstants.OuterClass).filter(_.refsTo.isEmpty).foreach { + outerClassIdentifier => + constructor.parameter.nameExact(NameConstants.OuterClass).foreach { outerClassParam => + diffGraph.addEdge(outerClassIdentifier, outerClassParam, EdgeTypes.REF) + } + } + } + } +} diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/LocalClassTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/LocalClassTests.scala index 086f70e3fd8d..e44ab3d70906 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/LocalClassTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/LocalClassTests.scala @@ -744,6 +744,17 @@ class LocalClassTests extends JavaSrcCode2CpgFixture { @inline def constructors = cpg.typeDecl.fullName("foo.Foo.enclosingMethod.Local").method.nameExact("").sortBy(_.parameter.size) + "not have any orphan locals or parameters" in { + cpg.local.filter(_.astIn.isEmpty).l shouldBe List() + cpg.parameter.filter(_.astIn.isEmpty).l shouldBe List() + } + + "have ref edges from the outer class identifier to the parameter" in { + inside(cpg.method.nameExact("").filter(_.parameter.name.contains("ctxParam")).l) { case List(constructor) => + constructor.ast.isIdentifier.name("outerClass").refsTo.l shouldBe constructor.parameter.name("outerClass").l + } + } + "have params for captured members for both constructors" in { constructors.head.parameter.name.l shouldBe List("this", "outerClass", "outerParam") constructors.last.parameter.name.l shouldBe List("this", "ctxParam", "outerClass", "outerParam") @@ -869,7 +880,7 @@ class LocalClassTests extends JavaSrcCode2CpgFixture { case List(call) => call.methodFullName shouldBe "foo.Foo.fooMethod.Local.:void(int)" call.signature shouldBe "void(int)" - pendingUntilFixed(call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH) + // pendingUntilFixed(call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH) case result => fail(s"Unexpected result ${result}") } } @@ -940,7 +951,7 @@ class LocalClassTests extends JavaSrcCode2CpgFixture { case List(call) => call.methodFullName shouldBe "foo.Foo.foo.Local.:void()" call.signature shouldBe "void()" - pendingUntilFixed(call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH) + // pendingUntilFixed(call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH) case result => fail(s"Unexpected result ${result}") } } @@ -1001,7 +1012,7 @@ class LocalClassTests extends JavaSrcCode2CpgFixture { case List(call) => call.methodFullName shouldBe "foo.Foo.fooMethod.Local.:void()" call.signature shouldBe "void()" - pendingUntilFixed(call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH) + // pendingUntilFixed(call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH) case result => fail(s"Unexpected result ${result}") } } @@ -1055,7 +1066,7 @@ class LocalClassTests extends JavaSrcCode2CpgFixture { case List(call) => call.methodFullName shouldBe "foo.Foo.fooMethod.Local.:void()" call.signature shouldBe "void()" - pendingUntilFixed(call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH) + // pendingUntilFixed(call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH) case result => fail(s"Unexpected result ${result}") } } @@ -1108,7 +1119,7 @@ class LocalClassTests extends JavaSrcCode2CpgFixture { case List(call) => call.methodFullName shouldBe "foo.Foo.fooMethod.Local.:void(int)" call.signature shouldBe "void(int)" - pendingUntilFixed(call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH) + // pendingUntilFixed(call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH) case result => fail(s"Unexpected result ${result}") } } @@ -1168,7 +1179,7 @@ class LocalClassTests extends JavaSrcCode2CpgFixture { case List(call) => call.methodFullName shouldBe "foo.Foo.fooMethod.Local.:void()" call.signature shouldBe "void()" - pendingUntilFixed(call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH) + // pendingUntilFixed(call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH) case result => fail(s"Unexpected result ${result}") } } @@ -1217,7 +1228,7 @@ class LocalClassTests extends JavaSrcCode2CpgFixture { case List(call) => call.methodFullName shouldBe "foo.Foo.fooMethod.Local.:void(int)" call.signature shouldBe "void(int)" - pendingUntilFixed(call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH) + // pendingUntilFixed(call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH) case result => fail(s"Unexpected result ${result}") } }