Skip to content

Commit

Permalink
updated docs
Browse files Browse the repository at this point in the history
  • Loading branch information
ashishpatel16 committed Mar 29, 2024
1 parent 70ba78d commit 6f5294e
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 28 deletions.
2 changes: 0 additions & 2 deletions Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ name = "pypi"
networkx = "*"
numpy = "*"
scikit-learn = "*"
shap = "*"
xarray = "*"

[dev-packages]
pytest = "*"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ predictions = pipeline.predict(X_test)
```

## Explaining Hierarchical Classifiers
Hierarchical classifiers can provide additional insights when combined with explainability methods. HiClass allows explaining hierarchical models using SHAP values. Different hierarchical models yield different insights. More information on explaining [Local classifier per parent node](https://colab.research.google.com/drive/1rVlYuRU_uO1jw5sD6qo2HoCpCz6E6z5J?usp=sharing), [Local classifier per node](), and [Local classifier per level]() is available on [Read the Docs](https://hiclass.readthedocs.io/en/latest/algorithms/explainer.html).
Hierarchical classifiers can provide additional insights when combined with explainability methods. HiClass allows explaining hierarchical models using SHAP values. Different hierarchical models yield different insights. More information on explaining [Local classifier per parent node](https://colab.research.google.com/drive/1rVlYuRU_uO1jw5sD6qo2HoCpCz6E6z5J?usp=sharing), [Local classifier per node](https://colab.research.google.com/drive/1wqSl1t_Qn2f62WNZQ48mdB0mNeu1XSF1?usp=sharing), and [Local classifier per level]() is available on [Read the Docs](https://hiclass.readthedocs.io/en/latest/algorithms/explainer.html).

## Step-by-step walk-through

Expand Down
39 changes: 21 additions & 18 deletions docs/examples/plot_lcpn_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,16 @@
A minimalist example showing how to use HiClass Explainer to obtain SHAP values of LCPN model.
A detailed summary of the Explainer class has been given at Algorithms Overview Section for :ref:`Hierarchical Explainability`.
SHAP values are calculated based on a synthetic platypus diseases dataset that can be downloaded `here <https://gist.githubusercontent.com/ashishpatel16/9306f8ed3ed101e7ddcb519776bcbd80/raw/3f225c3f80dd8cbb1b6252f6c372a054ec968705/platypus_diseases.csv>`_.
"""
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from hiclass import LocalClassifierPerNode, Explainer
from hiclass.datasets import load_platypus
import shap

# Define data
X_train = np.array(
[
[40.7, 1.0, 1.0, 2.0, 5.0, 2.0, 1.0, 5.0, 34.3],
[39.2, 0.0, 2.0, 4.0, 1.0, 3.0, 1.0, 2.0, 34.1],
[40.6, 0.0, 3.0, 1.0, 4.0, 5.0, 0.0, 6.0, 27.7],
[36.5, 0.0, 3.0, 1.0, 2.0, 2.0, 0.0, 2.0, 39.9],
]
)
X_test = np.array([[35.5, 0.0, 1.0, 1.0, 3.0, 3.0, 0.0, 2.0, 37.5]])
Y_train = np.array(
[
["Gastrointestinal", "Norovirus", ""],
["Respiratory", "Covid", ""],
["Allergy", "External", "Bee Allergy"],
["Respiratory", "Cold", ""],
]
)
# Load train and test splits
X_train, X_test, Y_train, Y_test = load_platypus()

# Use random forest classifiers for every node
rfc = RandomForestClassifier()
Expand All @@ -41,3 +28,19 @@
explainer = Explainer(classifier, data=X_train, mode="tree")
explanations = explainer.explain(X_test)
print(explanations)

# Filter samples which only predicted "Respiratory" at first level
respiratory_idx = classifier.predict(X_test)[:, 0] == "Respiratory"

# Specify additional filters to obtain only level 0
shap_filter = {"level": 0, "class": "Respiratory_1", "sample": respiratory_idx}

# Use .sel() method to apply the filter and obtain filtered results
shap_val_respiratory = explanations.sel(shap_filter)

# Plot feature importance on test set
shap.plots.violin(
shap_val_respiratory.shap_values,
feature_names=X_train.columns.values,
plot_size=(13, 8),
)
5 changes: 0 additions & 5 deletions docs/source/algorithms/explainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ Integrating explainability methods into Hierarchical classifiers can yield promi
++++++++++++++++++++++++++
Dataset overview
++++++++++++++++++++++++++

For the remainder of this section, we will utilize a synthetically generated dataset representing platypus diseases. This tabular dataset is created to visualize and test the essence of explainability using SHAP on hierarchical models. The diagram below illustrates the hierarchical structure of the dataset. With nine symptoms as features—fever, diarrhea, stomach pain, skin rash, cough, sniffles, shortness of breath, headache, and body size—the objective is to predict the disease based on these feature values.

.. figure:: ../algorithms/platypus_diseases_hierarchy.svg
Expand All @@ -30,7 +29,6 @@ For the remainder of this section, we will utilize a synthetically generated dat
++++++++++++++++++++++++++
Background
++++++++++++++++++++++++++

This section introduces two main concepts: hierarchical classification and SHAP values. Hierarchical classification leverages the hierarchical structure of data, breaking down the classification task into manageable sub-tasks using models organized in a tree or DAG structure.

SHAP values, adapted from game theory, show the impact of features on model predictions, thus aiding model interpretation. The SHAP library offers practical implementation of these methods, supporting various machine learning algorithms for explanation generation.
Expand All @@ -42,7 +40,6 @@ To demonstrate how SHAP values provide insights into model prediction, consider
test_sample = np.array([[35.5, 0. , 1. , 1. , 3. , 3. , 0. , 2. , 37.5]])
sample_target = np.array([['Respiratory', 'Cold', '']])
We can calculate SHAP values using the SHAP python package and visualize them. SHAP values tell us how much each symptom "contributes" to the model's decision about which disease a platypus might have. The following diagram illustrates how SHAP values can be visualized using the :literal:`shap.force_plot`.

.. figure:: ../algorithms/shap_explanation.png
Expand All @@ -69,7 +66,6 @@ The Explainer class takes a fitted HiClass model, training data, and some named
explainer = Explainer(fitted_hiclass_model, data=training_data)
The Explainer returns an :literal:`xarray.Dataset` object which allows users to intuitively access, filter, slice, and plot SHAP values. This Explanation object can also be used interactively within the Jupyter notebook environment. The Explanation object along with its respective attributes are depicted in the following UML diagram.

.. figure:: ../algorithms/hiclass-uml.png
Expand Down Expand Up @@ -133,4 +129,3 @@ To achieve this, we can use xarray's :literal:`.sel()` method:
x = explanations.sel(mask).shap_values
More advanced usage and capabilities can be found at the `Xarray.Dataset <https://docs.xarray.dev/en/stable/generated/xarray.Dataset.html>`_ documentation.

7 changes: 5 additions & 2 deletions hiclass/Explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,11 @@ def _calculate_shap_values(self, X):
datasets = []
level = 0
for node in traversed_nodes:
# Skip if classifier is not found, can happen in case of imbalanced hierarchies
if "classifier" not in self.hierarchical_model.hierarchy_.nodes[node]:
# Skip if node is empty or classifier is not found, can happen in case of imbalanced hierarchies
if (
node == ""
or "classifier" not in self.hierarchical_model.hierarchy_.nodes[node]
):
continue

local_classifier = self.hierarchical_model.hierarchy_.nodes[node][
Expand Down

0 comments on commit 6f5294e

Please sign in to comment.