Skip to content

Commit 25ecb0d

Browse files
Use cuvs for knn when available (#103)
1 parent e009be2 commit 25ecb0d

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

crossfit/backend/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
# flake8: noqa
16+
import logging
1617

1718
from crossfit.backend.dask.dataframe import *
1819
from crossfit.backend.numpy.sparse import *
@@ -23,17 +24,20 @@
2324
from crossfit.backend.cudf.array import *
2425
from crossfit.backend.cudf.dataframe import *
2526
except ImportError:
27+
logging.warning("Import Error for cudf backend in Crossfit. Skipping it.")
2628
pass
2729

2830
try:
2931
from crossfit.backend.cupy.array import *
3032
from crossfit.backend.cupy.sparse import *
3133
except ImportError:
34+
logging.warning("Import Error for cupy backend in Crossfit. Skipping it.")
3235
pass
3336

3437
try:
3538
from crossfit.backend.torch.array import *
3639
except ImportError:
40+
logging.warning("Import Error for Torch backend in Crossfit. Skipping it.")
3741
pass
3842

3943
# from crossfit.backend.tf.array import *

crossfit/op/vector_search.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,26 @@
1616

1717
import cudf
1818
import cupy as cp
19+
import cuvs
1920
import dask.dataframe as dd
21+
import pylibraft
2022
from cuml.dask.neighbors import NearestNeighbors
2123
from dask import delayed
2224
from dask_cudf import from_delayed
23-
from pylibraft.neighbors.brute_force import knn
25+
from packaging.version import parse as parse_version
2426

2527
from crossfit.backend.cudf.series import create_list_series_from_1d_or_2d_ar
2628
from crossfit.backend.dask.cluster import global_dask_client
2729
from crossfit.dataset.base import EmbeddingDatataset
2830
from crossfit.op.base import Op
2931

32+
if (parse_version(pylibraft.__version__).base_version >= "24.12") or (
33+
parse_version(cuvs.__version__).base_version >= "24.12"
34+
):
35+
from cuvs.neighbors.brute_force import search
36+
else:
37+
from pylibraft.neighbors.brute_force import knn as search
38+
3039

3140
class VectorSearchOp(Op):
3241
@overload
@@ -171,7 +180,7 @@ def __init__(
171180
self.normalize = normalize
172181

173182
def search_tensors(self, queries, corpus):
174-
results, indices = knn(dataset=corpus, queries=queries, k=self.k, metric=self.metric)
183+
results, indices = search(dataset=corpus, queries=queries, k=self.k, metric=self.metric)
175184

176185
return cp.asarray(results), cp.asarray(indices)
177186

0 commit comments

Comments
 (0)