Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: Composable input/output pipeline #234

Open
wants to merge 19 commits into
base: master
Choose a base branch
from
Open
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
!/docs
!/scikeras
!/tests
!/rfcs
!/pyproject.toml
!/.gitignore
!/.pre-commit-config.yaml
Expand Down
391 changes: 391 additions & 0 deletions rfcs/1-input-pipeline.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,391 @@
## Background

One of the primary functions of the Scikit-Learn wrappers for Keras is fielding conversion from Scikit-Learn's data types
to TensorFlow's data types. An example of these conversions is integer-encoding or one-hot encoding categorical targets.

Originally, these conversions were hardcoded into the wrappers in an ad-hoc manner (see here).
adriangb marked this conversation as resolved.
Show resolved Hide resolved
SciKeras introduced the concept of `TargetTransformer` and `InputTransformer`,
two Scikit-Learn style transformers that formalize this data conversion framework and allow users to insert their own custom pipeline.

Currently, user-customization requires subclassing the wrappers,
and composability is only provided via meta-transformers (i.e. a Scikit-Learn pipeline) of transformers.

Seperately, SciKeras also implements data validations that mirrors what most Scikit-Learn estimators implement,
for example to assert that `X` and `y` are of the same length or that `y` is purely numeric for regressors.
SciKeras also validates and inspects the model, for example to make sure that the output shape matches the
target's shape.
These checks are helpful for simple models, but may be too restrictive for more complex scenarios, like
multi-input/output models.

This RFC proposes a unifified interface for composable data transformations and validations.
The goal is to provide a pipeline of default transformations and validations that cover the simple use cases,
while allowing users to easily remove checks or add transformation steps for more advanced use cases.

Some of the speicific functional requirements are:
1. Able to implement the default (i.e. current) data transformations and validations. This includes:
- Integer encoding targets for classifiers.
- One-hot encoding targets for classifiers using the categorical crossentropy loss.
- Converting class probability predictions into class predictions for classifiers.
2. Able to implement user-defined transformations, including:
- Splitting the input and/or target into multiple inputs/outputs.
- Reshaping 2D inputs into 3D.
3. Able to operate on array-like data (lists, Numpy arrays, Pandas DataFrames, etc.) as well as `tf.data.Dataset`s.
4. Composable and modifiable without subclassing.

## Proposal

This proposal consists of 2 pipelines:
1. A pipeline for preparing array-like data for a `tf.data.Dataset`. This would include, for example, integer-encoding object-dtype arrays (`tf.Tensors` can't hold objects).
2. A pipeline for applying transformations to `tf.data.Dataset`s. For example, one-hot encoding targets for classifiers using the categorical crossentropy loss.

If the data comes in as a `tf.data.Dataset`, the first pipeline is skipped. If not, it is run and the output is converted to a `Dataset`.
The second pipeline is then run in all cases.

These pipelines will consist of chained transformers implementing a Scikit-Learn-like interface, but without being restricted to the exact Scikit-Learn API.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but without being restricted to the exact Scikit-Learn API.

By "exact Scikit-Learn API," you mean "the exact names of Scikit-learn transformers," right? And the proposed "Scikit-learn-like interface" is public, correct?

If the answers to both of those questions are affirmative, 👎 to that interface. How can this RFC be reworked to conform with the "exact Scikit-Learn API"?

Copy link
Owner Author

@adriangb adriangb Jun 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By "exact Scikit-Learn API" I mean roughly:

  • having a fit method with the signature fit(X, y=None) -> self
  • having a transform method with the signature transform(X) -> X'
  • an inverse_transform method with the signature inverse_transform(X') -> X.
    And no other methods (that are part of the API).
    Also inverse_transform(transform(X)) == X is not a requirement of the API, but it certainly is the spirit of it (and how most sklearn transformers are implemented, when it makes sense to do so).

And the proposed "Scikit-learn-like interface" is public, correct?

Yes, it would be public and consists of ArrayTransformer and DatasetTransformer.
These can be base classes, protocols/interfaces or just duck-typing (implementation detail).
The idea is that users could take our default data validation/transformations and mix/match with their own for use cases like multi-output models or multi-output class reweighting.

This proposed interface violates the sklearn interface in several ways:

  • by passing X, y and sample_weight in ArrayTransformer
  • by passing a tf.data.Dataset object in DatasetTransformer
  • because these really aren't inverse transformations: the forward/input transformation transforms X & y, the output transformation transforms y' (y predictions) and/or y itself. So inverse_transform(transform(X)) != X.
  • by allowing side effects to be applied to the model (this is an addition to the API, it doesn't strictly break the implementation, but it does break the spirit).

How can this RFC be reworked to conform with the "exact Scikit-Learn API"?

I don't think it can be easily reworked: there are too many limitations in the sklearn transformer API (and the authors have said as much themselves).

The good news is that the Scikit-Learn API is a strict subset of this API, so we could provide wrappers to convert the Scikit-Learn transformer API into this one (eg. by specifying what X is to your transformer and dispatching the methods).

But let me ask: why is it important to adhere to the Scikit-Learn API? I know there are good reasons, I just want understand which you are thinking of.
My thoughts are that if we can't interoperate directly (i.e. use make_pipeline and other facilities of sklearn), there is not much value in being semi-compatible. I also think there aren't that many things (beyond pipelines) that would be useful here. But this may be short-sighted, I'm open to other opinions.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This library has one job: to bring the Scikit-learn API to Keras. That means there needs to be a really good reason to break the Scikit-learn API.

That reason needs to consider alternative implementations, including the one that exactly follows the Scikit-learn API. I imagine some questions that need to be answered include the following:

  • What affect does the implementation have on developers?
  • What affect does the implementation have on the user?
  • What specific issues make the implementation inappropriate?


```python
from typing import Any, Dict, Optional, Tuple, Sequence, Union, TYPE_CHECKING

import numpy as np
import tensorflow as tf

from scikeras.wrappers import BaseWrapper


Data = Tuple[np.ndarray, Union[np.ndarray, None], Union[np.ndarray, None]]
ArrayLike = Union[Sequence, np.ndarray]


class NotInitializedError(Exception):
...


class ArrayTransformer:

def set_model(self, model: "BaseWrapper") -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👎

Why not follow the Scikit-Learn API?

class ArrayTransformer(BaseEstimator):
    def __init__(self, model=None):
        self.model = model

    ...

a = ArrayTransformer(model=model) # option 0

a = ArrayTransformer().model = model # option 1

a = ArrayTransformer()
a.set_params(model=model) # option 2

Yes, TF uses set_model. Why should their boilerplate be followed?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would require users to give us a class to initialize (this object is being passed to the constructor that creates model).

I guess we could use parameter routing to allow setting any other parameters. Does this look any better?

class MyTransf:

    def __init__(self, model, other):
        self.other = other
    
    ...

est = BaseWrapper(..., pipeline_param=[MyTransf], pipeline_param__0__other=True)
# or
est = BaseWrapper(..., pipeline_param={"tfname": MyTransf}, pipeline_param__ tfname__other=True)  # relies on dict ordering to know what order to run pipeline in
# or
est = BaseWrapper(..., pipeline_param=[("tfname", MyTransf)], pipeline_param__ tfname__other=True)  # same as an sklearn pipeline using a tuple to set the name

An alternative would be to require users to bind any other parameters using functools.partial or something.

self.model = model

def transform_input(self, X: ArrayLike, y: Optional[ArrayLike], sample_weight: Optional[ArrayLike], *, initialize: bool = True) -> Data:
return X, y, sample_weight

def transform_output(self, y_pred_proba: np.ndarray, y: Union[np.ndarray, None]) -> Tuple[np.ndarray, Union[np.ndarray, None]]:
return y_pred_proba, None

def get_meta(self) -> Dict[str, Any]:
stsievert marked this conversation as resolved.
Show resolved Hide resolved
return {}


class DatasetTransformer:

def set_model(self, model: "BaseWrapper") -> None:
self.model = model

def transform_input(self, data: tf.data.Dataset, *, initialize: bool = True) -> tf.data.Dataset:
return data

def transform_output(self, y_pred_proba: np.ndarray, y: Union[np.ndarray, None]) -> Tuple[np.ndarray, Union[np.ndarray, None]]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not rename these functions inverse_transform?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because their signature and usage does not match inverse_transform's signature. They wouldn't work in an sklearn Pipeline, nor would they be chainable with sklearn estimators.

return y_pred_proba, None

def get_meta(self) -> Dict[str, Any]:
return {}

```

SciKeras will initialize the pipeline of transformers by calling `set_model` with a reference to the current estimator.
This is exactly how Keras handles callbacks ([`set_model`](https://github.com/tensorflow/tensorflow/blob/a4dfb8d1a71385bd6d122e4f27f86dcebb96712d/tensorflow/python/keras/callbacks.py#L302), [in-place modification of `model`](https://github.com/tensorflow/tensorflow/blob/a4dfb8d1a71385bd6d122e4f27f86dcebb96712d/tensorflow/python/keras/callbacks.py#L1153) in History).

Then SciKeras will then iterate through them, similar to a Scikit-Learn pipeline

```python
import itertools
from typing import Optional, Protocol, Sequence, Tuple, Type, Union

from numpy import ndarray
from numpy.typing import ArrayLike
import tensorflow as tf


Input = Union[tf.data.Dataset, ArrayLike]


class BaseWrapper:

def __init__(
self,
array_transformers: Sequence[Callable[["BaseWrapper", ArrayTransformer]]] = tuple(), # a tuple with default transformers
dataset_transformers: Sequence[Callable[["BaseWrapper", ArrayTransformer]]] = tuple(),
) -> None:
self.array_transformers = array_transformers
self.dataset_transformers = dataset_transformers

def _transform_input(self, X: Union[tf.data.Dataset, ArrayLike], y: Optional[ArrayLike], sample_weight: Optional[ArrayLike], *, initialize: bool) -> tf.data.Dataset:
if initialize:
for tf in itertools.chain(self.array_pipeline, self.dataset_pipeline):
tf.set_model(self)
if isinstance(X, tf.data.Dataset):
self._numpy_input = False
data = X
else:
self._numpy_input = True
for t in self.array_pipeline:
X, y, sample_weight = t.transform_input(X, y, sample_weight, initialize=initialize)
data = tf.data.Dataset.from_tensors((X, y, sample_weight))
for t in self.dataset_pipeline_:
data = t.transform_input(data, initialize=initialize)
if self._numpy_input:
# keep as numpy arrays to allow validation_split and such to work
X, y, sample_weight = next(iter(data))
return X, y, sample_weight
else:
return data, None, None

def initialize(self, X: Input, y: Optional[ArrayLike], sample_weight: Optional[ArrayLike]) -> "BaseWrapper":
self._transform_input(X, y, sample_weight, initialize=True)
return self

def fit(self, X: Input, y: Optional[ArrayLike], sample_weight: Optional[ArrayLike]) -> "BaseWrapper":
data = self._transform_input(X, y, sample_weight, initialize=True)
...
return self

def partial_fit(self, X: Input, y: Optional[ArrayLike], sample_weight: Optional[ArrayLike]) -> "BaseWrapper":
initialize = hasattr(self, "model_") # or other check
data = self._transform_input(X, y, sample_weight, initialize=initialize)
...
return self

def predict(self, X: Input) -> ndarray:
data = self._transform_input(X, y, sample_weight, initialize=False)
y_pred = self.model_.predict(data)
for t in itertools.chain(
reversed(self.dataset_transformers_),
reversed(self.array_transformers_)
):
y_proba = t.transform_output(y_proba, None)
return y_proba
```

Classifiers (as well as LTR or other learning problems where `y` is not the raw prediction probabilties) can use this interface
to convert probabilities to class predictions, or to modify the probabilites themselves. Two small examples:

```python
class BinaryPredictionReshaper(DatasetTransformer):

def transform_input(self, data: tf.data.Dataset, *, initialize: bool) -> tf.data.Dataset:
if initialize:
self._is_binary = ... # call type_of_target or other check
return data

def transform_output(self, y_pred_proba: np.ndarray, y: Union[np.ndarray, None]) -> Tuple[np.ndarray, Union[np.ndarray, None]]:
shp = y_pred_proba.shape
if self._is_binary and len(shp) == 1 or len(shp) == 2 and shp[1] == 1:
# single sigmoid output, reshape to a 2D array of predicitons, which is what sklearn expects
y_pred_proba = np.column_stack([1-y_pred_proba, y_pred_proba])
return y_pred_proba, y

class ClassifierPredictionDecoder(DatasetTransformer):

def transform_output(self, y_pred_proba: np.ndarray, y: Union[np.ndarray, None]) -> Tuple[np.ndarray, Union[np.ndarray, None]]:
if y is None:
y = np.argmax(y_pred_proba, axis=1)
return y_pred_proba, y
```


## Example Implementations

### One-hot encode targets

This moves one-hot encoding out of `ClassifierLabelEncoder`.
This means the transformation can be applied to any input, including tf.data.Datasets.
Performance should also be better because TensorFlow lazily applies and optimizes `map` operations on `Dataset`.

```python
def _is_ohe_dataset(data: tf.data.Dataset) -> bool:
target_shape = data.element_spec[1].shape
if len(target_shape) != 2 or target_shape[1] == 1:
return False # needs to be 2D with >=2 columns to be one-hot encoded
y = next(iter(data))[1]
return tf.math.reduce_all(tf.math.reduce_sum(y, axis=1) == 1, axis=0).numpy() # all rows add up to 1


class ClassifierOneHotEncoder(DatasetTransformer):
"""One-hot encode the target if the loss function is categorical crossentropy.
"""

def transform_input(self, data: tf.data.Dataset, *, initialize: bool) -> tf.data.Dataset:
if initialize:
loss = getattr(self.model, "loss", None)
loss_requires_ohe = False if loss is None else is_categorical_crossentropy(loss)
self._needs_ohe = loss_requires_ohe and not _is_ohe_dataset(data)
if self._needs_ohe:
user_supplied_classes = getattr(self.model, "classes_", None)
self.classes_ = user_supplied_classes if user_supplied_classes is not None else tf.unique(next(iter(data))[1])[1]
if self._needs_ohe:
data = data.map(lambda X, y, sample_weight: (X, tf.one_hot(y, indices=self.classes_, depth=len(self.classes_)), sample_weight))
return data

def transform_output(self, y_pred_proba: np.ndarray, y: Union[np.ndarray, None]) -> Tuple[np.ndarray, Union[np.ndarray, None]]:
if y is None and self._needs_ohe:
y = np.argmax(y_pred_proba, axis=1)
return y_pred_proba, y
```

The we add this to the default list of transformers for classifiers:

```python
class KerasClassifier:

def __init__(
self,
dataset_transformers = (DatasetOneHotEncoder,)
...
):
...
```

## Validate array-like data

This mirrors the current implementation of `BaseWrapper._validate_data`.

Moving that check to this interface would:
1. Only apply these checks array-like inputs.
2. Move the implementation from a hardcoded private method to be stand-alone (making it easier to test, etc.).
3. Make usage of these checks both composable and optional.

This implementation could also be split up:
1. Transform X & y into arrays.
2. Check shapes, styles, etc. as tf.data.Dataset

```python
def _check_array_dtype(arr: ArrayLike, force_numeric: bool):
if not isinstance(arr, np.ndarray):
return _check_array_dtype(np.asarray(arr), force_numeric=force_numeric)
elif (
arr.dtype.kind in ("O", "U", "S" or not force_numeric
):
return None # check_array won't do any casting with dtype=None
else:
# default to TFs backend float type
# instead of float64 (sklearn's default)
return tf.keras.backend.floatx()


class ValidateFeaturesArray(ArrayTransformer):

def transform_input(self, X: ArrayLike, y: Optional[ArrayLike], sample_weight: Optional[ArrayLike], *, initialize: bool) -> Data:
X = check_array(
X,
allow_nd=True,
ensure_2d=True,
dtype=_check_array_dtype(X, force_numeric=True)
)
n_feautres_in_ = X.shape[1]
if initialize:
self.n_feautres_in_ = n_feautres_in_
else:
if self.n_feautres_in_ != n_feautres_in_:
raise ValueError(
f"Expected X to have {self.n_feautres_in_} features, but got {n_feautres_in_} features"
)
return X, y, sample_weight

def get_meta(self) -> Dict[str, Any]:
return {"n_features_in_": self.n_feautres_in_}


class ValidateClassifierTargetArray(ArrayTransformer):

def transform_input(self, X: ArrayLike, y: Optional[ArrayLike], sample_weight: Optional[ArrayLike], *, initialize: bool) -> Data:
if y is not None:
y = check_array(
y,
ensure_2d=False,
allow_nd=False,
dtype=_check_array_dtype(y, force_numeric=False),
)
classes_ = np.unique(y)
if initialize:
self.classes_ = classes_
else:
if self.classes_ != classes_:
raise ValueError(
f"Expected y to have {self.classes_} classes, but got {classes_} classes"
)
return X, y, sample_weight

def get_meta(self) -> Dict[str, Any]:
return {"classes_": self.classes_}


class ValidateRegressorTargetArray(ArrayTransformer):

def transform_input(self, X: ArrayLike, y: Optional[ArrayLike], sample_weight: Optional[ArrayLike], *, initialize: bool) -> Data:
if y is not None:
y = check_array(
y,
ensure_2d=False,
allow_nd=False,
dtype=_check_array_dtype(y, force_numeric=True),
)
return X, y, sample_weight


class ValidateSampleWeight(ArrayTransformer):

def transform_input(self, X: ArrayLike, y: Optional[ArrayLike], sample_weight: Optional[ArrayLike], *, initialize: bool) -> Data:
if isinstance(sample_weight, numbers.Number):
sample_weight = np.full(shape=(len(X),), fill_value=sample_weight)
if sample_weight is not None:
sample_weight = check_array(
sample_weight,
accept_sparse=False,
ensure_2d=False,
dtype=tf.keras.backend.floatx(),
copy=False,
)
if sample_weight.ndim != 1:
raise ValueError("Sample weights must be 1D array or scalar")
if np.all(sample_weight == 0):
raise ValueError(
"No training samples had any weight; only zeros were passed in sample_weight."
" That means there's nothing to train on by definition, so training can not be completed."
)
return X, y, sample_weight

```

Now we can add this to the default array transformers:

```python
class KerasClassifier:
def __init__(
self,
array_transformers = (ClassifierInputValidator,)
):
...


class KerasRegressor:
def __init__(
self,
array_transformers = (RegressorInputValidator,)
):
...
```

## Issues this can potentially resolve
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These issues seem to be around modular validation/transforming. This RFC proposes one solution. What are other solutions? Why does this RFC represent the best solution?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are other solutions?

This is, of course, a great question. I think general solutions would be variations of the same idea, perhaps less structured (eg. hardcoding that tf.data.Dataset inputs should skip BaseWrapper._validate_data). I'd have to think a bit to see if I can come up with any other structured approaches that might also solve the issues.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not have two Scikit-learn transformers, one for validation and one to change the data?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would need:

  1. Interface for validation using array-like data
  2. Interface for data changing using array-like data
  3. Interface for validation using Dataset
  4. Interface for data changing using array-like data

I think this is too many interfaces. The only difference between a "validation" and a "transformation" is that the transformation needs to return the data, but the validation does not. So by having the validation return the data we can collapse those two concepts into one. Validation transformers simply inspect the data but do not modify it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What use case/issue motivates having separate classes for Dataset/array-like inputs? Why not collapse it into one class?

Why is 4 interfaces "too many"? Why is the current framework with target_encoder_ and feature_encoder_ not sufficient?

I'm asking questions to get answers encoded in the RFC (an RFC is a really good idea to encode these decisions).


- #167
adriangb marked this conversation as resolved.
Show resolved Hide resolved
- #106 / #143
adriangb marked this conversation as resolved.
Show resolved Hide resolved
- #209
adriangb marked this conversation as resolved.
Show resolved Hide resolved
- #148 (by allowing users to implement it)
adriangb marked this conversation as resolved.
Show resolved Hide resolved
- #111
adriangb marked this conversation as resolved.
Show resolved Hide resolved
- #167
adriangb marked this conversation as resolved.
Show resolved Hide resolved

## Outstanding questions

Some outstanding issues:

1. Validations that require a model to be built. For example, checking the model's output shape (#106, #143).
2. Transformations involving not just the data but other parameters passed to Keras' `fit`/`predict` (#167).