diff --git a/hiclass/HierarchicalClassifier.py b/hiclass/HierarchicalClassifier.py index e23915b1..1351fa0b 100644 --- a/hiclass/HierarchicalClassifier.py +++ b/hiclass/HierarchicalClassifier.py @@ -161,7 +161,9 @@ def _pre_fit(self, X, y, sample_weight): ) else: self.X_ = np.array(X) - self.y_ = np.array(y) + self.y_ = check_array( + make_leveled(y), dtype=None, ensure_2d=False, allow_nd=True + ) if sample_weight is not None: self.sample_weight_ = _check_sample_weight(sample_weight, X) diff --git a/setup.py b/setup.py index 1c362c1e..4d80b4d8 100644 --- a/setup.py +++ b/setup.py @@ -45,6 +45,7 @@ "ray", "shap==0.44.1", "xarray==2023.1.0", + "bert-sklearn @ git+https://github.com/charles9n/bert-sklearn.git#egg=bert-sklearn", ], } diff --git a/tests/test_LocalClassifierPerParentNode.py b/tests/test_LocalClassifierPerParentNode.py index ad1a5469..268fc990 100644 --- a/tests/test_LocalClassifierPerParentNode.py +++ b/tests/test_LocalClassifierPerParentNode.py @@ -4,12 +4,14 @@ import networkx as nx import numpy as np import pytest -from numpy.testing import assert_array_equal, assert_array_almost_equal +from bert_sklearn import BertClassifier +from numpy.testing import assert_array_almost_equal, assert_array_equal from scipy.sparse import csr_matrix from sklearn.exceptions import NotFittedError from sklearn.linear_model import LogisticRegression from sklearn.utils.estimator_checks import parametrize_with_checks from sklearn.utils.validation import check_is_fitted + from hiclass import LocalClassifierPerParentNode from hiclass._calibration.Calibrator import _Calibrator from hiclass.HierarchicalClassifier import make_leveled @@ -393,3 +395,37 @@ def test_fit_calibrate_predict_predict_proba_bert(): classifier.calibrate(x, y) classifier.predict(x) classifier.predict_proba(x) + + +# Note: bert only works with the local classifier per parent node +# It does not have the attribute classes_, which are necessary +# for the local classifiers per level and per node +def test_fit_bert(): + bert = BertClassifier() + clf = LocalClassifierPerParentNode( + local_classifier=bert, + bert=True, + ) + x = ["Batman", "rorschach"] + y = [ + ["Action", "The Dark Night"], + ["Action", "Watchmen"], + ] + clf.fit(x, y) + check_is_fitted(clf) + predictions = clf.predict(x) + assert_array_equal(y, predictions) + + +def test_bert_unleveled(): + clf = LocalClassifierPerParentNode( + local_classifier=BertClassifier(), + bert=True, + ) + x = ["Batman", "Jaws"] + y = [["Action", "The Dark Night"], ["Thriller"]] + ground_truth = [["Action", "The Dark Night"], ["Action", "The Dark Night"]] + clf.fit(x, y) + check_is_fitted(clf) + predictions = clf.predict(x) + assert_array_equal(ground_truth, predictions) diff --git a/tests/test_LocalClassifiers.py b/tests/test_LocalClassifiers.py index abd7bddf..065a1486 100644 --- a/tests/test_LocalClassifiers.py +++ b/tests/test_LocalClassifiers.py @@ -10,8 +10,8 @@ from sklearn.utils.validation import check_is_fitted from hiclass import ( - LocalClassifierPerNode, LocalClassifierPerLevel, + LocalClassifierPerNode, LocalClassifierPerParentNode, ) from hiclass.ConstantClassifier import ConstantClassifier @@ -75,21 +75,6 @@ def test_empty_levels(empty_levels, classifier): assert_array_equal(ground_truth, predictions) -@pytest.mark.parametrize("classifier", classifiers) -def test_fit_bert(classifier): - bert = ConstantClassifier() - clf = classifier( - local_classifier=bert, - bert=True, - ) - X = ["Text 1", "Text 2"] - y = ["a", "a"] - clf.fit(X, y) - check_is_fitted(clf) - predictions = clf.predict(X) - assert_array_equal(y, predictions) - - @pytest.mark.parametrize("classifier", classifiers) def test_knn(classifier): knn = KNeighborsClassifier(