diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala index d78a3a391edb6..c8ef755f8a818 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala @@ -955,7 +955,7 @@ class Dataset[T] private[sql]( /** @inheritdoc */ def reduce(func: (T, T) => T): T = withNewRDDExecutionId("reduce") { - rdd.reduce(func) + materializedRdd.reduce(func) } /** @inheritdoc */ @@ -1471,7 +1471,7 @@ class Dataset[T] private[sql]( /** @inheritdoc */ def foreachPartition(f: Iterator[T] => Unit): Unit = withNewRDDExecutionId("foreachPartition") { - rdd.foreachPartition(f) + materializedRdd.foreachPartition(f) } /** @inheritdoc */ @@ -1573,14 +1573,20 @@ class Dataset[T] private[sql]( sparkSession.sessionState.executePlan(deserialized) } - /** @inheritdoc */ - lazy val rdd: RDD[T] = { + private lazy val materializedRdd: RDD[T] = { val objectType = exprEnc.deserializer.dataType rddQueryExecution.toRdd.mapPartitions { rows => rows.map(_.get(0, objectType).asInstanceOf[T]) } } + /** @inheritdoc */ + lazy val rdd: RDD[T] = { + withNewRDDExecutionId("rdd") { + materializedRdd + } + } + /** @inheritdoc */ def toJavaRDD: JavaRDD[T] = rdd.toJavaRDD() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 9b8400f0e3a15..125aad5e52853 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2721,6 +2721,25 @@ class DataFrameSuite extends QueryTest parameters = Map("name" -> ".whatever") ) } + + test("SPARK-50994: RDD conversion is performed with execution context") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + withTempDir(dir => { + val dummyDF = Seq((1, 1.0), (2, 2.0), (3, 3.0), (1, 1.0)).toDF("a", "A") + dummyDF.write.format("parquet").mode("overwrite").save(dir.getCanonicalPath) + + val df = spark.read.parquet(dir.getCanonicalPath) + val encoder = ExpressionEncoder(df.schema) + val deduplicated = df.dropDuplicates(Array("a")) + val df2 = deduplicated.flatMap(row => Seq(row))(encoder).rdd + + val output = spark.createDataFrame(df2, df.schema) + checkAnswer(output, Seq(Row(1, 1.0), Row(2, 2.0), Row(3, 3.0))) + }) + } + } + } } case class GroupByKey(a: Int, b: Int) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index ecbf77e4c3c01..dcad1a1942ffb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -2704,7 +2704,7 @@ class SQLQuerySuite extends SQLQuerySuiteBase with DisableAdaptiveExecutionSuite checkAnswer(sql(s"SELECT id FROM $targetTable"), Row(1) :: Row(2) :: Row(3) :: Nil) spark.sparkContext.listenerBus.waitUntilEmpty() - assert(commands.size == 3) + assert(commands.size == 4) assert(commands.head.nodeName == "Execute CreateHiveTableAsSelectCommand") val v1WriteCommand = commands(1)