Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

the faiss index method change: replace the IndexFlatIP with IndexIVFFlat #250

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion huixiangdou/service/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from langchain.text_splitter import (MarkdownHeaderTextSplitter,
MarkdownTextSplitter,
RecursiveCharacterTextSplitter)
from langchain.vectorstores.faiss import FAISS as Vectorstore
# from langchain.vectorstores.faiss import FAISS as Vectorstore
from .vector_store import Vectorstore
from langchain_core.documents import Document
from loguru import logger
from torch.cuda import empty_cache
Expand Down
7 changes: 6 additions & 1 deletion huixiangdou/service/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from BCEmbedding.tools.langchain import BCERerank
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.retrievers import ContextualCompressionRetriever
from langchain.vectorstores.faiss import FAISS as Vectorstore
# from langchain.vectorstores.faiss import FAISS as Vectorstore
from .vector_store import Vectorstore
from langchain_community.vectorstores.utils import DistanceStrategy
from loguru import logger
from sklearn.metrics import precision_recall_curve
Expand Down Expand Up @@ -118,6 +119,9 @@ def query(self,
str: The best matching chunk, or None.
str: The best matching text, or None
"""

import time
start_time=time.time()
if question is None or len(question) < 1:
return None, None, []

Expand Down Expand Up @@ -183,6 +187,7 @@ def query(self,

context = context[0:context_max_length]
logger.debug('query:{} top1 file:{}'.format(question, references[0]))
logger.debug(f"query total time:{time.time()-start_time}")
return '\n'.join(chunks), context, [
os.path.basename(r) for r in references
]
Expand Down
192 changes: 192 additions & 0 deletions huixiangdou/service/vector_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import os
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Sized,
Tuple,
Union,
)
import uuid
import numpy as np
import torch
from langchain.vectorstores.faiss import FAISS
from langchain_core.embeddings import Embeddings
from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain_community.docstore.base import AddableMixin, Docstore
from langchain_core.documents import Document
from langchain_community.vectorstores.utils import (
DistanceStrategy,
maximal_marginal_relevance,
)

def dependable_faiss_import(no_avx2: Optional[bool] = None) -> Any:
"""
Import faiss if available, otherwise raise error.
If FAISS_NO_AVX2 environment variable is set, it will be considered
to load FAISS with no AVX2 optimization.

Args:
no_avx2: Load FAISS strictly with no AVX2 optimization
so that the vectorstore is portable and compatible with other devices.
"""
if no_avx2 is None and "FAISS_NO_AVX2" in os.environ:
no_avx2 = bool(os.getenv("FAISS_NO_AVX2"))

try:
if no_avx2:
from faiss import swigfaiss as faiss
else:
import faiss
except ImportError:
raise ImportError(
"Could not import faiss python package. "
"Please install it with `pip install faiss-gpu` (for CUDA supported GPU) "
"or `pip install faiss-cpu` (depending on Python version)."
)
return faiss

def _len_check_if_sized(x: Any, y: Any, x_name: str, y_name: str) -> None:
if isinstance(x, Sized) and isinstance(y, Sized) and len(x) != len(y):
raise ValueError(
f"{x_name} and {y_name} expected to be equal length but "
f"len({x_name})={len(x)} and len({y_name})={len(y)}"
)
return

import faiss
def get_faiss_index(dimension,index_method):
# faiss=dependable_faiss_import()
'''
dimension: int, the dimension of a embedding
index_method: str,
reference https://github.com/facebookresearch/faiss/wiki/The-index-factory

return: one of index database instance
'''
nlist = 100
if DistanceStrategy.MAX_INNER_PRODUCT:
index=faiss.IndexFlatIP(dimension)
index=faiss.IndexIVFFlat(index,dimension, nlist, faiss.METRIC_INNER_PRODUCT)
else:
index=faiss.IndexFlatL2(dimension)
index=faiss.IndexIVFFlat(index,dimension, nlist, faiss.METRIC_L2)
return index


class Vectorstore(FAISS):
@classmethod
def from_texts(
cls,
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> FAISS:
"""Construct FAISS wrapper from raw documents.

This is a user friendly interface that:
1. Embeds documents.
2. Creates an in memory docstore
3. Initializes the FAISS database

This is intended to be a quick way to get started.

Example:
.. code-block:: python

from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import OpenAIEmbeddings

embeddings = OpenAIEmbeddings()
faiss = FAISS.from_texts(texts, embeddings)
"""
embeddings = embedding.embed_documents(texts)
return cls.__from(
texts,
embeddings,
embedding,
metadatas=metadatas,
ids=ids,
**kwargs,
)

@classmethod
def __from(
cls,
texts: Iterable[str],
embeddings: List[List[float]],
embedding: Embeddings,
metadatas: Optional[Iterable[dict]] = None,
ids: Optional[List[str]] = None,
normalize_L2: bool = False,
distance_strategy: DistanceStrategy = DistanceStrategy.EUCLIDEAN_DISTANCE,
**kwargs: Any,
) -> FAISS:
faiss = dependable_faiss_import()

index = get_faiss_index(len(embeddings[0]),distance_strategy)
data=np.array(embeddings,dtype="float32")
index.train(data)

# print("len of the embedding",len(embeddings[0]))
docstore = kwargs.pop("docstore", InMemoryDocstore())
index_to_docstore_id = kwargs.pop("index_to_docstore_id", {})
vecstore = cls(
embedding,
index,
docstore,
index_to_docstore_id,
normalize_L2=normalize_L2,
distance_strategy=distance_strategy,
**kwargs,
)

vecstore.__add(texts, embeddings, metadatas=metadatas, ids=ids)

return vecstore

def __add(
self,
texts: Iterable[str],
embeddings: Iterable[List[float]],
metadatas: Optional[Iterable[dict]] = None,
ids: Optional[List[str]] = None,
) -> List[str]:
faiss = dependable_faiss_import()

if not isinstance(self.docstore, AddableMixin):
raise ValueError(
"If trying to add texts, the underlying docstore should support "
f"adding items, which {self.docstore} does not"
)

_len_check_if_sized(texts, metadatas, "texts", "metadatas")
_metadatas = metadatas or ({} for _ in texts)
documents = [
Document(page_content=t, metadata=m) for t, m in zip(texts, _metadatas)
]

_len_check_if_sized(documents, embeddings, "documents", "embeddings")
_len_check_if_sized(documents, ids, "documents", "ids")

if ids and len(ids) != len(set(ids)):
raise ValueError("Duplicate ids found in the ids list.")

# Add to the index.
vector = np.array(embeddings, dtype=np.float32)
if self._normalize_L2:
faiss.normalize_L2(vector)
self.index.add(vector)

# Add information to docstore and index.
ids = ids or [str(uuid.uuid4()) for _ in texts]
self.docstore.add({id_: doc for id_, doc in zip(ids, documents)})
starting_len = len(self.index_to_docstore_id)
index_to_id = {starting_len + j: id_ for j, id_ in enumerate(ids)}
self.index_to_docstore_id.update(index_to_id)
return ids
Loading