Skip to content

Commit 2f1669e

Browse files
committed
draft
1 parent fef1b23 commit 2f1669e

File tree

8 files changed

+238
-95
lines changed

8 files changed

+238
-95
lines changed

Diff for: sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala

+9
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,15 @@ object StaticSQLConf {
210210
.checkValue(thres => thres > 0 && thres <= 128, "The threshold must be in (0,128].")
211211
.createWithDefault(16)
212212

213+
val RESULT_QUERY_STAGE_MAX_THREAD_THRESHOLD =
214+
buildStaticConf("spark.sql.resultQueryStage.maxThreadThreshold")
215+
.internal()
216+
.doc("The maximum degree of parallelism to execute ResultQueryStageExec in AQE")
217+
.version("4.0.0")
218+
.intConf
219+
.checkValue(thres => thres > 0 && thres <= 1024, "The threshold must be in (0,1024].")
220+
.createWithDefault(1024)
221+
213222
val SQL_EVENT_TRUNCATE_LENGTH = buildStaticConf("spark.sql.event.truncate.length")
214223
.doc("Threshold of SQL length beyond which it will be truncated before adding to " +
215224
"event. Defaults to no truncation. If set to 0, callsite will be logged instead.")

Diff for: sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala

+4-4
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.execution
1919

20-
import java.util.concurrent.{ConcurrentHashMap, ExecutorService, Future => JFuture}
20+
import java.util.concurrent.{CompletableFuture, ConcurrentHashMap, ExecutorService}
2121
import java.util.concurrent.atomic.AtomicLong
2222

2323
import scala.jdk.CollectionConverters._
@@ -301,15 +301,15 @@ object SQLExecution extends Logging {
301301
* SparkContext local properties are forwarded to execution thread
302302
*/
303303
def withThreadLocalCaptured[T](
304-
sparkSession: SparkSession, exec: ExecutorService) (body: => T): JFuture[T] = {
304+
sparkSession: SparkSession, exec: ExecutorService) (body: => T): CompletableFuture[T] = {
305305
val activeSession = sparkSession
306306
val sc = sparkSession.sparkContext
307307
val localProps = Utils.cloneProperties(sc.getLocalProperties)
308308
// `getCurrentJobArtifactState` will return a stat only in Spark Connect mode. In non-Connect
309309
// mode, we default back to the resources of the current Spark session.
310310
val artifactState = JobArtifactSet.getCurrentJobArtifactState.getOrElse(
311311
activeSession.artifactManager.state)
312-
exec.submit(() => JobArtifactSet.withActiveJobArtifactState(artifactState) {
312+
CompletableFuture.supplyAsync(() => JobArtifactSet.withActiveJobArtifactState(artifactState) {
313313
val originalSession = SparkSession.getActiveSession
314314
val originalLocalProps = sc.getLocalProperties
315315
SparkSession.setActiveSession(activeSession)
@@ -326,6 +326,6 @@ object SQLExecution extends Logging {
326326
SparkSession.clearActiveSession()
327327
}
328328
res
329-
})
329+
}, exec)
330330
}
331331
}

Diff for: sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala

+112-59
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,14 @@ case class AdaptiveSparkPlanExec(
236236

237237
@volatile private var currentPhysicalPlan = initialPlan
238238

239+
// Use inputPlan logicalLink here in case some top level physical nodes may be removed
240+
// during `initialPlan`
241+
@transient @volatile private var currentLogicalPlan: LogicalPlan = {
242+
inputPlan.logicalLink.get
243+
}
244+
245+
val stagesToReplace = mutable.ArrayBuffer.empty[QueryStageExec]
246+
239247
@volatile private var _isFinalPlan = false
240248

241249
private var currentStageId = 0
@@ -289,26 +297,24 @@ case class AdaptiveSparkPlanExec(
289297

290298
def finalPhysicalPlan: SparkPlan = withFinalPlanUpdate(identity)
291299

292-
private def getFinalPhysicalPlan(): SparkPlan = lock.synchronized {
293-
if (isFinalPlan) return currentPhysicalPlan
294-
300+
/**
301+
* Run `fun` on finalized physical plan
302+
*/
303+
def withFinalPlanUpdate[T](fun: SparkPlan => T): T = lock.synchronized {
304+
_isFinalPlan = false
295305
// In case of this adaptive plan being executed out of `withActive` scoped functions, e.g.,
296306
// `plan.queryExecution.rdd`, we need to set active session here as new plan nodes can be
297307
// created in the middle of the execution.
298308
context.session.withActive {
299309
val executionId = getExecutionId
300-
// Use inputPlan logicalLink here in case some top level physical nodes may be removed
301-
// during `initialPlan`
302-
var currentLogicalPlan = inputPlan.logicalLink.get
303-
var result = createQueryStages(currentPhysicalPlan)
310+
var result = createQueryStages(fun, currentPhysicalPlan, true)
304311
val events = new LinkedBlockingQueue[StageMaterializationEvent]()
305312
val errors = new mutable.ArrayBuffer[Throwable]()
306-
var stagesToReplace = Seq.empty[QueryStageExec]
307313
while (!result.allChildStagesMaterialized) {
308314
ruleContext.clearConfigs()
309315
currentPhysicalPlan = result.newPlan
310316
if (result.newStages.nonEmpty) {
311-
stagesToReplace = result.newStages ++ stagesToReplace
317+
stagesToReplace ++= result.newStages
312318
executionId.foreach(onUpdatePlan(_, result.newStages.map(_.plan)))
313319

314320
// SPARK-33933: we should submit tasks of broadcast stages first, to avoid waiting
@@ -366,50 +372,44 @@ case class AdaptiveSparkPlanExec(
366372
if (errors.nonEmpty) {
367373
cleanUpAndThrowException(errors.toSeq, None)
368374
}
369-
370-
// Try re-optimizing and re-planning. Adopt the new plan if its cost is equal to or less
371-
// than that of the current plan; otherwise keep the current physical plan together with
372-
// the current logical plan since the physical plan's logical links point to the logical
373-
// plan it has originated from.
374-
// Meanwhile, we keep a list of the query stages that have been created since last plan
375-
// update, which stands for the "semantic gap" between the current logical and physical
376-
// plans. And each time before re-planning, we replace the corresponding nodes in the
377-
// current logical plan with logical query stages to make it semantically in sync with
378-
// the current physical plan. Once a new plan is adopted and both logical and physical
379-
// plans are updated, we can clear the query stage list because at this point the two plans
380-
// are semantically and physically in sync again.
381-
val logicalPlan = replaceWithQueryStagesInLogicalPlan(currentLogicalPlan, stagesToReplace)
382-
val afterReOptimize = reOptimize(logicalPlan)
383-
if (afterReOptimize.isDefined) {
384-
val (newPhysicalPlan, newLogicalPlan) = afterReOptimize.get
385-
val origCost = costEvaluator.evaluateCost(currentPhysicalPlan)
386-
val newCost = costEvaluator.evaluateCost(newPhysicalPlan)
387-
if (newCost < origCost ||
388-
(newCost == origCost && currentPhysicalPlan != newPhysicalPlan)) {
389-
lazy val plans =
390-
sideBySide(currentPhysicalPlan.treeString, newPhysicalPlan.treeString).mkString("\n")
391-
logOnLevel(log"Plan changed:\n${MDC(QUERY_PLAN, plans)}")
392-
cleanUpTempTags(newPhysicalPlan)
393-
currentPhysicalPlan = newPhysicalPlan
394-
currentLogicalPlan = newLogicalPlan
395-
stagesToReplace = Seq.empty[QueryStageExec]
375+
if (!currentPhysicalPlan.isInstanceOf[ResultQueryStageExec]) {
376+
// Try re-optimizing and re-planning. Adopt the new plan if its cost is equal to or less
377+
// than that of the current plan; otherwise keep the current physical plan together with
378+
// the current logical plan since the physical plan's logical links point to the logical
379+
// plan it has originated from.
380+
// Meanwhile, we keep a list of the query stages that have been created since last plan
381+
// update, which stands for the "semantic gap" between the current logical and physical
382+
// plans. And each time before re-planning, we replace the corresponding nodes in the
383+
// current logical plan with logical query stages to make it semantically in sync with
384+
// the current physical plan. Once a new plan is adopted and both logical and physical
385+
// plans are updated, we can clear the query stage list because at this point the two
386+
// plans are semantically and physically in sync again.
387+
val logicalPlan = replaceWithQueryStagesInLogicalPlan(currentLogicalPlan,
388+
stagesToReplace.toSeq)
389+
val afterReOptimize = reOptimize(logicalPlan)
390+
if (afterReOptimize.isDefined) {
391+
val (newPhysicalPlan, newLogicalPlan) = afterReOptimize.get
392+
val origCost = costEvaluator.evaluateCost(currentPhysicalPlan)
393+
val newCost = costEvaluator.evaluateCost(newPhysicalPlan)
394+
if (newCost < origCost ||
395+
(newCost == origCost && currentPhysicalPlan != newPhysicalPlan)) {
396+
lazy val plans = sideBySide(
397+
currentPhysicalPlan.treeString, newPhysicalPlan.treeString).mkString("\n")
398+
logOnLevel(log"Plan changed:\n${MDC(QUERY_PLAN, plans)}")
399+
cleanUpTempTags(newPhysicalPlan)
400+
currentPhysicalPlan = newPhysicalPlan
401+
currentLogicalPlan = newLogicalPlan
402+
stagesToReplace.clear()
403+
}
396404
}
405+
// Now that some stages have finished, we can try creating new stages.
406+
result = createQueryStages(fun, currentPhysicalPlan, false)
397407
}
398-
// Now that some stages have finished, we can try creating new stages.
399-
result = createQueryStages(currentPhysicalPlan)
400408
}
401-
402-
ruleContext = ruleContext.withFinalStage(isFinalStage = true)
403-
// Run the final plan when there's no more unfinished stages.
404-
currentPhysicalPlan = applyPhysicalRulesWithRuleContext(
405-
optimizeQueryStage(result.newPlan, isFinalStage = true),
406-
postStageCreationRules(supportsColumnar),
407-
Some((planChangeLogger, "AQE Post Stage Creation")))
408-
ruleContext.clearConfigs()
409-
_isFinalPlan = true
410-
executionId.foreach(onUpdatePlan(_, Seq(currentPhysicalPlan)))
411-
currentPhysicalPlan
412409
}
410+
_isFinalPlan = true
411+
finalPlanUpdate
412+
currentPhysicalPlan.asInstanceOf[ResultQueryStageExec].resultOption.get().get.asInstanceOf[T]
413413
}
414414

415415
// Use a lazy val to avoid this being called more than once.
@@ -450,13 +450,6 @@ case class AdaptiveSparkPlanExec(
450450
}
451451
}
452452

453-
private def withFinalPlanUpdate[T](fun: SparkPlan => T): T = {
454-
val plan = getFinalPhysicalPlan()
455-
val result = fun(plan)
456-
finalPlanUpdate
457-
result
458-
}
459-
460453
protected override def stringArgs: Iterator[Any] = Iterator(s"isFinalPlan=$isFinalPlan")
461454

462455
override def generateTreeString(
@@ -545,6 +538,66 @@ case class AdaptiveSparkPlanExec(
545538
this.inputPlan == obj.asInstanceOf[AdaptiveSparkPlanExec].inputPlan
546539
}
547540

541+
/**
542+
* This method is a wrapper of `createQueryStagesInternal`, which deals with result stage creation
543+
*/
544+
private def createQueryStages(
545+
resultHandler: SparkPlan => Any,
546+
plan: SparkPlan,
547+
firstRun: Boolean): CreateStageResult = {
548+
plan match {
549+
case resultStage@ResultQueryStageExec(_, optimizedPlan, _) =>
550+
return if (firstRun) {
551+
// There is already an existing ResultQueryStage created in previous `withFinalPlanUpdate`
552+
val newResultStage = ResultQueryStageExec(currentStageId, optimizedPlan, resultHandler)
553+
currentStageId += 1
554+
setLogicalLinkForNewQueryStage(newResultStage, optimizedPlan)
555+
stagesToReplace.append(newResultStage)
556+
CreateStageResult(newPlan = newResultStage,
557+
allChildStagesMaterialized = false,
558+
newStages = Seq(newResultStage))
559+
} else {
560+
// result stage already created, do nothing
561+
CreateStageResult(newPlan = plan,
562+
allChildStagesMaterialized = resultStage.isMaterialized,
563+
newStages = Seq.empty)
564+
}
565+
case _ =>
566+
}
567+
val result = createQueryStagesInternal(plan)
568+
var allNewStages = result.newStages
569+
var newPlan = result.newPlan
570+
var allChildStagesMaterialized = result.allChildStagesMaterialized
571+
// Create result stage
572+
if (allNewStages.isEmpty && allChildStagesMaterialized) {
573+
val resultStage = createResultQueryStage(resultHandler, newPlan)
574+
stagesToReplace.append(resultStage)
575+
newPlan = resultStage
576+
allChildStagesMaterialized = false
577+
allNewStages :+= resultStage
578+
}
579+
CreateStageResult(
580+
newPlan = newPlan,
581+
allChildStagesMaterialized = allChildStagesMaterialized,
582+
newStages = allNewStages)
583+
}
584+
585+
private def createResultQueryStage(
586+
resultHandler: SparkPlan => Any,
587+
plan: SparkPlan): ResultQueryStageExec = {
588+
ruleContext = ruleContext.withFinalStage(isFinalStage = true)
589+
// Run the final plan when there's no more unfinished stages.
590+
val optimizedRootPlan = applyPhysicalRulesWithRuleContext(
591+
optimizeQueryStage(plan, isFinalStage = true),
592+
postStageCreationRules(supportsColumnar),
593+
Some((planChangeLogger, "AQE Post Stage Creation")))
594+
ruleContext.clearConfigs()
595+
val resultStage = ResultQueryStageExec(currentStageId, optimizedRootPlan, resultHandler)
596+
currentStageId += 1
597+
setLogicalLinkForNewQueryStage(resultStage, plan)
598+
resultStage
599+
}
600+
548601
/**
549602
* This method is called recursively to traverse the plan tree bottom-up and create a new query
550603
* stage or try reusing an existing stage if the current node is an [[Exchange]] node and all of
@@ -555,7 +608,7 @@ case class AdaptiveSparkPlanExec(
555608
* 2) Whether the child query stages (if any) of the current node have all been materialized.
556609
* 3) A list of the new query stages that have been created.
557610
*/
558-
private def createQueryStages(plan: SparkPlan): CreateStageResult = plan match {
611+
private def createQueryStagesInternal(plan: SparkPlan): CreateStageResult = plan match {
559612
case e: Exchange =>
560613
// First have a quick check in the `stageCache` without having to traverse down the node.
561614
context.stageCache.get(e.canonicalized) match {
@@ -568,7 +621,7 @@ case class AdaptiveSparkPlanExec(
568621
newStages = if (isMaterialized) Seq.empty else Seq(stage))
569622

570623
case _ =>
571-
val result = createQueryStages(e.child)
624+
val result = createQueryStagesInternal(e.child)
572625
val newPlan = e.withNewChildren(Seq(result.newPlan)).asInstanceOf[Exchange]
573626
// Create a query stage only when all the child query stages are ready.
574627
if (result.allChildStagesMaterialized) {
@@ -612,7 +665,7 @@ case class AdaptiveSparkPlanExec(
612665
if (plan.children.isEmpty) {
613666
CreateStageResult(newPlan = plan, allChildStagesMaterialized = true, newStages = Seq.empty)
614667
} else {
615-
val results = plan.children.map(createQueryStages)
668+
val results = plan.children.map(createQueryStagesInternal)
616669
CreateStageResult(
617670
newPlan = plan.withNewChildren(results.map(_.newPlan)),
618671
allChildStagesMaterialized = results.forall(_.allChildStagesMaterialized),

Diff for: sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanHelper.scala

+4-2
Original file line numberDiff line numberDiff line change
@@ -129,10 +129,12 @@ trait AdaptiveSparkPlanHelper {
129129
}
130130

131131
/**
132-
* Strip the executePlan of AdaptiveSparkPlanExec leaf node.
132+
* Strip the top [[AdaptiveSparkPlanExec]] and [[ResultQueryStageExec]] nodes off
133+
* the [[SparkPlan]].
133134
*/
134135
def stripAQEPlan(p: SparkPlan): SparkPlan = p match {
135-
case a: AdaptiveSparkPlanExec => a.executedPlan
136+
case a: AdaptiveSparkPlanExec => stripAQEPlan(a.executedPlan)
137+
case ResultQueryStageExec(_, plan, _) => plan
136138
case other => other
137139
}
138140
}

Diff for: sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala

+45
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ package org.apache.spark.sql.execution.adaptive
1919

2020
import java.util.concurrent.atomic.AtomicReference
2121

22+
import scala.concurrent.ExecutionContext
2223
import scala.concurrent.Future
24+
import scala.concurrent.Promise
2325

2426
import org.apache.spark.{MapOutputStatistics, SparkException}
2527
import org.apache.spark.broadcast.Broadcast
@@ -32,7 +34,10 @@ import org.apache.spark.sql.columnar.CachedBatch
3234
import org.apache.spark.sql.execution._
3335
import org.apache.spark.sql.execution.columnar.InMemoryTableScanLike
3436
import org.apache.spark.sql.execution.exchange._
37+
import org.apache.spark.sql.internal.SQLConf
38+
import org.apache.spark.sql.internal.StaticSQLConf
3539
import org.apache.spark.sql.vectorized.ColumnarBatch
40+
import org.apache.spark.util.ThreadUtils
3641

3742
/**
3843
* A query stage is an independent subgraph of the query plan. AQE framework will materialize its
@@ -303,3 +308,43 @@ case class TableCacheQueryStageExec(
303308

304309
override def getRuntimeStatistics: Statistics = inMemoryTableScan.runtimeStatistics
305310
}
311+
312+
case class ResultQueryStageExec(
313+
override val id: Int,
314+
override val plan: SparkPlan,
315+
resultHandler: SparkPlan => Any) extends QueryStageExec {
316+
317+
override def resetMetrics(): Unit = {
318+
plan.resetMetrics()
319+
}
320+
321+
override protected def doMaterialize(): Future[Any] = {
322+
val javaFuture = SQLExecution.withThreadLocalCaptured(
323+
session,
324+
ResultQueryStageExec.executionContext) {
325+
resultHandler(plan)
326+
}
327+
val scalaPromise: Promise[Any] = Promise()
328+
javaFuture.whenComplete { (result: Any, exception: Throwable) =>
329+
if (exception != null) {
330+
scalaPromise.failure(exception match {
331+
case completionException: java.util.concurrent.CompletionException =>
332+
completionException.getCause
333+
case ex => ex
334+
})
335+
} else {
336+
scalaPromise.success(result)
337+
}
338+
}
339+
scalaPromise.future
340+
}
341+
342+
// Result stage could be any SparkPlan, so we don't have a specific runtime statistics for it.
343+
override def getRuntimeStatistics: Statistics = Statistics(sizeInBytes = 0, rowCount = None)
344+
}
345+
346+
object ResultQueryStageExec {
347+
private[execution] val executionContext = ExecutionContext.fromExecutorService(
348+
ThreadUtils.newDaemonCachedThreadPool("ResultQueryStageExecution",
349+
SQLConf.get.getConf(StaticSQLConf.RESULT_QUERY_STAGE_MAX_THREAD_THRESHOLD)))
350+
}

0 commit comments

Comments
 (0)