Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: don't convert sparse matrix formats #282

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions scikeras/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,24 @@
import warnings

from collections import defaultdict
from typing import Any, Callable, Dict, Iterable, List, Mapping, Set, Tuple, Type, Union
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Mapping,
Set,
Tuple,
Type,
Union,
cast,
)

import numpy as np
import tensorflow as tf

from scipy.sparse import isspmatrix, lil_matrix
from scipy.sparse import csr_matrix, isspmatrix, lil_matrix, spmatrix
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
from sklearn.exceptions import NotFittedError
from sklearn.metrics import accuracy_score as sklearn_accuracy_score
Expand Down Expand Up @@ -651,9 +663,19 @@ def _check_array_dtype(arr, force_numeric):
)
if X is not None:
if isspmatrix(X):
# TensorFlow does not support several of SciPy's sparse formats
# use SciPy to reformat here so at least the cost is known
X = lil_matrix(X) # no-copy reformat
# TensorFlow requires sparse matrices to be sorted in row-major order
# see https://www.tensorflow.org/api_docs/python/tf/sparse/SparseTensor
# It supports conversion of "lil", "dok" and "bsr" (empirically checked)
Xs = cast(spmatrix, X)
if Xs.getformat() == "csr":
Xs_csr = cast(csr_matrix, Xs)
Xs_csr.sort_indices()
elif Xs.getformat() not in ("dok", "lil", "bsr"):
raise ValueError(
"TensorFlow does not support the sparse matrix format"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What error does TF/Keras raise if a matrix of this format is passed?

Copy link
Owner Author

@adriangb adriangb Jul 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


tensorflow.python.framework.errors_impl.InvalidArgumentError: indices[2] = [37,0] is out of order. Many sparse ops require sorted indices.
E         Use `tf.sparse.reorder` to create a correctly ordered copy.
Traceback
    @traceback_utils.filter_traceback
    def fit(self,
            x=None,
            y=None,
            batch_size=None,
            epochs=1,
            verbose='auto',
            callbacks=None,
            validation_split=0.,
            validation_data=None,
            shuffle=True,
            class_weight=None,
            sample_weight=None,
            initial_epoch=0,
            steps_per_epoch=None,
            validation_steps=None,
            validation_batch_size=None,
            validation_freq=1,
            max_queue_size=10,
            workers=1,
            use_multiprocessing=False):
      """Trains the model for a fixed number of epochs (iterations on a dataset).
    
      Args:
          x: Input data. It could be:
            - A Numpy array (or array-like), or a list of arrays
              (in case the model has multiple inputs).
            - A TensorFlow tensor, or a list of tensors
              (in case the model has multiple inputs).
            - A dict mapping input names to the corresponding array/tensors,
              if the model has named inputs.
            - A `tf.data` dataset. Should return a tuple
              of either `(inputs, targets)` or
              `(inputs, targets, sample_weights)`.
            - A generator or `keras.utils.Sequence` returning `(inputs, targets)`
              or `(inputs, targets, sample_weights)`.
            - A `tf.keras.utils.experimental.DatasetCreator`, which wraps a
              callable that takes a single argument of type
              `tf.distribute.InputContext`, and returns a `tf.data.Dataset`.
              `DatasetCreator` should be used when users prefer to specify the
              per-replica batching and sharding logic for the `Dataset`.
              See `tf.keras.utils.experimental.DatasetCreator` doc for more
              information.
            A more detailed description of unpacking behavior for iterator types
            (Dataset, generator, Sequence) is given below. If using
            `tf.distribute.experimental.ParameterServerStrategy`, only
            `DatasetCreator` type is supported for `x`.
          y: Target data. Like the input data `x`,
            it could be either Numpy array(s) or TensorFlow tensor(s).
            It should be consistent with `x` (you cannot have Numpy inputs and
            tensor targets, or inversely). If `x` is a dataset, generator,
            or `keras.utils.Sequence` instance, `y` should
            not be specified (since targets will be obtained from `x`).
          batch_size: Integer or `None`.
              Number of samples per gradient update.
              If unspecified, `batch_size` will default to 32.
              Do not specify the `batch_size` if your data is in the
              form of datasets, generators, or `keras.utils.Sequence` instances
              (since they generate batches).
          epochs: Integer. Number of epochs to train the model.
              An epoch is an iteration over the entire `x` and `y`
              data provided
              (unless the `steps_per_epoch` flag is set to
              something other than None).
              Note that in conjunction with `initial_epoch`,
              `epochs` is to be understood as "final epoch".
              The model is not trained for a number of iterations
              given by `epochs`, but merely until the epoch
              of index `epochs` is reached.
          verbose: 'auto', 0, 1, or 2. Verbosity mode.
              0 = silent, 1 = progress bar, 2 = one line per epoch.
              'auto' defaults to 1 for most cases, but 2 when used with
              `ParameterServerStrategy`. Note that the progress bar is not
              particularly useful when logged to a file, so verbose=2 is
              recommended when not running interactively (eg, in a production
              environment).
          callbacks: List of `keras.callbacks.Callback` instances.
              List of callbacks to apply during training.
              See `tf.keras.callbacks`. Note `tf.keras.callbacks.ProgbarLogger`
              and `tf.keras.callbacks.History` callbacks are created automatically
              and need not be passed into `model.fit`.
              `tf.keras.callbacks.ProgbarLogger` is created or not based on
              `verbose` argument to `model.fit`.
              Callbacks with batch-level calls are currently unsupported with
              `tf.distribute.experimental.ParameterServerStrategy`, and users are
              advised to implement epoch-level calls instead with an appropriate
              `steps_per_epoch` value.
          validation_split: Float between 0 and 1.
              Fraction of the training data to be used as validation data.
              The model will set apart this fraction of the training data,
              will not train on it, and will evaluate
              the loss and any model metrics
              on this data at the end of each epoch.
              The validation data is selected from the last samples
              in the `x` and `y` data provided, before shuffling. This argument is
              not supported when `x` is a dataset, generator or
              `keras.utils.Sequence` instance.
              If both `validation_data` and `validation_split` are provided,
              `validation_data` will override `validation_split`.
              `validation_split` is not yet supported with
              `tf.distribute.experimental.ParameterServerStrategy`.
          validation_data: Data on which to evaluate
              the loss and any model metrics at the end of each epoch.
              The model will not be trained on this data. Thus, note the fact
              that the validation loss of data provided using `validation_split`
              or `validation_data` is not affected by regularization layers like
              noise and dropout.
              `validation_data` will override `validation_split`.
              `validation_data` could be:
                - A tuple `(x_val, y_val)` of Numpy arrays or tensors.
                - A tuple `(x_val, y_val, val_sample_weights)` of NumPy arrays.
                - A `tf.data.Dataset`.
                - A Python generator or `keras.utils.Sequence` returning
                `(inputs, targets)` or `(inputs, targets, sample_weights)`.
              `validation_data` is not yet supported with
              `tf.distribute.experimental.ParameterServerStrategy`.
          shuffle: Boolean (whether to shuffle the training data
              before each epoch) or str (for 'batch'). This argument is ignored
              when `x` is a generator or an object of tf.data.Dataset.
              'batch' is a special option for dealing
              with the limitations of HDF5 data; it shuffles in batch-sized
              chunks. Has no effect when `steps_per_epoch` is not `None`.
          class_weight: Optional dictionary mapping class indices (integers)
              to a weight (float) value, used for weighting the loss function
              (during training only).
              This can be useful to tell the model to
              "pay more attention" to samples from
              an under-represented class.
          sample_weight: Optional Numpy array of weights for
              the training samples, used for weighting the loss function
              (during training only). You can either pass a flat (1D)
              Numpy array with the same length as the input samples
              (1:1 mapping between weights and samples),
              or in the case of temporal data,
              you can pass a 2D array with shape
              `(samples, sequence_length)`,
              to apply a different weight to every timestep of every sample. This
              argument is not supported when `x` is a dataset, generator, or
             `keras.utils.Sequence` instance, instead provide the sample_weights
              as the third element of `x`.
          initial_epoch: Integer.
              Epoch at which to start training
              (useful for resuming a previous training run).
          steps_per_epoch: Integer or `None`.
              Total number of steps (batches of samples)
              before declaring one epoch finished and starting the
              next epoch. When training with input tensors such as
              TensorFlow data tensors, the default `None` is equal to
              the number of samples in your dataset divided by
              the batch size, or 1 if that cannot be determined. If x is a
              `tf.data` dataset, and 'steps_per_epoch'
              is None, the epoch will run until the input dataset is exhausted.
              When passing an infinitely repeating dataset, you must specify the
              `steps_per_epoch` argument. If `steps_per_epoch=-1` the training
              will run indefinitely with an infinitely repeating dataset.
              This argument is not supported with array inputs.
              When using `tf.distribute.experimental.ParameterServerStrategy`:
                * `steps_per_epoch=None` is not supported.
          validation_steps: Only relevant if `validation_data` is provided and
              is a `tf.data` dataset. Total number of steps (batches of
              samples) to draw before stopping when performing validation
              at the end of every epoch. If 'validation_steps' is None, validation
              will run until the `validation_data` dataset is exhausted. In the
              case of an infinitely repeated dataset, it will run into an
              infinite loop. If 'validation_steps' is specified and only part of
              the dataset will be consumed, the evaluation will start from the
              beginning of the dataset at each epoch. This ensures that the same
              validation samples are used every time.
          validation_batch_size: Integer or `None`.
              Number of samples per validation batch.
              If unspecified, will default to `batch_size`.
              Do not specify the `validation_batch_size` if your data is in the
              form of datasets, generators, or `keras.utils.Sequence` instances
              (since they generate batches).
          validation_freq: Only relevant if validation data is provided. Integer
              or `collections.abc.Container` instance (e.g. list, tuple, etc.).
              If an integer, specifies how many training epochs to run before a
              new validation run is performed, e.g. `validation_freq=2` runs
              validation every 2 epochs. If a Container, specifies the epochs on
              which to run validation, e.g. `validation_freq=[1, 2, 10]` runs
              validation at the end of the 1st, 2nd, and 10th epochs.
          max_queue_size: Integer. Used for generator or `keras.utils.Sequence`
              input only. Maximum size for the generator queue.
              If unspecified, `max_queue_size` will default to 10.
          workers: Integer. Used for generator or `keras.utils.Sequence` input
              only. Maximum number of processes to spin up
              when using process-based threading. If unspecified, `workers`
              will default to 1.
          use_multiprocessing: Boolean. Used for generator or
              `keras.utils.Sequence` input only. If `True`, use process-based
              threading. If unspecified, `use_multiprocessing` will default to
              `False`. Note that because this implementation relies on
              multiprocessing, you should not pass non-picklable arguments to
              the generator as they can't be passed easily to children processes.
    
      Unpacking behavior for iterator-like inputs:
          A common pattern is to pass a tf.data.Dataset, generator, or
        tf.keras.utils.Sequence to the `x` argument of fit, which will in fact
        yield not only features (x) but optionally targets (y) and sample weights.
        Keras requires that the output of such iterator-likes be unambiguous. The
        iterator should return a tuple of length 1, 2, or 3, where the optional
        second and third elements will be used for y and sample_weight
        respectively. Any other type provided will be wrapped in a length one
        tuple, effectively treating everything as 'x'. When yielding dicts, they
        should still adhere to the top-level tuple structure.
        e.g. `({"x0": x0, "x1": x1}, y)`. Keras will not attempt to separate
        features, targets, and weights from the keys of a single dict.
          A notable unsupported data type is the namedtuple. The reason is that
        it behaves like both an ordered datatype (tuple) and a mapping
        datatype (dict). So given a namedtuple of the form:
            `namedtuple("example_tuple", ["y", "x"])`
        it is ambiguous whether to reverse the order of the elements when
        interpreting the value. Even worse is a tuple of the form:
            `namedtuple("other_tuple", ["x", "y", "z"])`
        where it is unclear if the tuple was intended to be unpacked into x, y,
        and sample_weight or passed through as a single element to `x`. As a
        result the data processing code will simply raise a ValueError if it
        encounters a namedtuple. (Along with instructions to remedy the issue.)
    
      Returns:
          A `History` object. Its `History.history` attribute is
          a record of training loss values and metrics values
          at successive epochs, as well as validation loss values
          and validation metrics values (if applicable).
    
      Raises:
          RuntimeError: 1. If the model was never compiled or,
          2. If `model.fit` is  wrapped in `tf.function`.
    
          ValueError: In case of mismatch between the provided input data
              and what the model expects or when the input data is empty.
      """
      base_layer.keras_api_gauge.get_cell('fit').set(True)
      # Legacy graph support is contained in `training_v1.Model`.
      version_utils.disallow_legacy_graph('Model', 'fit')
      self._assert_compile_was_called()
      self._check_call_args('fit')
      _disallow_inside_tf_function('fit')
    
      verbose = _get_verbosity(verbose, self.distribute_strategy)
    
      if validation_split and validation_data is None:
        # Create the validation data using the training data. Only supported for
        # `Tensor` and `NumPy` input.
        (x, y, sample_weight), validation_data = (
            data_adapter.train_validation_split(
                (x, y, sample_weight), validation_split=validation_split))
    
      if validation_data:
        val_x, val_y, val_sample_weight = (
            data_adapter.unpack_x_y_sample_weight(validation_data))
    
      if self.distribute_strategy._should_use_with_coordinator:  # pylint: disable=protected-access
        self._cluster_coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
            self.distribute_strategy)
    
      with self.distribute_strategy.scope(), \
           training_utils.RespectCompiledTrainableState(self):
        # Creates a `tf.data.Dataset` and handles batch and epoch iteration.
>       data_handler = data_adapter.get_data_handler(
            x=x,
            y=y,
            sample_weight=sample_weight,
            batch_size=batch_size,
            steps_per_epoch=steps_per_epoch,
            initial_epoch=initial_epoch,
            epochs=epochs,
            shuffle=shuffle,
            class_weight=class_weight,
            max_queue_size=max_queue_size,
            workers=workers,
            use_multiprocessing=use_multiprocessing,
            model=self,
            steps_per_execution=self._steps_per_execution)

.venv/lib/python3.10/site-packages/keras/engine/training.py:1358: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

args = ()
kwargs = {'batch_size': 1000, 'class_weight': None, 'epochs': 1, 'initial_epoch': 0, ...}

    def get_data_handler(*args, **kwargs):
      if getattr(kwargs["model"], "_cluster_coordinator", None):
        return _ClusterCoordinatorDataHandler(*args, **kwargs)
>     return DataHandler(*args, **kwargs)

.venv/lib/python3.10/site-packages/keras/engine/data_adapter.py:1401: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <keras.engine.data_adapter.DataHandler object at 0x7f3cb866c700>
x = <40x3 sparse matrix of type '<class 'numpy.float32'>'
	with 53 stored elements (18 diagonals) in DIAgonal format>
y = array([[2.],
       [2.],
       [3.],
       [2.],
       [1.],
       [2.],
       [0.],
       [1.],
       [2.],
 ...       [0.],
       [2.],
       [0.],
       [3.],
       [3.],
       [2.],
       [1.],
       [0.]], dtype=float32)
sample_weight = None, batch_size = 1000, steps_per_epoch = None
initial_epoch = 0, epochs = 1, shuffle = True, class_weight = None
max_queue_size = 10, workers = 1, use_multiprocessing = False
model = <keras.engine.functional.Functional object at 0x7f3cb8659cf0>
steps_per_execution = <tf.Variable 'Variable:0' shape=() dtype=int64, numpy=1>
distribute = True

    def __init__(self,
                 x,
                 y=None,
                 sample_weight=None,
                 batch_size=None,
                 steps_per_epoch=None,
                 initial_epoch=0,
                 epochs=1,
                 shuffle=False,
                 class_weight=None,
                 max_queue_size=10,
                 workers=1,
                 use_multiprocessing=False,
                 model=None,
                 steps_per_execution=None,
                 distribute=True):
      """Initializes a `DataHandler`.
    
      Arguments:
        x: See `Model.fit`.
        y: See `Model.fit`.
        sample_weight: See `Model.fit`.
        batch_size: See `Model.fit`.
        steps_per_epoch: See `Model.fit`.
        initial_epoch: See `Model.fit`.
        epochs: See `Model.fit`.
        shuffle: See `Model.fit`.
        class_weight: See `Model.fit`.
        max_queue_size: See `Model.fit`.
        workers: See `Model.fit`.
        use_multiprocessing: See `Model.fit`.
        model: The `Model` instance. Needed in order to correctly `build` the
          `Model` using generator-like inputs (see `GeneratorDataAdapter`).
        steps_per_execution: See `Model.compile`.
        distribute: Whether to distribute the `tf.dataset`.
          `PreprocessingLayer.adapt` does not support distributed datasets,
          `Model` should always set this to `True`.
      """
    
      self._initial_epoch = initial_epoch
      self._initial_step = 0
      self._epochs = epochs
      self._insufficient_data = False
      self._model = model
    
      # `steps_per_execution_value` is the cached initial value.
      # `steps_per_execution` is mutable and may be changed by the DataAdapter
      # to handle partial executions.
      if steps_per_execution is None:
        self._steps_per_execution = tf.Variable(1)
      else:
        self._steps_per_execution = steps_per_execution
    
      adapter_cls = select_data_adapter(x, y)
>     self._adapter = adapter_cls(
          x,
          y,
          batch_size=batch_size,
          steps=steps_per_epoch,
          epochs=epochs - initial_epoch,
          sample_weights=sample_weight,
          shuffle=shuffle,
          max_queue_size=max_queue_size,
          workers=workers,
          use_multiprocessing=use_multiprocessing,
          distribution_strategy=tf.distribute.get_strategy(),
          model=model)

.venv/lib/python3.10/site-packages/keras/engine/data_adapter.py:1151: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <keras.engine.data_adapter.CompositeTensorDataAdapter object at 0x7f3cb8659b10>
x = <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7f3cd4112770>
y = <tf.Tensor: shape=(40, 1), dtype=float32, numpy=
array([[2.],
       [2.],
       [3.],
       [2.],
       [1.],
    ...      [0.],
       [2.],
       [0.],
       [3.],
       [3.],
       [2.],
       [1.],
       [0.]], dtype=float32)>
sample_weights = None, sample_weight_modes = None, batch_size = 1000
steps = None, shuffle = True
kwargs = {'distribution_strategy': <tensorflow.python.distribute.distribute_lib._DefaultDistributionStrategy object at 0x7f3cd8441390>, 'epochs': 1, 'max_queue_size': 10, 'model': <keras.engine.functional.Functional object at 0x7f3cb8659cf0>, ...}
_ = False
inputs = (<tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7f3cd4112770>, <tf.Tensor: shape=(40, 1), dtype=f...     [0.],
       [2.],
       [0.],
       [3.],
       [3.],
       [2.],
       [1.],
       [0.]], dtype=float32)>)

    def __init__(self,
                 x,
                 y=None,
                 sample_weights=None,
                 sample_weight_modes=None,
                 batch_size=None,
                 steps=None,
                 shuffle=False,
                 **kwargs):
      super(CompositeTensorDataAdapter, self).__init__(x, y, **kwargs)
      x, y, sample_weights = _process_tensorlike((x, y, sample_weights))
      sample_weight_modes = broadcast_sample_weight_modes(
          sample_weights, sample_weight_modes)
    
      # If sample_weights are not specified for an output use 1.0 as weights.
      (sample_weights, _, _) = training_utils.handle_partial_sample_weights(
          y, sample_weights, sample_weight_modes, check_all_flat=True)
    
      inputs = pack_x_y_sample_weight(x, y, sample_weights)
    
>     dataset = tf.data.Dataset.from_tensor_slices(inputs)

.venv/lib/python3.10/site-packages/keras/engine/data_adapter.py:587: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

tensors = (<tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7f3cd4112770>, <tf.Tensor: shape=(40, 1), dtype=f...     [0.],
       [2.],
       [0.],
       [3.],
       [3.],
       [2.],
       [1.],
       [0.]], dtype=float32)>)
name = None

    @staticmethod
    def from_tensor_slices(tensors, name=None):
      """Creates a `Dataset` whose elements are slices of the given tensors.
    
      The given tensors are sliced along their first dimension. This operation
      preserves the structure of the input tensors, removing the first dimension
      of each tensor and using it as the dataset dimension. All input tensors
      must have the same size in their first dimensions.
    
      >>> # Slicing a 1D tensor produces scalar tensor elements.
      >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
      >>> list(dataset.as_numpy_iterator())
      [1, 2, 3]
    
      >>> # Slicing a 2D tensor produces 1D tensor elements.
      >>> dataset = tf.data.Dataset.from_tensor_slices([[1, 2], [3, 4]])
      >>> list(dataset.as_numpy_iterator())
      [array([1, 2], dtype=int32), array([3, 4], dtype=int32)]
    
      >>> # Slicing a tuple of 1D tensors produces tuple elements containing
      >>> # scalar tensors.
      >>> dataset = tf.data.Dataset.from_tensor_slices(([1, 2], [3, 4], [5, 6]))
      >>> list(dataset.as_numpy_iterator())
      [(1, 3, 5), (2, 4, 6)]
    
      >>> # Dictionary structure is also preserved.
      >>> dataset = tf.data.Dataset.from_tensor_slices({"a": [1, 2], "b": [3, 4]})
      >>> list(dataset.as_numpy_iterator()) == [{'a': 1, 'b': 3},
      ...                                       {'a': 2, 'b': 4}]
      True
    
      >>> # Two tensors can be combined into one Dataset object.
      >>> features = tf.constant([[1, 3], [2, 1], [3, 3]]) # ==> 3x2 tensor
      >>> labels = tf.constant(['A', 'B', 'A']) # ==> 3x1 tensor
      >>> dataset = Dataset.from_tensor_slices((features, labels))
      >>> # Both the features and the labels tensors can be converted
      >>> # to a Dataset object separately and combined after.
      >>> features_dataset = Dataset.from_tensor_slices(features)
      >>> labels_dataset = Dataset.from_tensor_slices(labels)
      >>> dataset = Dataset.zip((features_dataset, labels_dataset))
      >>> # A batched feature and label set can be converted to a Dataset
      >>> # in similar fashion.
      >>> batched_features = tf.constant([[[1, 3], [2, 3]],
      ...                                 [[2, 1], [1, 2]],
      ...                                 [[3, 3], [3, 2]]], shape=(3, 2, 2))
      >>> batched_labels = tf.constant([['A', 'A'],
      ...                               ['B', 'B'],
      ...                               ['A', 'B']], shape=(3, 2, 1))
      >>> dataset = Dataset.from_tensor_slices((batched_features, batched_labels))
      >>> for element in dataset.as_numpy_iterator():
      ...   print(element)
      (array([[1, 3],
             [2, 3]], dtype=int32), array([[b'A'],
             [b'A']], dtype=object))
      (array([[2, 1],
             [1, 2]], dtype=int32), array([[b'B'],
             [b'B']], dtype=object))
      (array([[3, 3],
             [3, 2]], dtype=int32), array([[b'A'],
             [b'B']], dtype=object))
    
      Note that if `tensors` contains a NumPy array, and eager execution is not
      enabled, the values will be embedded in the graph as one or more
      `tf.constant` operations. For large datasets (> 1 GB), this can waste
      memory and run into byte limits of graph serialization. If `tensors`
      contains one or more large NumPy arrays, consider the alternative described
      in [this guide](
      https://tensorflow.org/guide/data#consuming_numpy_arrays).
    
      Args:
        tensors: A dataset element, whose components have the same first
          dimension. Supported values are documented
          [here](https://www.tensorflow.org/guide/data#dataset_structure).
        name: (Optional.) A name for the tf.data operation.
    
      Returns:
        Dataset: A `Dataset`.
      """
>     return TensorSliceDataset(tensors, name=name)

.venv/lib/python3.10/site-packages/tensorflow/python/data/ops/dataset_ops.py:809: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <[AttributeError("'TensorSliceDataset' object has no attribute '_structure'") raised in repr()] TensorSliceDataset object at 0x7f3cb865ad70>
element = (<tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7f3cb866dbd0>, <tf.Tensor: shape=(40, 1), dtype=f...     [0.],
       [2.],
       [0.],
       [3.],
       [3.],
       [2.],
       [1.],
       [0.]], dtype=float32)>)
is_files = False, name = None

    def __init__(self, element, is_files=False, name=None):
      """See `Dataset.from_tensor_slices()` for details."""
      element = structure.normalize_element(element)
      batched_spec = structure.type_spec_from_value(element)
>     self._tensors = structure.to_batched_tensor_list(batched_spec, element)

.venv/lib/python3.10/site-packages/tensorflow/python/data/ops/dataset_ops.py:4553: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

element_spec = (SparseTensorSpec(TensorShape([40, 3]), tf.float32), TensorSpec(shape=(40, 1), dtype=tf.float32, name=None))
element = (<tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7f3cb866dbd0>, <tf.Tensor: shape=(40, 1), dtype=f...     [0.],
       [2.],
       [0.],
       [3.],
       [3.],
       [2.],
       [1.],
       [0.]], dtype=float32)>)

    def to_batched_tensor_list(element_spec, element):
      """Returns a tensor list representation of the element.
    
      Args:
        element_spec: A nested structure of `tf.TypeSpec` objects representing to
          element type specification.
        element: The element to convert to tensor list representation.
    
      Returns:
        A tensor list representation of `element`.
    
      Raises:
        ValueError: If `element_spec` and `element` do not have the same number of
          elements or if the two structures are not nested in the same way or the
          rank of any of the tensors in the tensor list representation is 0.
        TypeError: If `element_spec` and `element` differ in the type of sequence
          in any of their substructures.
      """
    
      # pylint: disable=protected-access
      # pylint: disable=g-long-lambda
>     return _to_tensor_list_helper(
          lambda state, spec, component: state + spec._to_batched_tensor_list(
              component), element_spec, element)

.venv/lib/python3.10/site-packages/tensorflow/python/data/util/structure.py:363: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

encode_fn = <function to_batched_tensor_list.<locals>.<lambda> at 0x7f3cd4177880>
element_spec = (SparseTensorSpec(TensorShape([40, 3]), tf.float32), TensorSpec(shape=(40, 1), dtype=tf.float32, name=None))
element = (<tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7f3cb866dbd0>, <tf.Tensor: shape=(40, 1), dtype=f...     [0.],
       [2.],
       [0.],
       [3.],
       [3.],
       [2.],
       [1.],
       [0.]], dtype=float32)>)

    def _to_tensor_list_helper(encode_fn, element_spec, element):
      """Returns a tensor list representation of the element.
    
      Args:
        encode_fn: Method that constructs a tensor list representation from the
          given element spec and element.
        element_spec: A nested structure of `tf.TypeSpec` objects representing to
          element type specification.
        element: The element to convert to tensor list representation.
    
      Returns:
        A tensor list representation of `element`.
    
      Raises:
        ValueError: If `element_spec` and `element` do not have the same number of
          elements or if the two structures are not nested in the same way.
        TypeError: If `element_spec` and `element` differ in the type of sequence
          in any of their substructures.
      """
    
      nest.assert_same_structure(element_spec, element)
    
      def reduce_fn(state, value):
        spec, component = value
        return encode_fn(state, spec, component)
    
>     return functools.reduce(
          reduce_fn, zip(nest.flatten(element_spec), nest.flatten(element)), [])

.venv/lib/python3.10/site-packages/tensorflow/python/data/util/structure.py:338: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

state = []
value = (SparseTensorSpec(TensorShape([40, 3]), tf.float32), <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7f3cb866dbd0>)

    def reduce_fn(state, value):
      spec, component = value
>     return encode_fn(state, spec, component)

.venv/lib/python3.10/site-packages/tensorflow/python/data/util/structure.py:336: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

state = [], spec = SparseTensorSpec(TensorShape([40, 3]), tf.float32)
component = <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7f3cb866dbd0>

>   lambda state, spec, component: state + spec._to_batched_tensor_list(
        component), element_spec, element)

.venv/lib/python3.10/site-packages/tensorflow/python/data/util/structure.py:364: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = SparseTensorSpec(TensorShape([40, 3]), tf.float32)
value = <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7f3cb866dbd0>

    def _to_batched_tensor_list(self, value):
      dense_shape = tensor_util.constant_value_as_shape(value.dense_shape)
      if self._shape.merge_with(dense_shape).ndims == 0:
        raise ValueError(
            "Unbatching a sparse tensor is only supported for rank >= 1. "
            f"Obtained input: {value}.")
>     return [gen_sparse_ops.serialize_many_sparse(
          value.indices, value.values, value.dense_shape,
          out_type=dtypes.variant)]

.venv/lib/python3.10/site-packages/tensorflow/python/framework/sparse_tensor.py:368: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

sparse_indices = <tf.Tensor: shape=(21, 2), dtype=int64, numpy=
array([[38,  0],
       [39,  1],
       [37,  0],
       [38,  2],
   ...      [ 9,  0],
       [ 6,  1],
       [ 6,  2],
       [ 4,  1],
       [ 5,  2],
       [ 2,  1],
       [ 2,  2]])>
sparse_values = <tf.Tensor: shape=(21,), dtype=float32, numpy=
array([0.81379783, 0.8817354 , 0.84640867, 0.8811032 , 0.952749  ,
    ...806, 0.9446689 ,
       0.87001216, 0.9786183 , 0.92559665, 0.83261985, 0.891773  ,
       0.96366274], dtype=float32)>
sparse_shape = <tf.Tensor: shape=(2,), dtype=int64, numpy=array([40,  3])>
out_type = tf.variant, name = None

    def serialize_many_sparse(sparse_indices, sparse_values, sparse_shape, out_type=_dtypes.string, name=None):
      r"""Serialize an `N`-minibatch `SparseTensor` into an `[N, 3]` `Tensor` object.
    
      The `SparseTensor` must have rank `R` greater than 1, and the first dimension
      is treated as the minibatch dimension.  Elements of the `SparseTensor`
      must be sorted in increasing order of this first dimension.  The serialized
      `SparseTensor` objects going into each row of `serialized_sparse` will have
      rank `R-1`.
    
      The minibatch size `N` is extracted from `sparse_shape[0]`.
    
      Args:
        sparse_indices: A `Tensor` of type `int64`.
          2-D.  The `indices` of the minibatch `SparseTensor`.
        sparse_values: A `Tensor`.
          1-D.  The `values` of the minibatch `SparseTensor`.
        sparse_shape: A `Tensor` of type `int64`.
          1-D.  The `shape` of the minibatch `SparseTensor`.
        out_type: An optional `tf.DType` from: `tf.string, tf.variant`. Defaults to `tf.string`.
          The `dtype` to use for serialization; the supported types are `string`
          (default) and `variant`.
        name: A name for the operation (optional).
    
      Returns:
        A `Tensor` of type `out_type`.
      """
      _ctx = _context._context or _context.context()
      tld = _ctx._thread_local_data
      if tld.is_eager:
        try:
          _result = pywrap_tfe.TFE_Py_FastPathExecute(
            _ctx, "SerializeManySparse", name, sparse_indices, sparse_values,
            sparse_shape, "out_type", out_type)
          return _result
        except _core._NotOkStatusException as e:
>         _ops.raise_from_not_ok_status(e, name)

.venv/lib/python3.10/site-packages/tensorflow/python/ops/gen_sparse_ops.py:496: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

e = _NotOkStatusException(), name = None

    def raise_from_not_ok_status(e, name):
      e.message += (" name: " + name if name is not None else "")
>     raise core._status_to_exception(e) from None  # pylint: disable=protected-access
E     tensorflow.python.framework.errors_impl.InvalidArgumentError: indices[2] = [37,0] is out of order. Many sparse ops require sorted indices.
E         Use `tf.sparse.reorder` to create a correctly ordered copy.
E     
E      [Op:SerializeManySparse]

.venv/lib/python3.10/site-packages/tensorflow/python/framework/ops.py:7164: InvalidArgumentError

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you can think of a better way to handle this compatibility, I am all ears

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not leave conversion to the correct format (spmatrix with sorted indices) to the user, and raise a ValueError if not correctly configured?

Why would the user want you to handle the conversion? Can you check if the index is already sorted to avoid some overhead? (or do you do that already? how much overhead is added? is "y" sorted too?).

(I haven't reviewed your code closely; 'scuse my ignorant questions)

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sort_indices() already checks if it is already sorted: https://github.com/scipy/scipy/blob/4cf21e753cf937d1c6c2d2a0e372fbc1dbbeea81/scipy/sparse/_compressed.py#L1163

I guess we could just not check here and let TensorFlow fail? I don't Scikit-Learn even checks for this specific case (it just checks for formats)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see two options:

  1. Raise errors before any fitting is done. This will require tests to ensure that SciKeras is staying current with TF's implementation (and that DOK/LIL/BSR raise errors in TF and SciKeras).
  2. Let TF raise errors with sparse matrices. Then maybe look at the error message and throw a warning/another error, and test to make sure that some sparse matrices work.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, and currently (1) is what would be happening.

There are two sorts of changes that I could see happening with (1):

  1. TensorFlow supports a new sparse input format. We won't be able to catch this with tests, but presumably if a user wants to use this they can ask us for the feature and it would be pretty straightforward to implement. I also don't think this is super likely to happen.
  2. TensorFlow drops support for a format. This we would catch with tests. Very unlikely to happen given TensorFlow's backward compatibility guarantees.

The two main issues with option (2) are:

  1. The Scikit-Learn API explicitly calls for a ValueError (so technically we would be violating the Scikit-Learn API).
  2. Introspecting into error messages is a bit problematic, they can change at any time, etc.
  3. The errors that get raised are not very user friendly.

f" {Xs.getformat()}"
)
X = Xs

X = check_array(
X,
Expand Down