|
34 | 34 | HashingTF,
|
35 | 35 | IDF,
|
36 | 36 | IDFModel,
|
| 37 | + Imputer, |
| 38 | + ImputerModel, |
37 | 39 | NGram,
|
38 | 40 | RFormula,
|
39 | 41 | Tokenizer,
|
@@ -541,6 +543,46 @@ def test_word2vec(self):
|
541 | 543 | model2 = Word2VecModel.load(d)
|
542 | 544 | self.assertEqual(str(model), str(model2))
|
543 | 545 |
|
| 546 | + def test_imputer(self): |
| 547 | + spark = self.spark |
| 548 | + df = spark.createDataFrame( |
| 549 | + [ |
| 550 | + (1.0, float("nan")), |
| 551 | + (2.0, float("nan")), |
| 552 | + (float("nan"), 3.0), |
| 553 | + (4.0, 4.0), |
| 554 | + (5.0, 5.0), |
| 555 | + ], |
| 556 | + ["a", "b"], |
| 557 | + ) |
| 558 | + |
| 559 | + imputer = Imputer(strategy="mean") |
| 560 | + imputer.setInputCols(["a", "b"]) |
| 561 | + imputer.setOutputCols(["out_a", "out_b"]) |
| 562 | + |
| 563 | + self.assertEqual(imputer.getStrategy(), "mean") |
| 564 | + self.assertEqual(imputer.getInputCols(), ["a", "b"]) |
| 565 | + self.assertEqual(imputer.getOutputCols(), ["out_a", "out_b"]) |
| 566 | + |
| 567 | + model = imputer.fit(df) |
| 568 | + self.assertEqual(model.surrogateDF.columns, ["a", "b"]) |
| 569 | + self.assertEqual(model.surrogateDF.count(), 1) |
| 570 | + self.assertEqual(list(model.surrogateDF.head()), [3.0, 4.0]) |
| 571 | + |
| 572 | + output = model.transform(df) |
| 573 | + self.assertEqual(output.columns, ["a", "b", "out_a", "out_b"]) |
| 574 | + self.assertEqual(output.count(), 5) |
| 575 | + |
| 576 | + # save & load |
| 577 | + with tempfile.TemporaryDirectory(prefix="imputer") as d: |
| 578 | + imputer.write().overwrite().save(d) |
| 579 | + imputer2 = Imputer.load(d) |
| 580 | + self.assertEqual(str(imputer), str(imputer2)) |
| 581 | + |
| 582 | + model.write().overwrite().save(d) |
| 583 | + model2 = ImputerModel.load(d) |
| 584 | + self.assertEqual(str(model), str(model2)) |
| 585 | + |
544 | 586 | def test_count_vectorizer(self):
|
545 | 587 | df = self.spark.createDataFrame(
|
546 | 588 | [(0, ["a", "b", "c"]), (1, ["a", "b", "b", "c", "a"])],
|
|
0 commit comments