From fcf447ee5f6294af2c915bdd5244100cd7d4a638 Mon Sep 17 00:00:00 2001 From: AnotherSamWilson Date: Thu, 1 Aug 2024 20:21:15 -0400 Subject: [PATCH] Added plot_mean_convergence method --- miceforest/imputed_data.py | 259 +++++++++++---------------------- tests/test_ImputationKernel.py | 1 + 2 files changed, 89 insertions(+), 171 deletions(-) diff --git a/miceforest/imputed_data.py b/miceforest/imputed_data.py index 121d254..4e59341 100644 --- a/miceforest/imputed_data.py +++ b/miceforest/imputed_data.py @@ -1,3 +1,4 @@ +import importlib.metadata from io import BytesIO from itertools import combinations from typing import Any, Dict, List, Optional, Union @@ -111,9 +112,6 @@ def __init__( self.initialized = False self.imputed_variable_count = len(self.imputed_variables) self.modeled_variable_count = len(self.modeled_variables) - # self.iterations = np.zeros( - # shape=(self.num_datasets, self.modeled_variable_count) - # ).astype(int) # Create a multiindexed dataframe to store our imputation values iv_multiindex = MultiIndex.from_product( @@ -132,6 +130,9 @@ def __init__( for dataset in datasets: self.iteration_tab[variable, dataset] = 0 + # Save the version of miceforest that was used to make this kernel + self.version = importlib.metadata.version("miceforest") + # Subsetting allows us to get to the imputation values: def __getitem__(self, tup): variable, iteration, dataset = tup @@ -323,72 +324,6 @@ def complete_data( if not inplace: return impute_data - # def get_means(self, datasets, variables=None): - # """ - # Return a dict containing the average imputation value - # for specified variables at each iteration. - # """ - # num_vars = self._get_num_vars(variables) - - # # For every variable, get the correlations between every dataset combination - # # at each iteration - # curr_iteration = self.iteration_count(datasets=datasets) - # if self.save_all_iterations: - # iter_range = list(range(curr_iteration + 1)) - # else: - # iter_range = [curr_iteration] - # mean_dict = { - # ds: { - # var: {itr: np.mean(self[ds, var, itr]) for itr in iter_range} - # for var in num_vars - # } - # for ds in datasets - # } - - # return mean_dict - - # def plot_mean_convergence(self, datasets=None, variables=None, **adj_args): - # """ - # Plots the average value of imputations over each iteration. - - # Parameters - # ---------- - # variables: None or list - # The variables to plot. Must be numeric. - # adj_args - # Passed to matplotlib.pyplot.subplots_adjust() - - # """ - - # try: - # import matplotlib.pyplot as plt - # from matplotlib import gridspec - # except ImportError: - # raise ImportError("matplotlib must be installed to plot mean convergence") - - # if self.iteration_count() < 2 or not self.save_all_iterations: - # raise ValueError("There is only one iteration.") - - # if datasets is None: - # datasets = list(range(self.dataset_count())) - # else: - # datasets = _ensure_iterable(datasets) - # num_vars = self._get_num_vars(variables) - # mean_dict = self.get_means(datasets=datasets, variables=variables) - # plots, plotrows, plotcols = self._prep_multi_plot(num_vars) - # gs = gridspec.GridSpec(plotrows, plotcols) - # fig, ax = plt.subplots(plotrows, plotcols, squeeze=False) - - # for v in range(plots): - # axr, axc = next(iter(gs[v].rowspan)), next(iter(gs[v].colspan)) - # var = num_vars[v] - # for d in mean_dict.values(): - # ax[axr, axc].plot(list(d[var].values()), color="black") - # ax[axr, axc].set_title(var) - # ax[axr, axc].set_xlabel("Iteration") - # ax[axr, axc].set_ylabel("mean") - # plt.subplots_adjust(**adj_args) - def plot_imputed_distributions( self, variables: Optional[List[str]] = None, iteration: int = -1 ): @@ -467,105 +402,87 @@ def plot_imputed_distributions( return fig - # def get_correlations( - # self, datasets: List[int], variables: Union[List[int], List[str]] - # ): - # """ - # Return the correlations between datasets for - # the specified variables. - - # Parameters - # ---------- - # variables: list[str], list[int] - # The variables to return the correlations for. - - # Returns - # ------- - # dict - # The correlations at each iteration for the specified - # variables. - - # """ - - # if self.dataset_count() < 3: - # raise ValueError( - # "Not enough datasets to calculate correlations between them" - # ) - # curr_iteration = self.iteration_count() - # var_indx = self._get_var_ind_from_list(variables) - - # # For every variable, get the correlations between every dataset combination - # # at each iteration - # correlation_dict = {} - # if self.save_all_iterations: - # iter_range = list(range(1, curr_iteration + 1)) - # else: - # # Make this iterable for code tidyness - # iter_range = [curr_iteration] - - # for var in var_indx: - # # Get a dict of variables and imputations for all datasets for this iteration - # iteration_level_imputations = { - # iteration: {ds: self[ds, var, iteration] for ds in datasets} - # for iteration in iter_range - # } - - # combination_correlations = { - # iteration: [ - # round(np.corrcoef(impcomb)[0, 1], 3) - # for impcomb in list(combinations(varimps.values(), 2)) - # ] - # for iteration, varimps in iteration_level_imputations.items() - # } - - # correlation_dict[var] = combination_correlations - - # return correlation_dict - - # def plot_correlations(self, datasets=None, variables=None, **adj_args): - # """ - # Plot the correlations between datasets. - # See get_correlations() for more details. - - # Parameters - # ---------- - # datasets: None or list[int] - # The datasets to plot. - # variables: None,list - # The variables to plot. - # adj_args - # Additional arguments passed to plt.subplots_adjust() - - # """ - - # try: - # import matplotlib.pyplot as plt - # from matplotlib import gridspec - # except ImportError: - # raise ImportError("matplotlib must be installed to plot importance") - - # if self.dataset_count() < 4: - # raise ValueError("Not enough datasets to make box plot") - # if datasets is None: - # datasets = list(range(self.dataset_count())) - # else: - # datasets = _ensure_iterable(datasets) - # var_indx = self._get_var_ind_from_list(variables) - # num_vars = self._get_num_vars(var_indx) - # plots, plotrows, plotcols = self._prep_multi_plot(num_vars) - # correlation_dict = self.get_correlations(datasets=datasets, variables=num_vars) - # gs = gridspec.GridSpec(plotrows, plotcols) - # fig, ax = plt.subplots(plotrows, plotcols, squeeze=False) - - # for v in range(plots): - # axr, axc = next(iter(gs[v].rowspan)), next(iter(gs[v].colspan)) - # var = list(correlation_dict)[v] - # ax[axr, axc].boxplot( - # list(correlation_dict[var].values()), - # labels=range(len(correlation_dict[var])), - # ) - # ax[axr, axc].set_title(self._get_var_name_from_scalar(var)) - # ax[axr, axc].set_xlabel("Iteration") - # ax[axr, axc].set_ylabel("Correlations") - # ax[axr, axc].set_ylim([-1, 1]) - # plt.subplots_adjust(**adj_args) + def plot_mean_convergence( + self, + variables: Optional[List[str]] = None, + ): + """ + Plots the average value and standard deviation of imputations over each iteration. + The lines show the average imputation value for a dataset over the iteration. + The bars show the average standard deviation of the imputation values within datasets. + + Parameters + ---------- + variables: Optional[List[str]], default=None + The variables to plot. By default, all numeric, imputed variables are plotted. + """ + + try: + from plotnine import ( + aes, + element_text, + facet_wrap, + geom_errorbar, + geom_line, + geom_point, + ggplot, + ggtitle, + theme, + theme_538, + xlab, + ylab, + ) + except ImportError: + raise ImportError("plotnine must be installed to plot distributions.") + + num_vars = self.working_data.select_dtypes("number").columns.to_list() + imp_vars = self.imputed_variables + imp_num_vars = [v for v in num_vars if v in imp_vars] + if variables is None: + variables = imp_num_vars + else: + variables = [v for v in variables if v in imp_num_vars] + + plot_dat = DataFrame() + for variable in variables: + dat = self.imputation_values[variable].melt(col_level="iteration") + dat["dataset"] = self.imputation_values[variable].melt(col_level="dataset")[ + "dataset" + ] + dat = ( + dat.groupby(["dataset", "iteration"]) + .agg({"value": ["mean", "std"]}) + .reset_index() + ) + dat["middle"] = dat[("value", "mean")] + dat["upper"] = dat["middle"] + dat[("value", "std")] + dat["lower"] = dat["middle"] - dat[("value", "std")] + del dat["value"] + dat.columns = dat.columns.droplevel(1) + iter_dat = dat.groupby("iteration").agg( + {"lower": "mean", "middle": "mean", "upper": "mean"} + ) + dat["lower"] = dat.iteration.map(iter_dat["lower"]) + dat["stdavg"] = dat.iteration.map(iter_dat["middle"]) + dat["upper"] = dat.iteration.map(iter_dat["upper"]) + dat["variable"] = variable + plot_dat = concat([dat, plot_dat], axis=0) + + fig = ( + ggplot(plot_dat, aes(x="iteration", y="middle", group="dataset")) + + geom_line() + + geom_errorbar( + aes(x="iteration", ymin="lower", ymax="upper", group="dataset") + ) + + geom_point(aes(x="iteration", y="stdavg")) + + facet_wrap("variable", scales="free") + + ggtitle("Mean Convergence Plot") + + xlab("") + + ylab("") + + theme( + plot_title=element_text(ha="left", size=20), + ) + + theme_538() + ) + + return fig diff --git a/tests/test_ImputationKernel.py b/tests/test_ImputationKernel.py index ed24bb9..609fbe1 100644 --- a/tests/test_ImputationKernel.py +++ b/tests/test_ImputationKernel.py @@ -266,6 +266,7 @@ def make_and_test_kernel(**kwargs): # Test plotting kernel.plot_imputed_distributions() kernel.plot_feature_importance(dataset=0) + kernel.plot_mean_convergence() return kernel