Skip to content

Commit

Permalink
Add zero_division parameter #minor (#136)
Browse files Browse the repository at this point in the history
  • Loading branch information
mirand863 authored Nov 27, 2024
1 parent c029aa6 commit a344014
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 12 deletions.
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ sphinx_code_tabs==0.5.3
sphinx-gallery==0.10.1
matplotlib==3.5.2
pandas==1.4.2
ray==1.13.0
ray
numpy
git+https://github.com/charles9n/bert-sklearn.git@master
shap==0.44.1
Expand Down
101 changes: 90 additions & 11 deletions hiclass/metrics.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
"""Helper functions to compute hierarchical evaluation metrics."""

from typing import Union, List
import warnings
from typing import List, Union

import numpy as np
from sklearn.utils import check_array
from sklearn.exceptions import UndefinedMetricWarning
from sklearn.metrics import log_loss as sk_log_loss
from sklearn.preprocessing import LabelEncoder
from sklearn.utils import check_array

from hiclass.HierarchicalClassifier import make_leveled
from hiclass import HierarchicalClassifier
from hiclass.HierarchicalClassifier import make_leveled


def _validate_input(y_true, y_pred):
Expand Down Expand Up @@ -208,7 +211,12 @@ def _recall_macro(y_true: np.ndarray, y_pred: np.ndarray):
return _compute_macro(y_true, y_pred, _recall_micro)


def f1(y_true: np.ndarray, y_pred: np.ndarray, average: str = "micro"):
def f1(
y_true: np.ndarray,
y_pred: np.ndarray,
average: str = "micro",
zero_division: str = "warn",
):
r"""
Compute hierarchical f-score.
Expand All @@ -223,33 +231,104 @@ def f1(y_true: np.ndarray, y_pred: np.ndarray, average: str = "micro"):
- `micro`: The f-score is computed by summing over all individual instances, :math:`\displaystyle{hF = \frac{2 \times hP \times hR}{hP + hR}}`, where :math:`hP` is the hierarchical precision and :math:`hR` is the hierarchical recall.
- `macro`: The f-score is computed for each instance and then averaged, :math:`\displaystyle{hF = \frac{\sum_{i=1}^{n}hF_{i}}{n}}`, where :math:`\alpha_i` is the set consisting of the most specific classes predicted for test example :math:`i` and all their ancestor classes, while :math:`\beta_i` is the set containing the true most specific classes of test example :math:`i` and all their ancestors.
zero_division: {"warn", 0.0, 1.0, np.nan}, default="warn"
Sets the value to return when there is a zero division, i.e., when all
predictions and labels are negative.
Notes:
- If set to "warn", this acts like 0, but a warning is also raised.
- If set to `np.nan`, such values will be excluded from the average.
Returns
-------
f1 : float
Weighted average of the precision and recall
Notes
-----
When ``precision + recall == 0`` (i.e. classes
are completely different from both ``y_true`` and ``y_pred``), f-score is
undefined. In such cases, by default f-score will be set to 0.0, and
``UndefinedMetricWarning`` will be raised. This behavior can be modified by
setting the ``zero_division`` parameter.
References
----------
.. [1] `A survey of hierarchical classification across different application domains
<https://link.springer.com/article/10.1007/S10618-010-0175-9>`_.
Examples
--------
>>> import numpy as np
>>> from hiclass.metrics import f1
>>> y_true = [[0, 1, 2], [3, 4, 5]]
>>> y_pred = [[0, 1, 2], [6, 7, 8]]
>>> f1(y_true, y_pred, average='micro')
0.5
>>> f1(y_true, y_pred, average='macro')
0.5
>>> # zero division
>>> y_true = [[0, 1], [2, 3]]
>>> y_pred = [[4, 5], [6, 7]]
>>> f1(y_true, y_pred)
F-score is ill-defined and being set to 0.0. Use `zero_division` parameter to control this behavior.
0.0
>>> f1(y_true, y_pred, zero_division=1.0)
1.0
>>> f1(y_true, y_pred, zero_division=np.nan)
nan
>>> # multilabel hierarchical classification
>>> y_true = [[["a", "b", "c"]], [["d", "e", "f"]], [["g", "h", "i"]]]
>>> y_pred = [[["a", "b", "c"]], [["d", "e", "f"]], [["g", "h", "i"]]]
>>> f1(y_true, y_pred)
1.0
"""
y_true, y_pred = _validate_input(y_true, y_pred)
functions = {
"micro": _f_score_micro,
"macro": _f_score_macro,
}
return functions[average](y_true, y_pred)
return functions[average](y_true, y_pred, zero_division)


def _f_score_micro(y_true: np.ndarray, y_pred: np.ndarray):
def _f_score_micro(y_true: np.ndarray, y_pred: np.ndarray, zero_division):
prec = precision(y_true, y_pred)
rec = recall(y_true, y_pred)
return 2 * prec * rec / (prec + rec)
if prec + rec == 0:
if zero_division == "warn":
msg = (
"F-score is ill-defined and being set to 0.0. "
"Use `zero_division` parameter to control this behavior."
)
warnings.warn(msg, UndefinedMetricWarning, stacklevel=2)
return np.float64(0.0)
elif zero_division in [0, 1]:
return np.float64(zero_division)
else:
return np.nan
else:
return np.float64(2 * prec * rec / (prec + rec))


def _f_score_macro(y_true: np.ndarray, y_pred: np.ndarray):
return _compute_macro(y_true, y_pred, _f_score_micro)
def _f_score_macro(y_true: np.ndarray, y_pred: np.ndarray, zero_division):
return _compute_macro(y_true, y_pred, _f_score_micro, zero_division)


def _compute_macro(y_true: np.ndarray, y_pred: np.ndarray, _micro_function):
def _compute_macro(
y_true: np.ndarray, y_pred: np.ndarray, _micro_function, zero_division=None
):
overall_sum = 0
for ground_truth, prediction in zip(y_true, y_pred):
sample_score = _micro_function(np.array([ground_truth]), np.array([prediction]))
if zero_division:
sample_score = _micro_function(
np.array([ground_truth]), np.array([prediction]), zero_division
)
else:
sample_score = _micro_function(
np.array([ground_truth]), np.array([prediction])
)
overall_sum = overall_sum + sample_score
return overall_sum / len(y_true)

Expand Down
99 changes: 99 additions & 0 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,24 +264,55 @@ def test_f1_micro_1d_list():
assert 0.5 == f1(y_true, y_pred, "micro")


def test_f1_micro_1d_list_zero_division():
y_true = [1, 2, 3, 4]
y_pred = [5, 6, 7, 8]
assert 0.0 == f1(y_true, y_pred, "micro")
assert 1.0 == f1(y_true, y_pred, "micro", 1.0)
assert np.isnan(f1(y_true, y_pred, "micro", np.nan))


def test_f1_micro_2d_list():
y_true = [[1, 2, 3, 4], [1, 2, 5, 6]]
y_pred = [[1, 2, 5, 6], [1, 2, 3, 4]]
assert 0.5 == f1(y_true, y_pred, "micro")


def test_f1_micro_2d_list_zero_division():
y_true = [[1, 2, 3, 4], [5, 6, 7, 8]]
y_pred = [[5, 6, 7, 8], [1, 2, 3, 4]]
assert 0.0 == f1(y_true, y_pred, "micro")
assert 1.0 == f1(y_true, y_pred, "micro", 1.0)


def test_f1_micro_1d_np_array():
y_true = np.array([1, 2, 3, 4])
y_pred = np.array([1, 2, 5, 6])
assert 0.5 == f1(y_true, y_pred, "micro")


def test_f1_micro_1d_np_array_zero_division():
y_true = np.array([1, 2, 3, 4])
y_pred = np.array([5, 6, 7, 8])
assert 0.0 == f1(y_true, y_pred, "micro")
assert 1.0 == f1(y_true, y_pred, "micro", 1.0)
assert np.isnan(f1(y_true, y_pred, "micro", np.nan))


def test_f1_micro_2d_np_array():
y_true = np.array([[1, 2, 3, 4], [1, 2, 5, 6]])
y_pred = np.array([[1, 2, 5, 6], [1, 2, 3, 4]])
assert 0.5 == f1(y_true, y_pred, "micro")


def test_f1_micro_2d_np_array_zero_division():
y_true = np.array([[1, 2, 3, 4], [5, 6, 7, 8]])
y_pred = np.array([[5, 6, 7, 8], [1, 2, 3, 4]])
assert 0.0 == f1(y_true, y_pred, "micro")
assert 1.0 == f1(y_true, y_pred, "micro", 1.0)
assert np.isnan(f1(y_true, y_pred, "micro", np.nan))


def test_f1_micro_3d_np_array():
y_true = np.array(
[
Expand All @@ -299,30 +330,80 @@ def test_f1_micro_3d_np_array():
assert 1 == f1(y_true, y_true, "micro")


def test_f1_micro_3d_np_array_zero_division():
y_true = np.array(
[
[["a", "b"], ["c", "d"]],
[["e", "f"], ["g", "h"]],
]
)
y_pred = np.array(
[
[["i", "j"], ["k", "l"]],
[["m", "n"], ["o", "p"]],
]
)
assert 0.0 == f1(y_true, y_pred, "micro")
assert 1.0 == f1(y_true, y_pred, "micro", 1.0)
assert np.isnan(f1(y_true, y_pred, "micro", np.nan))


def test_f1_macro_1d_list():
y_true = [1, 2, 3, 4]
y_pred = [1, 2, 3, 4]
assert 1 == f1(y_true, y_pred, "macro")


def test_f1_macro_1d_list_zero_division():
y_true = [1, 2, 3, 4]
y_pred = [5, 6, 7, 8]
assert 0.0 == f1(y_true, y_pred, "macro")
assert 1.0 == f1(y_true, y_pred, "macro", 1.0)
assert np.isnan(f1(y_true, y_pred, "macro", np.nan))


def test_f1_macro_2d_list():
y_true = [[1, 2, 3, 4], [1, 2, 5, 6]]
y_pred = [[1, 5, 6], [1, 2, 3]]
assert 0.4285714 == approx(f1(y_true, y_pred, "macro"))


def test_f1_macro_2d_list_zero_division():
y_true = [[1, 2, 3, 4], [5, 6, 7, 8]]
y_pred = [[5, 6, 7, 8], [1, 2, 3, 4]]
assert 0.0 == f1(y_true, y_pred, "macro")
assert 1.0 == f1(y_true, y_pred, "macro", 1.0)
assert np.isnan(f1(y_true, y_pred, "macro", np.nan))


def test_f1_macro_1d_np_array():
y_true = np.array([1, 2, 3, 4])
y_pred = np.array([1, 2, 3, 4])
assert 1 == f1(y_true, y_pred, "macro")


def test_f1_macro_1d_np_array_zero_division():
y_true = np.array([1, 2, 3, 4])
y_pred = np.array([5, 6, 7, 8])
assert 0.0 == f1(y_true, y_pred, "macro")
assert 1.0 == f1(y_true, y_pred, "macro", 1.0)
assert np.isnan(f1(y_true, y_pred, "macro", np.nan))


def test_f1_macro_2d_np_array():
y_true = np.array([[1, 2, 3, 4], [1, 2, 5, 6]])
y_pred = np.array([[1, 5, 6], [1, 2, 3]])
assert 0.4285714 == approx(f1(y_true, y_pred, "macro"))


def test_f1_macro_2d_np_array_zero_division():
y_true = np.array([[1, 2, 3, 4], [5, 6, 7, 8]])
y_pred = np.array([[5, 6, 7, 8], [1, 2, 3, 4]])
assert 0.0 == f1(y_true, y_pred, "macro")
assert 1.0 == f1(y_true, y_pred, "macro", 1.0)
assert np.isnan(f1(y_true, y_pred, "macro", np.nan))


def test_f1_macro_3d_np_array():
y_true = np.array(
[
Expand All @@ -340,6 +421,24 @@ def test_f1_macro_3d_np_array():
assert 1 == f1(y_true, y_true, "macro")


def test_f1_macro_3d_np_array_zero_division():
y_true = np.array(
[
[["a", "b"], ["c", "d"]],
[["e", "f"], ["g", "h"]],
]
)
y_pred = np.array(
[
[["i", "j"], ["k", "l"]],
[["m", "n"], ["o", "p"]],
]
)
assert 0.0 == f1(y_true, y_pred, "macro")
assert 1.0 == f1(y_true, y_pred, "macro", 1.0)
assert np.isnan(f1(y_true, y_pred, "macro", np.nan))


def test_empty_levels_2d_list_1():
y_true = [["2", "3"], ["1"], ["4", "5", "6"]]
y_pred = [["1"], ["2", "3"], ["4", "5", "6"]]
Expand Down

0 comments on commit a344014

Please sign in to comment.