Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use cuvs for knn when available #103

Merged
merged 6 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions crossfit/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,20 @@
from crossfit.backend.cudf.array import *
from crossfit.backend.cudf.dataframe import *
except ImportError:
logging.warning("Import Error for cudf backend in Crossfit. Skipping it.")
pass

try:
from crossfit.backend.cupy.array import *
from crossfit.backend.cupy.sparse import *
except ImportError:
logging.warning("Import Error for cupy backend in Crossfit. Skipping it.")
pass

try:
from crossfit.backend.torch.array import *
except ImportError:
logging.warning("Import Error for Torch backend in Crossfit. Skipping it.")
pass

# from crossfit.backend.tf.array import *
Expand Down
14 changes: 12 additions & 2 deletions crossfit/op/vector_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,27 @@

import cudf
import cupy as cp
import cuvs
import dask.dataframe as dd
import pylibraft
from cuml.dask.neighbors import NearestNeighbors
from dask import delayed
from dask_cudf import from_delayed
from pylibraft.neighbors.brute_force import knn
from packaging.version import parse as parse_version

from crossfit.backend.cudf.series import create_list_series_from_1d_or_2d_ar
from crossfit.backend.dask.cluster import global_dask_client
from crossfit.dataset.base import EmbeddingDatataset
from crossfit.op.base import Op

if (
(parse_version(pylibraft.__version__).base_version >= "24.12") or
(parse_version(cuvs.__version__).base_version >= "24.12")
):
from cuvs.neighbors.brute_force import search
else:
from pylibraft.neighbors.brute_force import knn as search


class VectorSearchOp(Op):
@overload
Expand Down Expand Up @@ -171,7 +181,7 @@ def __init__(
self.normalize = normalize

def search_tensors(self, queries, corpus):
results, indices = knn(dataset=corpus, queries=queries, k=self.k, metric=self.metric)
results, indices = search(dataset=corpus, queries=queries, k=self.k, metric=self.metric)

return cp.asarray(results), cp.asarray(indices)

Expand Down
Loading