Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TST: compare to sklearn.neural_network #155

Merged
merged 10 commits into from
Jan 16, 2021
2 changes: 1 addition & 1 deletion scikeras/utils/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def inverse_transform(self, y: np.ndarray) -> np.ndarray:
Keras Model predictions cast to the dtype and shape of the input
targets.
"""
if self._y_dtype == np.float64 and y.dtype == np.float32:
if y.dtype.name == "float32":
y = y.astype(np.float64, copy=False)
y = y.reshape(-1, *self._y_shape[1:])
return y
Expand Down
2 changes: 1 addition & 1 deletion scikeras/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ def _check_model_compatibility(self, y: np.ndarray) -> None:
This is in place to avoid cryptic TF errors.
"""
# check if this is a multi-output model
if hasattr(self, "n_outputs_expected_"):
if getattr(self, "n_outputs_expected_", None):
# n_outputs_expected_ is generated by data transformers
# we recognize the attribute but do not force it to be
# generated
Expand Down
36 changes: 28 additions & 8 deletions tests/multi_output_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,47 +3,67 @@
import numpy as np

from sklearn.utils.multiclass import type_of_target
from tensorflow.keras.backend import floatx as tf_floatx

from scikeras.utils.transformers import ClassifierLabelEncoder
from scikeras.wrappers import KerasClassifier


class MultiLabelTransformer(ClassifierLabelEncoder):
def __init__(
self, split: bool = True,
):
super().__init__()
self.split = split

def fit(self, y: np.ndarray) -> "MultiLabelTransformer":
self._target_type = type_of_target(y)
if self._target_type != "multilabel-indicator":
if self._target_type not in ("multilabel-indicator", "multiclass-multioutput"):
return super().fit(y)
# y = array([1, 1, 1, 0], [0, 0, 1, 1])
# each col will be processed as multiple binary classifications
self.n_outputs_ = self.n_outputs_expected_ = y.shape[1]
self.n_outputs_ = y.shape[1]
self.n_outputs_expected_ = None if not self.split else self.n_outputs_
self._y_dtype = y.dtype
self.classes_ = [np.array([0, 1])] * y.shape[1]
self.n_classes_ = [2] * y.shape[1]
return self

def transform(self, y: np.ndarray) -> List[np.ndarray]:
if self._target_type != "multilabel-indicator":
if self._target_type not in ("multilabel-indicator", "multiclass-multioutput"):
return super().transform(y)
return np.split(y, y.shape[1], axis=1)
y = y.astype(tf_floatx())
if self.split:
return np.split(y, y.shape[1], axis=1)
return y

def inverse_transform(
self, y: List[np.ndarray], return_proba: bool = False
) -> np.ndarray:
if self._target_type != "multilabel-indicator":
if self._target_type not in ("multilabel-indicator", "multiclass-multioutput"):
return super().inverse_transform(y, return_proba=return_proba)
if not return_proba:
if not return_proba and self.split:
y = [np.argmax(y_, axis=1).astype(self._y_dtype, copy=False) for y_ in y]
return np.squeeze(np.column_stack(y))
y = np.squeeze(np.column_stack(y))
if self._target_type == "multilabel-indicator":
# RandomForestClassifier and sklearn's MultiOutputClassifier always return int64
# for multilabel-indicator
y = y.astype(int)
return y


class MultiOutputClassifier(KerasClassifier):
"""Extend KerasClassifier with the ability to process
"multilabel-indicator" by mapping to multiple Keras outputs.
"""

def __init__(self, model=None, split: bool = True, **kwargs):
super().__init__(model=model, **kwargs)
self.split = split

@property
def target_encoder(self) -> MultiLabelTransformer:
return MultiLabelTransformer()
return MultiLabelTransformer(split=self.split)

def score(self, X, y):
"""Taken from sklearn.multiouput.MultiOutputClassifier
Expand Down
Loading