|
| 1 | +package com.github.plume.oss.passes |
| 2 | + |
| 3 | +import com.github.plume.oss.drivers.IDriver |
| 4 | +import io.shiftleft.SerializedCpg |
| 5 | +import io.shiftleft.codepropertygraph.generated.Cpg |
| 6 | +import io.shiftleft.utils.ExecutionContextProvider |
| 7 | +import io.shiftleft.codepropertygraph.generated.nodes.AbstractNode |
| 8 | +import io.shiftleft.passes.CpgPassBase |
| 9 | +import overflowdb.BatchedUpdate.DiffGraphBuilder |
| 10 | + |
| 11 | +import java.util.function.* |
| 12 | +import scala.annotation.nowarn |
| 13 | +import scala.collection.mutable |
| 14 | +import scala.concurrent.duration.Duration |
| 15 | +import scala.concurrent.{Await, ExecutionContext, Future} |
| 16 | + |
| 17 | +abstract class PlumeForkJoinParallelCpgPass[T <: AnyRef](driver: IDriver, @nowarn outName: String = "") |
| 18 | + extends CpgPassBase { |
| 19 | + |
| 20 | + // generate Array of parts that can be processed in parallel |
| 21 | + def generateParts(): Array[? <: AnyRef] |
| 22 | + |
| 23 | + // setup large data structures, acquire external resources |
| 24 | + def init(): Unit = {} |
| 25 | + |
| 26 | + // release large data structures and external resources |
| 27 | + def finish(): Unit = {} |
| 28 | + |
| 29 | + // main function: add desired changes to builder |
| 30 | + def runOnPart(builder: DiffGraphBuilder, part: T): Unit |
| 31 | + |
| 32 | + // Override this to disable parallelism of passes. Useful for debugging. |
| 33 | + def isParallel: Boolean = true |
| 34 | + |
| 35 | + override def createAndApply(): Unit = createApplySerializeAndStore(null) |
| 36 | + |
| 37 | + override def runWithBuilder(externalBuilder: DiffGraphBuilder): Int = { |
| 38 | + try { |
| 39 | + init() |
| 40 | + val parts = generateParts() |
| 41 | + val nParts = parts.size |
| 42 | + nParts match { |
| 43 | + case 0 => |
| 44 | + case 1 => |
| 45 | + runOnPart(externalBuilder, parts(0).asInstanceOf[T]) |
| 46 | + case _ => |
| 47 | + val stream = |
| 48 | + if (!isParallel) |
| 49 | + java.util.Arrays |
| 50 | + .stream(parts) |
| 51 | + .sequential() |
| 52 | + else |
| 53 | + java.util.Arrays |
| 54 | + .stream(parts) |
| 55 | + .parallel() |
| 56 | + val diff = stream.collect( |
| 57 | + new Supplier[DiffGraphBuilder] { |
| 58 | + override def get(): DiffGraphBuilder = |
| 59 | + Cpg.newDiffGraphBuilder |
| 60 | + }, |
| 61 | + new BiConsumer[DiffGraphBuilder, AnyRef] { |
| 62 | + override def accept(builder: DiffGraphBuilder, part: AnyRef): Unit = |
| 63 | + runOnPart(builder, part.asInstanceOf[T]) |
| 64 | + }, |
| 65 | + new BiConsumer[DiffGraphBuilder, DiffGraphBuilder] { |
| 66 | + override def accept(leftBuilder: DiffGraphBuilder, rightBuilder: DiffGraphBuilder): Unit = |
| 67 | + leftBuilder.absorb(rightBuilder) |
| 68 | + } |
| 69 | + ) |
| 70 | + externalBuilder.absorb(diff) |
| 71 | + } |
| 72 | + nParts |
| 73 | + } finally { |
| 74 | + finish() |
| 75 | + } |
| 76 | + } |
| 77 | + |
| 78 | + override def createApplySerializeAndStore(serializedCpg: SerializedCpg, prefix: String = ""): Unit = { |
| 79 | + baseLogger.info(s"Start of pass: $name") |
| 80 | + val nanosStart = System.nanoTime() |
| 81 | + var nParts = 0 |
| 82 | + var nanosBuilt = -1L |
| 83 | + var nDiff = -1 |
| 84 | + var nDiffT = -1 |
| 85 | + try { |
| 86 | + val diffGraph = Cpg.newDiffGraphBuilder |
| 87 | + nParts = runWithBuilder(diffGraph) |
| 88 | + nanosBuilt = System.nanoTime() |
| 89 | + nDiff = diffGraph.size |
| 90 | + driver.bulkTx(diffGraph) |
| 91 | + } catch { |
| 92 | + case exc: Exception => |
| 93 | + baseLogger.error(s"Pass ${name} failed", exc) |
| 94 | + throw exc |
| 95 | + } finally { |
| 96 | + try { |
| 97 | + finish() |
| 98 | + } finally { |
| 99 | + // the nested finally is somewhat ugly -- but we promised to clean up with finish(), we want to include finish() |
| 100 | + // in the reported timings, and we must have our final log message if finish() throws |
| 101 | + val nanosStop = System.nanoTime() |
| 102 | + val fracRun = if (nanosBuilt == -1) 0.0 else (nanosStop - nanosBuilt) * 100.0 / (nanosStop - nanosStart + 1) |
| 103 | + val serializationString = if (serializedCpg != null && !serializedCpg.isEmpty) { |
| 104 | + " Diff serialized and stored." |
| 105 | + } else "" |
| 106 | + baseLogger.info( |
| 107 | + f"Pass $name completed in ${(nanosStop - nanosStart) * 1e-6}%.0f ms (${fracRun}%.0f%% on mutations). ${nDiff}%d + ${nDiffT - nDiff}%d changes committed from ${nParts}%d parts.${serializationString}%s" |
| 108 | + ) |
| 109 | + } |
| 110 | + } |
| 111 | + } |
| 112 | + |
| 113 | +} |
0 commit comments