Skip to content

Commit 485cfd7

Browse files
Using ForkJoin pass instead of ConcurrentWriter
1 parent 424ba5a commit 485cfd7

File tree

4 files changed

+115
-98
lines changed

4 files changed

+115
-98
lines changed

astcreator/src/main/scala/com/github/plume/oss/passes/IncrementalKeyPool.scala

-38
This file was deleted.

astcreator/src/main/scala/com/github/plume/oss/passes/PlumeConcurrentWriterPass.scala

-58
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
package com.github.plume.oss.passes
2+
3+
import com.github.plume.oss.drivers.IDriver
4+
import io.shiftleft.SerializedCpg
5+
import io.shiftleft.codepropertygraph.generated.Cpg
6+
import io.shiftleft.utils.ExecutionContextProvider
7+
import io.shiftleft.codepropertygraph.generated.nodes.AbstractNode
8+
import io.shiftleft.passes.CpgPassBase
9+
import overflowdb.BatchedUpdate.DiffGraphBuilder
10+
11+
import java.util.function.*
12+
import scala.annotation.nowarn
13+
import scala.collection.mutable
14+
import scala.concurrent.duration.Duration
15+
import scala.concurrent.{Await, ExecutionContext, Future}
16+
17+
abstract class PlumeForkJoinParallelCpgPass[T <: AnyRef](driver: IDriver, @nowarn outName: String = "")
18+
extends CpgPassBase {
19+
20+
// generate Array of parts that can be processed in parallel
21+
def generateParts(): Array[? <: AnyRef]
22+
23+
// setup large data structures, acquire external resources
24+
def init(): Unit = {}
25+
26+
// release large data structures and external resources
27+
def finish(): Unit = {}
28+
29+
// main function: add desired changes to builder
30+
def runOnPart(builder: DiffGraphBuilder, part: T): Unit
31+
32+
// Override this to disable parallelism of passes. Useful for debugging.
33+
def isParallel: Boolean = true
34+
35+
override def createAndApply(): Unit = createApplySerializeAndStore(null)
36+
37+
override def runWithBuilder(externalBuilder: DiffGraphBuilder): Int = {
38+
try {
39+
init()
40+
val parts = generateParts()
41+
val nParts = parts.size
42+
nParts match {
43+
case 0 =>
44+
case 1 =>
45+
runOnPart(externalBuilder, parts(0).asInstanceOf[T])
46+
case _ =>
47+
val stream =
48+
if (!isParallel)
49+
java.util.Arrays
50+
.stream(parts)
51+
.sequential()
52+
else
53+
java.util.Arrays
54+
.stream(parts)
55+
.parallel()
56+
val diff = stream.collect(
57+
new Supplier[DiffGraphBuilder] {
58+
override def get(): DiffGraphBuilder =
59+
Cpg.newDiffGraphBuilder
60+
},
61+
new BiConsumer[DiffGraphBuilder, AnyRef] {
62+
override def accept(builder: DiffGraphBuilder, part: AnyRef): Unit =
63+
runOnPart(builder, part.asInstanceOf[T])
64+
},
65+
new BiConsumer[DiffGraphBuilder, DiffGraphBuilder] {
66+
override def accept(leftBuilder: DiffGraphBuilder, rightBuilder: DiffGraphBuilder): Unit =
67+
leftBuilder.absorb(rightBuilder)
68+
}
69+
)
70+
externalBuilder.absorb(diff)
71+
}
72+
nParts
73+
} finally {
74+
finish()
75+
}
76+
}
77+
78+
override def createApplySerializeAndStore(serializedCpg: SerializedCpg, prefix: String = ""): Unit = {
79+
baseLogger.info(s"Start of pass: $name")
80+
val nanosStart = System.nanoTime()
81+
var nParts = 0
82+
var nanosBuilt = -1L
83+
var nDiff = -1
84+
var nDiffT = -1
85+
try {
86+
val diffGraph = Cpg.newDiffGraphBuilder
87+
nParts = runWithBuilder(diffGraph)
88+
nanosBuilt = System.nanoTime()
89+
nDiff = diffGraph.size
90+
driver.bulkTx(diffGraph)
91+
} catch {
92+
case exc: Exception =>
93+
baseLogger.error(s"Pass ${name} failed", exc)
94+
throw exc
95+
} finally {
96+
try {
97+
finish()
98+
} finally {
99+
// the nested finally is somewhat ugly -- but we promised to clean up with finish(), we want to include finish()
100+
// in the reported timings, and we must have our final log message if finish() throws
101+
val nanosStop = System.nanoTime()
102+
val fracRun = if (nanosBuilt == -1) 0.0 else (nanosStop - nanosBuilt) * 100.0 / (nanosStop - nanosStart + 1)
103+
val serializationString = if (serializedCpg != null && !serializedCpg.isEmpty) {
104+
" Diff serialized and stored."
105+
} else ""
106+
baseLogger.info(
107+
f"Pass $name completed in ${(nanosStop - nanosStart) * 1e-6}%.0f ms (${fracRun}%.0f%% on mutations). ${nDiff}%d + ${nDiffT - nDiff}%d changes committed from ${nParts}%d parts.${serializationString}%s"
108+
)
109+
}
110+
}
111+
}
112+
113+
}

astcreator/src/main/scala/com/github/plume/oss/passes/base/AstCreationPass.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ package com.github.plume.oss.passes.base
33
import better.files.File
44
import com.github.plume.oss.JimpleAst2Database
55
import com.github.plume.oss.drivers.IDriver
6-
import com.github.plume.oss.passes.PlumeConcurrentWriterPass
6+
import com.github.plume.oss.passes.PlumeForkJoinParallelCpgPass
77
import io.joern.x2cpg.ValidationMode
88
import io.joern.x2cpg.datastructures.Global
99
import org.slf4j.LoggerFactory
@@ -15,7 +15,7 @@ import java.nio.file.Paths
1515
/** Creates the AST layer from the given class file and stores all types in the given global parameter.
1616
*/
1717
class AstCreationPass(filenames: List[String], driver: IDriver, unpackingRoot: File)
18-
extends PlumeConcurrentWriterPass[String](driver) {
18+
extends PlumeForkJoinParallelCpgPass[String](driver) {
1919

2020
val global: Global = new Global()
2121
private val logger = LoggerFactory.getLogger(classOf[AstCreationPass])

0 commit comments

Comments
 (0)