Skip to content

Commit 4af6be6

Browse files
committedJan 25, 2025
[SPARK-50937][ML][PYTHON][CONNECT] Support Imputer on Connect
### What changes were proposed in this pull request? Support `Imputer` on Connect ### Why are the changes needed? for feature parity ### Does this PR introduce _any_ user-facing change? yes ### How was this patch tested? added test ### Was this patch authored or co-authored using generative AI tooling? no Closes #49667 from zhengruifeng/ml_connect_imputer. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent ff6b1a9 commit 4af6be6

File tree

6 files changed

+48
-0
lines changed

6 files changed

+48
-0
lines changed
 

‎mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator

+1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ org.apache.spark.ml.recommendation.ALS
4343
org.apache.spark.ml.fpm.FPGrowth
4444

4545
# feature
46+
org.apache.spark.ml.feature.Imputer
4647
org.apache.spark.ml.feature.StandardScaler
4748
org.apache.spark.ml.feature.MaxAbsScaler
4849
org.apache.spark.ml.feature.MinMaxScaler

‎mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer

+1
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ org.apache.spark.ml.recommendation.ALSModel
5656
org.apache.spark.ml.fpm.FPGrowthModel
5757

5858
# feature
59+
org.apache.spark.ml.feature.ImputerModel
5960
org.apache.spark.ml.feature.StandardScalerModel
6061
org.apache.spark.ml.feature.MaxAbsScalerModel
6162
org.apache.spark.ml.feature.MinMaxScalerModel

‎mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala

+2
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,8 @@ class ImputerModel private[ml] (
246246

247247
import ImputerModel._
248248

249+
private[ml] def this() = this(Identifiable.randomUID("imputer"), null)
250+
249251
/** @group setParam */
250252
@Since("3.0.0")
251253
def setInputCol(value: String): this.type = set(inputCol, value)

‎python/pyspark/ml/feature.py

+1
Original file line numberDiff line numberDiff line change
@@ -2261,6 +2261,7 @@ def setOutputCol(self, value: str) -> "ImputerModel":
22612261

22622262
@property
22632263
@since("2.2.0")
2264+
@try_remote_attribute_relation
22642265
def surrogateDF(self) -> DataFrame:
22652266
"""
22662267
Returns a DataFrame containing inputCols and their corresponding surrogates,

‎python/pyspark/ml/tests/test_feature.py

+42
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
HashingTF,
3535
IDF,
3636
IDFModel,
37+
Imputer,
38+
ImputerModel,
3739
NGram,
3840
RFormula,
3941
Tokenizer,
@@ -541,6 +543,46 @@ def test_word2vec(self):
541543
model2 = Word2VecModel.load(d)
542544
self.assertEqual(str(model), str(model2))
543545

546+
def test_imputer(self):
547+
spark = self.spark
548+
df = spark.createDataFrame(
549+
[
550+
(1.0, float("nan")),
551+
(2.0, float("nan")),
552+
(float("nan"), 3.0),
553+
(4.0, 4.0),
554+
(5.0, 5.0),
555+
],
556+
["a", "b"],
557+
)
558+
559+
imputer = Imputer(strategy="mean")
560+
imputer.setInputCols(["a", "b"])
561+
imputer.setOutputCols(["out_a", "out_b"])
562+
563+
self.assertEqual(imputer.getStrategy(), "mean")
564+
self.assertEqual(imputer.getInputCols(), ["a", "b"])
565+
self.assertEqual(imputer.getOutputCols(), ["out_a", "out_b"])
566+
567+
model = imputer.fit(df)
568+
self.assertEqual(model.surrogateDF.columns, ["a", "b"])
569+
self.assertEqual(model.surrogateDF.count(), 1)
570+
self.assertEqual(list(model.surrogateDF.head()), [3.0, 4.0])
571+
572+
output = model.transform(df)
573+
self.assertEqual(output.columns, ["a", "b", "out_a", "out_b"])
574+
self.assertEqual(output.count(), 5)
575+
576+
# save & load
577+
with tempfile.TemporaryDirectory(prefix="imputer") as d:
578+
imputer.write().overwrite().save(d)
579+
imputer2 = Imputer.load(d)
580+
self.assertEqual(str(imputer), str(imputer2))
581+
582+
model.write().overwrite().save(d)
583+
model2 = ImputerModel.load(d)
584+
self.assertEqual(str(model), str(model2))
585+
544586
def test_count_vectorizer(self):
545587
df = self.spark.createDataFrame(
546588
[(0, ["a", "b", "c"]), (1, ["a", "b", "b", "c", "a"])],

‎sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala

+1
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,7 @@ private[ml] object MLUtils {
582582
(classOf[FPGrowthModel], Set("associationRules", "freqItemsets")),
583583

584584
// Feature Models
585+
(classOf[ImputerModel], Set("surrogateDF")),
585586
(classOf[StandardScalerModel], Set("mean", "std")),
586587
(classOf[MaxAbsScalerModel], Set("maxAbs")),
587588
(classOf[MinMaxScalerModel], Set("originalMax", "originalMin")),

0 commit comments

Comments
 (0)