@@ -236,6 +236,14 @@ case class AdaptiveSparkPlanExec(
236
236
237
237
@ volatile private var currentPhysicalPlan = initialPlan
238
238
239
+ // Use inputPlan logicalLink here in case some top level physical nodes may be removed
240
+ // during `initialPlan`
241
+ @ transient @ volatile private var currentLogicalPlan : LogicalPlan = {
242
+ inputPlan.logicalLink.get
243
+ }
244
+
245
+ val stagesToReplace = mutable.ArrayBuffer .empty[QueryStageExec ]
246
+
239
247
@ volatile private var _isFinalPlan = false
240
248
241
249
private var currentStageId = 0
@@ -289,26 +297,24 @@ case class AdaptiveSparkPlanExec(
289
297
290
298
def finalPhysicalPlan : SparkPlan = withFinalPlanUpdate(identity)
291
299
292
- private def getFinalPhysicalPlan (): SparkPlan = lock.synchronized {
293
- if (isFinalPlan) return currentPhysicalPlan
294
-
300
+ /**
301
+ * Run `fun` on finalized physical plan
302
+ */
303
+ def withFinalPlanUpdate [T ](fun : SparkPlan => T ): T = lock.synchronized {
304
+ _isFinalPlan = false
295
305
// In case of this adaptive plan being executed out of `withActive` scoped functions, e.g.,
296
306
// `plan.queryExecution.rdd`, we need to set active session here as new plan nodes can be
297
307
// created in the middle of the execution.
298
308
context.session.withActive {
299
309
val executionId = getExecutionId
300
- // Use inputPlan logicalLink here in case some top level physical nodes may be removed
301
- // during `initialPlan`
302
- var currentLogicalPlan = inputPlan.logicalLink.get
303
- var result = createQueryStages(currentPhysicalPlan)
310
+ var result = createQueryStages(fun, currentPhysicalPlan, true )
304
311
val events = new LinkedBlockingQueue [StageMaterializationEvent ]()
305
312
val errors = new mutable.ArrayBuffer [Throwable ]()
306
- var stagesToReplace = Seq .empty[QueryStageExec ]
307
313
while (! result.allChildStagesMaterialized) {
308
314
ruleContext.clearConfigs()
309
315
currentPhysicalPlan = result.newPlan
310
316
if (result.newStages.nonEmpty) {
311
- stagesToReplace = result.newStages ++ stagesToReplace
317
+ stagesToReplace ++ = result.newStages
312
318
executionId.foreach(onUpdatePlan(_, result.newStages.map(_.plan)))
313
319
314
320
// SPARK-33933: we should submit tasks of broadcast stages first, to avoid waiting
@@ -366,50 +372,44 @@ case class AdaptiveSparkPlanExec(
366
372
if (errors.nonEmpty) {
367
373
cleanUpAndThrowException(errors.toSeq, None )
368
374
}
369
-
370
- // Try re-optimizing and re-planning. Adopt the new plan if its cost is equal to or less
371
- // than that of the current plan; otherwise keep the current physical plan together with
372
- // the current logical plan since the physical plan's logical links point to the logical
373
- // plan it has originated from.
374
- // Meanwhile, we keep a list of the query stages that have been created since last plan
375
- // update, which stands for the "semantic gap" between the current logical and physical
376
- // plans. And each time before re-planning, we replace the corresponding nodes in the
377
- // current logical plan with logical query stages to make it semantically in sync with
378
- // the current physical plan. Once a new plan is adopted and both logical and physical
379
- // plans are updated, we can clear the query stage list because at this point the two plans
380
- // are semantically and physically in sync again.
381
- val logicalPlan = replaceWithQueryStagesInLogicalPlan(currentLogicalPlan, stagesToReplace)
382
- val afterReOptimize = reOptimize(logicalPlan)
383
- if (afterReOptimize.isDefined) {
384
- val (newPhysicalPlan, newLogicalPlan) = afterReOptimize.get
385
- val origCost = costEvaluator.evaluateCost(currentPhysicalPlan)
386
- val newCost = costEvaluator.evaluateCost(newPhysicalPlan)
387
- if (newCost < origCost ||
388
- (newCost == origCost && currentPhysicalPlan != newPhysicalPlan)) {
389
- lazy val plans =
390
- sideBySide(currentPhysicalPlan.treeString, newPhysicalPlan.treeString).mkString(" \n " )
391
- logOnLevel(log " Plan changed: \n ${MDC (QUERY_PLAN , plans)}" )
392
- cleanUpTempTags(newPhysicalPlan)
393
- currentPhysicalPlan = newPhysicalPlan
394
- currentLogicalPlan = newLogicalPlan
395
- stagesToReplace = Seq .empty[QueryStageExec ]
375
+ if (! currentPhysicalPlan.isInstanceOf [ResultQueryStageExec ]) {
376
+ // Try re-optimizing and re-planning. Adopt the new plan if its cost is equal to or less
377
+ // than that of the current plan; otherwise keep the current physical plan together with
378
+ // the current logical plan since the physical plan's logical links point to the logical
379
+ // plan it has originated from.
380
+ // Meanwhile, we keep a list of the query stages that have been created since last plan
381
+ // update, which stands for the "semantic gap" between the current logical and physical
382
+ // plans. And each time before re-planning, we replace the corresponding nodes in the
383
+ // current logical plan with logical query stages to make it semantically in sync with
384
+ // the current physical plan. Once a new plan is adopted and both logical and physical
385
+ // plans are updated, we can clear the query stage list because at this point the two
386
+ // plans are semantically and physically in sync again.
387
+ val logicalPlan = replaceWithQueryStagesInLogicalPlan(currentLogicalPlan,
388
+ stagesToReplace.toSeq)
389
+ val afterReOptimize = reOptimize(logicalPlan)
390
+ if (afterReOptimize.isDefined) {
391
+ val (newPhysicalPlan, newLogicalPlan) = afterReOptimize.get
392
+ val origCost = costEvaluator.evaluateCost(currentPhysicalPlan)
393
+ val newCost = costEvaluator.evaluateCost(newPhysicalPlan)
394
+ if (newCost < origCost ||
395
+ (newCost == origCost && currentPhysicalPlan != newPhysicalPlan)) {
396
+ lazy val plans = sideBySide(
397
+ currentPhysicalPlan.treeString, newPhysicalPlan.treeString).mkString(" \n " )
398
+ logOnLevel(log " Plan changed: \n ${MDC (QUERY_PLAN , plans)}" )
399
+ cleanUpTempTags(newPhysicalPlan)
400
+ currentPhysicalPlan = newPhysicalPlan
401
+ currentLogicalPlan = newLogicalPlan
402
+ stagesToReplace.clear()
403
+ }
396
404
}
405
+ // Now that some stages have finished, we can try creating new stages.
406
+ result = createQueryStages(fun, currentPhysicalPlan, false )
397
407
}
398
- // Now that some stages have finished, we can try creating new stages.
399
- result = createQueryStages(currentPhysicalPlan)
400
408
}
401
-
402
- ruleContext = ruleContext.withFinalStage(isFinalStage = true )
403
- // Run the final plan when there's no more unfinished stages.
404
- currentPhysicalPlan = applyPhysicalRulesWithRuleContext(
405
- optimizeQueryStage(result.newPlan, isFinalStage = true ),
406
- postStageCreationRules(supportsColumnar),
407
- Some ((planChangeLogger, " AQE Post Stage Creation" )))
408
- ruleContext.clearConfigs()
409
- _isFinalPlan = true
410
- executionId.foreach(onUpdatePlan(_, Seq (currentPhysicalPlan)))
411
- currentPhysicalPlan
412
409
}
410
+ _isFinalPlan = true
411
+ finalPlanUpdate
412
+ currentPhysicalPlan.asInstanceOf [ResultQueryStageExec ].resultOption.get().get.asInstanceOf [T ]
413
413
}
414
414
415
415
// Use a lazy val to avoid this being called more than once.
@@ -450,13 +450,6 @@ case class AdaptiveSparkPlanExec(
450
450
}
451
451
}
452
452
453
- private def withFinalPlanUpdate [T ](fun : SparkPlan => T ): T = {
454
- val plan = getFinalPhysicalPlan()
455
- val result = fun(plan)
456
- finalPlanUpdate
457
- result
458
- }
459
-
460
453
protected override def stringArgs : Iterator [Any ] = Iterator (s " isFinalPlan= $isFinalPlan" )
461
454
462
455
override def generateTreeString (
@@ -545,6 +538,66 @@ case class AdaptiveSparkPlanExec(
545
538
this .inputPlan == obj.asInstanceOf [AdaptiveSparkPlanExec ].inputPlan
546
539
}
547
540
541
+ /**
542
+ * This method is a wrapper of `createQueryStagesInternal`, which deals with result stage creation
543
+ */
544
+ private def createQueryStages (
545
+ resultHandler : SparkPlan => Any ,
546
+ plan : SparkPlan ,
547
+ firstRun : Boolean ): CreateStageResult = {
548
+ plan match {
549
+ case resultStage@ ResultQueryStageExec (_, optimizedPlan, _) =>
550
+ return if (firstRun) {
551
+ // There is already an existing ResultQueryStage created in previous `withFinalPlanUpdate`
552
+ val newResultStage = ResultQueryStageExec (currentStageId, optimizedPlan, resultHandler)
553
+ currentStageId += 1
554
+ setLogicalLinkForNewQueryStage(newResultStage, optimizedPlan)
555
+ stagesToReplace.append(newResultStage)
556
+ CreateStageResult (newPlan = newResultStage,
557
+ allChildStagesMaterialized = false ,
558
+ newStages = Seq (newResultStage))
559
+ } else {
560
+ // result stage already created, do nothing
561
+ CreateStageResult (newPlan = plan,
562
+ allChildStagesMaterialized = resultStage.isMaterialized,
563
+ newStages = Seq .empty)
564
+ }
565
+ case _ =>
566
+ }
567
+ val result = createQueryStagesInternal(plan)
568
+ var allNewStages = result.newStages
569
+ var newPlan = result.newPlan
570
+ var allChildStagesMaterialized = result.allChildStagesMaterialized
571
+ // Create result stage
572
+ if (allNewStages.isEmpty && allChildStagesMaterialized) {
573
+ val resultStage = createResultQueryStage(resultHandler, newPlan)
574
+ stagesToReplace.append(resultStage)
575
+ newPlan = resultStage
576
+ allChildStagesMaterialized = false
577
+ allNewStages :+= resultStage
578
+ }
579
+ CreateStageResult (
580
+ newPlan = newPlan,
581
+ allChildStagesMaterialized = allChildStagesMaterialized,
582
+ newStages = allNewStages)
583
+ }
584
+
585
+ private def createResultQueryStage (
586
+ resultHandler : SparkPlan => Any ,
587
+ plan : SparkPlan ): ResultQueryStageExec = {
588
+ ruleContext = ruleContext.withFinalStage(isFinalStage = true )
589
+ // Run the final plan when there's no more unfinished stages.
590
+ val optimizedRootPlan = applyPhysicalRulesWithRuleContext(
591
+ optimizeQueryStage(plan, isFinalStage = true ),
592
+ postStageCreationRules(supportsColumnar),
593
+ Some ((planChangeLogger, " AQE Post Stage Creation" )))
594
+ ruleContext.clearConfigs()
595
+ val resultStage = ResultQueryStageExec (currentStageId, optimizedRootPlan, resultHandler)
596
+ currentStageId += 1
597
+ setLogicalLinkForNewQueryStage(resultStage, plan)
598
+ resultStage
599
+ }
600
+
548
601
/**
549
602
* This method is called recursively to traverse the plan tree bottom-up and create a new query
550
603
* stage or try reusing an existing stage if the current node is an [[Exchange ]] node and all of
@@ -555,7 +608,7 @@ case class AdaptiveSparkPlanExec(
555
608
* 2) Whether the child query stages (if any) of the current node have all been materialized.
556
609
* 3) A list of the new query stages that have been created.
557
610
*/
558
- private def createQueryStages (plan : SparkPlan ): CreateStageResult = plan match {
611
+ private def createQueryStagesInternal (plan : SparkPlan ): CreateStageResult = plan match {
559
612
case e : Exchange =>
560
613
// First have a quick check in the `stageCache` without having to traverse down the node.
561
614
context.stageCache.get(e.canonicalized) match {
@@ -568,7 +621,7 @@ case class AdaptiveSparkPlanExec(
568
621
newStages = if (isMaterialized) Seq .empty else Seq (stage))
569
622
570
623
case _ =>
571
- val result = createQueryStages (e.child)
624
+ val result = createQueryStagesInternal (e.child)
572
625
val newPlan = e.withNewChildren(Seq (result.newPlan)).asInstanceOf [Exchange ]
573
626
// Create a query stage only when all the child query stages are ready.
574
627
if (result.allChildStagesMaterialized) {
@@ -612,7 +665,7 @@ case class AdaptiveSparkPlanExec(
612
665
if (plan.children.isEmpty) {
613
666
CreateStageResult (newPlan = plan, allChildStagesMaterialized = true , newStages = Seq .empty)
614
667
} else {
615
- val results = plan.children.map(createQueryStages )
668
+ val results = plan.children.map(createQueryStagesInternal )
616
669
CreateStageResult (
617
670
newPlan = plan.withNewChildren(results.map(_.newPlan)),
618
671
allChildStagesMaterialized = results.forall(_.allChildStagesMaterialized),
0 commit comments