diff --git a/miceforest/imputation_kernel.py b/miceforest/imputation_kernel.py index 8fd90dd..ea46126 100644 --- a/miceforest/imputation_kernel.py +++ b/miceforest/imputation_kernel.py @@ -187,7 +187,6 @@ def __init__( super().__init__( impute_data=data, - # num_datasets=num_datasets, datasets=datasets, variable_schema=variable_schema, save_all_iterations_data=save_all_iterations_data, diff --git a/tests/test_ImputationKernel.py b/tests/test_ImputationKernel.py index 9dc3b4f..ed24bb9 100644 --- a/tests/test_ImputationKernel.py +++ b/tests/test_ImputationKernel.py @@ -172,6 +172,15 @@ def make_and_test_kernel(**kwargs): # Make sure we can impute the special cases imputed_data_special_1 = kernel.impute_new_data(new_amputed_data_special_1) + + # Before we do anything else, make sure saving / loading works + new_file, filename = mkstemp() + with open(filename, "wb") as file: + dill.dump(imputed_data_special_1, file) + del imputed_data_special_1 + with open(filename, "rb") as file: + imputed_data_special_1 = dill.load(file) + imputed_data_special_2 = kernel.impute_new_data(new_amputed_data_special_2) imputed_dataset_special_1 = imputed_data_special_1.complete_data(0) imputed_dataset_special_2 = imputed_data_special_2.complete_data(0) @@ -254,6 +263,10 @@ def make_and_test_kernel(**kwargs): assert op["cat_l2"] == 0.5 assert 1 <= op["min_data_in_leaf"] <= 10 + # Test plotting + kernel.plot_imputed_distributions() + kernel.plot_feature_importance(dataset=0) + return kernel