Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate inputs to Keras model #106

Open
stsievert opened this issue Oct 12, 2020 · 2 comments · May be fixed by #143
Open

Validate inputs to Keras model #106

stsievert opened this issue Oct 12, 2020 · 2 comments · May be fixed by #143

Comments

@stsievert
Copy link
Collaborator

stsievert commented Oct 12, 2020

Currently, the number of outputs of the Keras model is checked:

if self.model_n_outputs_ != len(self.model_.outputs):
raise RuntimeError(

It'd be nice if the number of inputs could be checked too (and their shape/dtype). This should be possible according to the docs, which says a Keras Model has an .inputs attribute: https://keras.io/api/models/model/

@stsievert
Copy link
Collaborator Author

stsievert commented Oct 12, 2020

This would be an internal refactor; I don't think the user would notice any changes (at least not at first). The advantage for the user would be that this would more lightly wrap the Keras API; a lot of the code in #88's _validate_data could be removed. In #88, _validate_data does the following:

  1. Runs Scikit-learn's check_X_y
  2. Casts to appropriate dtype.
  3. Checks X/y shape
  4. Checks the number of features in X.

I think (3) and (4) could be moved to checking the Keras input (after the user input has been run through the target/feature transformers). I think (2) can be enhanced with closer integration to the Keras model; there might be better error messages with np.can_cast(X.dtype, self.model_.layers[0].dtype).

@stsievert stsievert changed the title Check number of inputs to Keras model Validate inputs to Keras model Oct 14, 2020
@adriangb
Copy link
Owner

I am just now seeing this last comment. This sounds interesting, I am +1 on anything that reduces code complexity.

That said, I think this would be a fundamental departure from how Scikit-Learn does these validations? _validate_data is meant to validate that data passed after the first fit/initialization matches what the estimator knows. I'm guessing sklearn does it like this mainly to avoid cryptic failures within estimators, but like you say we may be able to check the "source of truth" (i.e. the Keras model) to achieve the same effect. Got to think about it a bit or see it implemented.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants