From 6f5294e89a89cfa7472dca5fab539e7e51507d45 Mon Sep 17 00:00:00 2001 From: ashishpatel16 Date: Fri, 29 Mar 2024 13:59:44 +0100 Subject: [PATCH] updated docs --- Pipfile | 2 -- README.md | 2 +- docs/examples/plot_lcpn_explainer.py | 39 +++++++++++++++------------- docs/source/algorithms/explainer.rst | 5 ---- hiclass/Explainer.py | 7 +++-- 5 files changed, 27 insertions(+), 28 deletions(-) diff --git a/Pipfile b/Pipfile index e7e174cb..4a08c49b 100644 --- a/Pipfile +++ b/Pipfile @@ -7,8 +7,6 @@ name = "pypi" networkx = "*" numpy = "*" scikit-learn = "*" -shap = "*" -xarray = "*" [dev-packages] pytest = "*" diff --git a/README.md b/README.md index 7fd8bd28..c835ba91 100644 --- a/README.md +++ b/README.md @@ -202,7 +202,7 @@ predictions = pipeline.predict(X_test) ``` ## Explaining Hierarchical Classifiers -Hierarchical classifiers can provide additional insights when combined with explainability methods. HiClass allows explaining hierarchical models using SHAP values. Different hierarchical models yield different insights. More information on explaining [Local classifier per parent node](https://colab.research.google.com/drive/1rVlYuRU_uO1jw5sD6qo2HoCpCz6E6z5J?usp=sharing), [Local classifier per node](), and [Local classifier per level]() is available on [Read the Docs](https://hiclass.readthedocs.io/en/latest/algorithms/explainer.html). +Hierarchical classifiers can provide additional insights when combined with explainability methods. HiClass allows explaining hierarchical models using SHAP values. Different hierarchical models yield different insights. More information on explaining [Local classifier per parent node](https://colab.research.google.com/drive/1rVlYuRU_uO1jw5sD6qo2HoCpCz6E6z5J?usp=sharing), [Local classifier per node](https://colab.research.google.com/drive/1wqSl1t_Qn2f62WNZQ48mdB0mNeu1XSF1?usp=sharing), and [Local classifier per level]() is available on [Read the Docs](https://hiclass.readthedocs.io/en/latest/algorithms/explainer.html). ## Step-by-step walk-through diff --git a/docs/examples/plot_lcpn_explainer.py b/docs/examples/plot_lcpn_explainer.py index e140b751..f7903495 100644 --- a/docs/examples/plot_lcpn_explainer.py +++ b/docs/examples/plot_lcpn_explainer.py @@ -6,29 +6,16 @@ A minimalist example showing how to use HiClass Explainer to obtain SHAP values of LCPN model. A detailed summary of the Explainer class has been given at Algorithms Overview Section for :ref:`Hierarchical Explainability`. +SHAP values are calculated based on a synthetic platypus diseases dataset that can be downloaded `here `_. """ import numpy as np from sklearn.ensemble import RandomForestClassifier from hiclass import LocalClassifierPerNode, Explainer +from hiclass.datasets import load_platypus +import shap -# Define data -X_train = np.array( - [ - [40.7, 1.0, 1.0, 2.0, 5.0, 2.0, 1.0, 5.0, 34.3], - [39.2, 0.0, 2.0, 4.0, 1.0, 3.0, 1.0, 2.0, 34.1], - [40.6, 0.0, 3.0, 1.0, 4.0, 5.0, 0.0, 6.0, 27.7], - [36.5, 0.0, 3.0, 1.0, 2.0, 2.0, 0.0, 2.0, 39.9], - ] -) -X_test = np.array([[35.5, 0.0, 1.0, 1.0, 3.0, 3.0, 0.0, 2.0, 37.5]]) -Y_train = np.array( - [ - ["Gastrointestinal", "Norovirus", ""], - ["Respiratory", "Covid", ""], - ["Allergy", "External", "Bee Allergy"], - ["Respiratory", "Cold", ""], - ] -) +# Load train and test splits +X_train, X_test, Y_train, Y_test = load_platypus() # Use random forest classifiers for every node rfc = RandomForestClassifier() @@ -41,3 +28,19 @@ explainer = Explainer(classifier, data=X_train, mode="tree") explanations = explainer.explain(X_test) print(explanations) + +# Filter samples which only predicted "Respiratory" at first level +respiratory_idx = classifier.predict(X_test)[:, 0] == "Respiratory" + +# Specify additional filters to obtain only level 0 +shap_filter = {"level": 0, "class": "Respiratory_1", "sample": respiratory_idx} + +# Use .sel() method to apply the filter and obtain filtered results +shap_val_respiratory = explanations.sel(shap_filter) + +# Plot feature importance on test set +shap.plots.violin( + shap_val_respiratory.shap_values, + feature_names=X_train.columns.values, + plot_size=(13, 8), +) diff --git a/docs/source/algorithms/explainer.rst b/docs/source/algorithms/explainer.rst index ef15775b..b87ced9c 100644 --- a/docs/source/algorithms/explainer.rst +++ b/docs/source/algorithms/explainer.rst @@ -18,7 +18,6 @@ Integrating explainability methods into Hierarchical classifiers can yield promi ++++++++++++++++++++++++++ Dataset overview ++++++++++++++++++++++++++ - For the remainder of this section, we will utilize a synthetically generated dataset representing platypus diseases. This tabular dataset is created to visualize and test the essence of explainability using SHAP on hierarchical models. The diagram below illustrates the hierarchical structure of the dataset. With nine symptoms as features—fever, diarrhea, stomach pain, skin rash, cough, sniffles, shortness of breath, headache, and body size—the objective is to predict the disease based on these feature values. .. figure:: ../algorithms/platypus_diseases_hierarchy.svg @@ -30,7 +29,6 @@ For the remainder of this section, we will utilize a synthetically generated dat ++++++++++++++++++++++++++ Background ++++++++++++++++++++++++++ - This section introduces two main concepts: hierarchical classification and SHAP values. Hierarchical classification leverages the hierarchical structure of data, breaking down the classification task into manageable sub-tasks using models organized in a tree or DAG structure. SHAP values, adapted from game theory, show the impact of features on model predictions, thus aiding model interpretation. The SHAP library offers practical implementation of these methods, supporting various machine learning algorithms for explanation generation. @@ -42,7 +40,6 @@ To demonstrate how SHAP values provide insights into model prediction, consider test_sample = np.array([[35.5, 0. , 1. , 1. , 3. , 3. , 0. , 2. , 37.5]]) sample_target = np.array([['Respiratory', 'Cold', '']]) - We can calculate SHAP values using the SHAP python package and visualize them. SHAP values tell us how much each symptom "contributes" to the model's decision about which disease a platypus might have. The following diagram illustrates how SHAP values can be visualized using the :literal:`shap.force_plot`. .. figure:: ../algorithms/shap_explanation.png @@ -69,7 +66,6 @@ The Explainer class takes a fitted HiClass model, training data, and some named explainer = Explainer(fitted_hiclass_model, data=training_data) - The Explainer returns an :literal:`xarray.Dataset` object which allows users to intuitively access, filter, slice, and plot SHAP values. This Explanation object can also be used interactively within the Jupyter notebook environment. The Explanation object along with its respective attributes are depicted in the following UML diagram. .. figure:: ../algorithms/hiclass-uml.png @@ -133,4 +129,3 @@ To achieve this, we can use xarray's :literal:`.sel()` method: x = explanations.sel(mask).shap_values More advanced usage and capabilities can be found at the `Xarray.Dataset `_ documentation. - diff --git a/hiclass/Explainer.py b/hiclass/Explainer.py index 7c82a06e..90fc29e7 100644 --- a/hiclass/Explainer.py +++ b/hiclass/Explainer.py @@ -251,8 +251,11 @@ def _calculate_shap_values(self, X): datasets = [] level = 0 for node in traversed_nodes: - # Skip if classifier is not found, can happen in case of imbalanced hierarchies - if "classifier" not in self.hierarchical_model.hierarchy_.nodes[node]: + # Skip if node is empty or classifier is not found, can happen in case of imbalanced hierarchies + if ( + node == "" + or "classifier" not in self.hierarchical_model.hierarchy_.nodes[node] + ): continue local_classifier = self.hierarchical_model.hierarchy_.nodes[node][