-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathpipeline.py
124 lines (103 loc) · 3.99 KB
/
pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
import re
from typing import List
from pprint import pprint
import weaviate
from dotenv import load_dotenv
from git import Repo
from langchain.docstore.document import Document
from langchain.document_loaders import GitLoader
from langchain.text_splitter import MarkdownTextSplitter
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.chains import RetrievalQAWithSourcesChain
from langchain import OpenAI
from langchain.vectorstores import Weaviate
from pydantic import BaseModel
load_dotenv()
EXAMPLE_DATA_DIR = os.path.join(os.path.dirname(__file__), "example_data")
class Entry(BaseModel):
content: str
meta: dict
class Entries(BaseModel):
entries: List[Entry]
class MaxPipeline:
def __init__(self, openai_token: str):
self.openai_token = openai_token
embed_setting = os.getenv("EMBEDDING_METHOD", "openai")
if embed_setting == "openai":
print("Using OpenAI embeddings")
self.embeddings = OpenAIEmbeddings()
elif embed_setting == "huggingface":
print("Using HuggingFace embeddings")
self.embeddings = HuggingFaceEmbeddings(model_name="all-mpnet-base-v2")
self.splitter = MarkdownTextSplitter(chunk_size=1000, chunk_overlap=0)
weaviate_auth_config = weaviate.AuthApiKey(
api_key=os.getenv("WEAVIATE_API_KEY")
)
weaviate_client = weaviate.Client(
url=os.getenv("WEAVIATE_URL"), auth_client_secret=weaviate_auth_config
)
self.document_store = Weaviate(
client=weaviate_client,
index_name="Posthog_docs",
by_text=False,
text_key="page_content",
embedding=self.embeddings,
attributes=["source"],
)
self.retriever = self.document_store.as_retriever(search_type="mmr")
def embed_markdown_document(self, documents: Entries):
for entry in documents.entries:
texts = self.splitter.split_text(entry.content)
documents = [
Document(page_content=doc, metadata=entry.meta) for doc in texts if doc
]
self.embed_documents(documents)
def embed_documents(self, documents: List[Document]):
self.document_store.add_documents(documents)
def retrieve_context(self, query: str):
return self.retriever.get_relevant_documents(query)
def chat(self, query: str):
chain = RetrievalQAWithSourcesChain.from_chain_type(
OpenAI(temperature=0), chain_type="stuff", retriever=self.retriever
)
results = chain(
{"question": query},
return_only_outputs=True,
)
return results
def embed_git_repo(self, gh_repo):
repo_url = f"https://github.com/{gh_repo}.git"
repo_dir = gh_repo.split("/")[-1]
path = os.path.join(EXAMPLE_DATA_DIR, repo_dir)
if not os.path.exists(path):
print("Repo not found, cloning...")
repo = Repo.clone_from(
repo_url,
to_path=path,
)
else:
print("Repo already exists, pulling latest changes...")
repo = Repo(path)
repo.git.pull()
branch = repo.head.reference
loader = GitLoader(
repo_path=path,
branch=branch,
file_filter=lambda file_path: file_path.endswith((".md", ".mdx")),
)
data = loader.load()
for page in data:
docs = []
text = self.splitter.split_text(page.page_content)
metadata = page.metadata
print(f"Adding {page.metadata['source']}")
page.metadata[
"source"
] = f"https://github.com/{gh_repo}/blob/master/{page.metadata['source']} "
for token in text:
docs.append(Document(page_content=token, metadata=metadata))
self.document_store.add_documents(docs)
print("Done")
return