Skip to content

Commit 6604ca5

Browse files
PGVector Support for Custom Connection Object (microsoft#2566)
* Added fixes and tests for basic auth format * User can provide their own connection object. Added test for it. * Updated instructions on how to use. Fully tested all 3 authentication methods successfully. * Get password from gitlab secrets. * Hide passwords. * Update notebook/agentchat_pgvector_RetrieveChat.ipynb Co-authored-by: Li Jiang <[email protected]> * Hide passwords. * Added connection_string test. 3 tests total for auth. * Fixed quotes on db config params. No other changes found. * Ran notebook * Ran pre-commits and updated setup to include psycopg[binary] for windows and mac. * Corrected list extension. * Separate connection establishment function. Testing pending. * Fixed pgvectordb auth * Update agentchat_pgvector_RetrieveChat.ipynb Added autocommit=True in example * Rerun notebook --------- Co-authored-by: Li Jiang <[email protected]> Co-authored-by: Li Jiang <[email protected]>
1 parent 1298875 commit 6604ca5

File tree

5 files changed

+313
-398
lines changed

5 files changed

+313
-398
lines changed

.gitattributes

+88
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,91 @@
1+
# Source code
2+
*.bash text eol=lf
3+
*.bat text eol=crlf
4+
*.cmd text eol=crlf
5+
*.coffee text
6+
*.css text diff=css eol=lf
7+
*.htm text diff=html eol=lf
8+
*.html text diff=html eol=lf
9+
*.inc text
10+
*.ini text
11+
*.js text
12+
*.json text eol=lf
13+
*.jsx text
14+
*.less text
15+
*.ls text
16+
*.map text -diff
17+
*.od text
18+
*.onlydata text
19+
*.php text diff=php
20+
*.pl text
21+
*.ps1 text eol=crlf
22+
*.py text diff=python eol=lf
23+
*.rb text diff=ruby eol=lf
24+
*.sass text
25+
*.scm text
26+
*.scss text diff=css
27+
*.sh text eol=lf
28+
.husky/* text eol=lf
29+
*.sql text
30+
*.styl text
31+
*.tag text
32+
*.ts text
33+
*.tsx text
34+
*.xml text
35+
*.xhtml text diff=html
36+
37+
# Docker
38+
Dockerfile text eol=lf
39+
40+
# Documentation
41+
*.ipynb text
42+
*.markdown text diff=markdown eol=lf
43+
*.md text diff=markdown eol=lf
44+
*.mdwn text diff=markdown eol=lf
45+
*.mdown text diff=markdown eol=lf
46+
*.mkd text diff=markdown eol=lf
47+
*.mkdn text diff=markdown eol=lf
48+
*.mdtxt text eol=lf
49+
*.mdtext text eol=lf
50+
*.txt text eol=lf
51+
AUTHORS text eol=lf
52+
CHANGELOG text eol=lf
53+
CHANGES text eol=lf
54+
CONTRIBUTING text eol=lf
55+
COPYING text eol=lf
56+
copyright text eol=lf
57+
*COPYRIGHT* text eol=lf
58+
INSTALL text eol=lf
59+
license text eol=lf
60+
LICENSE text eol=lf
61+
NEWS text eol=lf
62+
readme text eol=lf
63+
*README* text eol=lf
64+
TODO text
65+
66+
# Configs
67+
*.cnf text eol=lf
68+
*.conf text eol=lf
69+
*.config text eol=lf
70+
.editorconfig text
71+
.env text eol=lf
72+
.gitattributes text eol=lf
73+
.gitconfig text eol=lf
74+
.htaccess text
75+
*.lock text -diff
76+
package.json text eol=lf
77+
package-lock.json text eol=lf -diff
78+
pnpm-lock.yaml text eol=lf -diff
79+
.prettierrc text
80+
yarn.lock text -diff
81+
*.toml text eol=lf
82+
*.yaml text eol=lf
83+
*.yml text eol=lf
84+
browserslist text
85+
Makefile text eol=lf
86+
makefile text eol=lf
87+
88+
# Images
189
*.png filter=lfs diff=lfs merge=lfs -text
290
*.jpg filter=lfs diff=lfs merge=lfs -text
391
*.jpeg filter=lfs diff=lfs merge=lfs -text

autogen/agentchat/contrib/vectordb/pgvectordb.py

+118-46
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
import re
33
import urllib.parse
4-
from typing import Callable, List
4+
from typing import Callable, List, Optional, Union
55

66
import numpy as np
77
from sentence_transformers import SentenceTransformer
@@ -231,7 +231,14 @@ def table_exists(self, table_name: str) -> bool:
231231
exists = cursor.fetchone()[0]
232232
return exists
233233

234-
def get(self, ids=None, include=None, where=None, limit=None, offset=None) -> List[Document]:
234+
def get(
235+
self,
236+
ids: Optional[str] = None,
237+
include: Optional[str] = None,
238+
where: Optional[str] = None,
239+
limit: Optional[Union[int, str]] = None,
240+
offset: Optional[Union[int, str]] = None,
241+
) -> List[Document]:
235242
"""
236243
Retrieve documents from the collection.
237244
@@ -272,7 +279,6 @@ def get(self, ids=None, include=None, where=None, limit=None, offset=None) -> Li
272279

273280
# Construct the full query
274281
query = f"{select_clause} {from_clause} {where_clause} {limit_clause} {offset_clause}"
275-
276282
retrieved_documents = []
277283
try:
278284
# Execute the query with the appropriate values
@@ -380,11 +386,11 @@ def inner_product_distance(arr1: List[float], arr2: List[float]) -> float:
380386
def query(
381387
self,
382388
query_texts: List[str],
383-
collection_name: str = None,
384-
n_results: int = 10,
385-
distance_type: str = "euclidean",
386-
distance_threshold: float = -1,
387-
include_embedding: bool = False,
389+
collection_name: Optional[str] = None,
390+
n_results: Optional[int] = 10,
391+
distance_type: Optional[str] = "euclidean",
392+
distance_threshold: Optional[float] = -1,
393+
include_embedding: Optional[bool] = False,
388394
) -> QueryResults:
389395
"""
390396
Query documents in the collection.
@@ -450,7 +456,7 @@ def query(
450456
return results
451457

452458
@staticmethod
453-
def convert_string_to_array(array_string) -> List[float]:
459+
def convert_string_to_array(array_string: str) -> List[float]:
454460
"""
455461
Convert a string representation of an array to a list of floats.
456462
@@ -467,7 +473,7 @@ def convert_string_to_array(array_string) -> List[float]:
467473
array = [float(num) for num in array_string.split()]
468474
return array
469475

470-
def modify(self, metadata, collection_name: str = None) -> None:
476+
def modify(self, metadata, collection_name: Optional[str] = None) -> None:
471477
"""
472478
Modify metadata for the collection.
473479
@@ -486,7 +492,7 @@ def modify(self, metadata, collection_name: str = None) -> None:
486492
)
487493
cursor.close()
488494

489-
def delete(self, ids: List[ItemID], collection_name: str = None) -> None:
495+
def delete(self, ids: List[ItemID], collection_name: Optional[str] = None) -> None:
490496
"""
491497
Delete documents from the collection.
492498
@@ -504,7 +510,7 @@ def delete(self, ids: List[ItemID], collection_name: str = None) -> None:
504510
cursor.execute(f"DELETE FROM {self.name} WHERE id IN ({id_placeholders});", ids)
505511
cursor.close()
506512

507-
def delete_collection(self, collection_name: str = None) -> None:
513+
def delete_collection(self, collection_name: Optional[str] = None) -> None:
508514
"""
509515
Delete the entire collection.
510516
@@ -520,7 +526,7 @@ def delete_collection(self, collection_name: str = None) -> None:
520526
cursor.execute(f"DROP TABLE IF EXISTS {self.name}")
521527
cursor.close()
522528

523-
def create_collection(self, collection_name: str = None) -> None:
529+
def create_collection(self, collection_name: Optional[str] = None) -> None:
524530
"""
525531
Create a new collection.
526532
@@ -557,23 +563,27 @@ class PGVectorDB(VectorDB):
557563
def __init__(
558564
self,
559565
*,
560-
connection_string: str = None,
561-
host: str = None,
562-
port: int = None,
563-
dbname: str = None,
564-
username: str = None,
565-
password: str = None,
566-
connect_timeout: int = 10,
566+
conn: Optional[psycopg.Connection] = None,
567+
connection_string: Optional[str] = None,
568+
host: Optional[str] = None,
569+
port: Optional[Union[int, str]] = None,
570+
dbname: Optional[str] = None,
571+
username: Optional[str] = None,
572+
password: Optional[str] = None,
573+
connect_timeout: Optional[int] = 10,
567574
embedding_function: Callable = None,
568-
metadata: dict = None,
569-
model_name: str = "all-MiniLM-L6-v2",
575+
metadata: Optional[dict] = None,
576+
model_name: Optional[str] = "all-MiniLM-L6-v2",
570577
) -> None:
571578
"""
572579
Initialize the vector database.
573580
574581
Note: connection_string or host + port + dbname must be specified
575582
576583
Args:
584+
conn: psycopg.Connection | A customer connection object to connect to the database.
585+
A connection object may include additional key/values:
586+
https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING
577587
connection_string: "postgresql://username:password@hostname:port/database" | The PGVector connection string. Default is None.
578588
host: str | The host to connect to. Default is None.
579589
port: int | The port to connect to. Default is None.
@@ -593,46 +603,108 @@ def __init__(
593603
Returns:
594604
None
595605
"""
606+
self.client = self.establish_connection(
607+
conn=conn,
608+
connection_string=connection_string,
609+
host=host,
610+
port=port,
611+
dbname=dbname,
612+
username=username,
613+
password=password,
614+
connect_timeout=connect_timeout,
615+
)
616+
self.model_name = model_name
596617
try:
597-
if connection_string:
618+
self.embedding_function = (
619+
SentenceTransformer(self.model_name) if embedding_function is None else embedding_function
620+
)
621+
except Exception as e:
622+
logger.error(
623+
f"Validate the model name entered: {self.model_name} "
624+
f"from https://huggingface.co/models?library=sentence-transformers\nError: {e}"
625+
)
626+
raise e
627+
self.metadata = metadata
628+
register_vector(self.client)
629+
self.active_collection = None
630+
631+
def establish_connection(
632+
self,
633+
conn: Optional[psycopg.Connection] = None,
634+
connection_string: Optional[str] = None,
635+
host: Optional[str] = None,
636+
port: Optional[Union[int, str]] = None,
637+
dbname: Optional[str] = None,
638+
username: Optional[str] = None,
639+
password: Optional[str] = None,
640+
connect_timeout: Optional[int] = 10,
641+
) -> psycopg.Connection:
642+
"""
643+
Establishes a connection to a PostgreSQL database using psycopg.
644+
645+
Args:
646+
conn: An existing psycopg connection object. If provided, this connection will be used.
647+
connection_string: A string containing the connection information. If provided, a new connection will be established using this string.
648+
host: The hostname of the PostgreSQL server. Used if connection_string is not provided.
649+
port: The port number to connect to at the server host. Used if connection_string is not provided.
650+
dbname: The database name. Used if connection_string is not provided.
651+
username: The username to connect as. Used if connection_string is not provided.
652+
password: The user's password. Used if connection_string is not provided.
653+
connect_timeout: Maximum wait for connection, in seconds. The default is 10 seconds.
654+
655+
Returns:
656+
A psycopg.Connection object representing the established connection.
657+
658+
Raises:
659+
PermissionError if no credentials are supplied
660+
psycopg.Error: If an error occurs while trying to connect to the database.
661+
"""
662+
try:
663+
if conn:
664+
self.client = conn
665+
elif connection_string:
598666
parsed_connection = urllib.parse.urlparse(connection_string)
599667
encoded_username = urllib.parse.quote(parsed_connection.username, safe="")
600668
encoded_password = urllib.parse.quote(parsed_connection.password, safe="")
669+
encoded_password = f":{encoded_password}@"
601670
encoded_host = urllib.parse.quote(parsed_connection.hostname, safe="")
671+
encoded_port = f":{parsed_connection.port}"
602672
encoded_database = urllib.parse.quote(parsed_connection.path[1:], safe="")
603673
connection_string_encoded = (
604-
f"{parsed_connection.scheme}://{encoded_username}:{encoded_password}"
605-
f"@{encoded_host}:{parsed_connection.port}/{encoded_database}"
674+
f"{parsed_connection.scheme}://{encoded_username}{encoded_password}"
675+
f"{encoded_host}{encoded_port}/{encoded_database}"
606676
)
607677
self.client = psycopg.connect(conninfo=connection_string_encoded, autocommit=True)
608-
elif host and port and dbname:
678+
elif host:
679+
connection_string = ""
680+
if host:
681+
encoded_host = urllib.parse.quote(host, safe="")
682+
connection_string += f"host={encoded_host} "
683+
if port:
684+
connection_string += f"port={port} "
685+
if dbname:
686+
encoded_database = urllib.parse.quote(dbname, safe="")
687+
connection_string += f"dbname={encoded_database} "
688+
if username:
689+
encoded_username = urllib.parse.quote(username, safe="")
690+
connection_string += f"user={encoded_username} "
691+
if password:
692+
encoded_password = urllib.parse.quote(password, safe="")
693+
connection_string += f"password={encoded_password} "
694+
609695
self.client = psycopg.connect(
610-
host=host,
611-
port=port,
612-
dbname=dbname,
613-
username=username,
614-
password=password,
696+
conninfo=connection_string,
615697
connect_timeout=connect_timeout,
616698
autocommit=True,
617699
)
700+
else:
701+
logger.error("Credentials were not supplied...")
702+
raise PermissionError
703+
self.client.execute("CREATE EXTENSION IF NOT EXISTS vector")
618704
except psycopg.Error as e:
619705
logger.error("Error connecting to the database: ", e)
620706
raise e
621-
self.model_name = model_name
622-
try:
623-
self.embedding_function = (
624-
SentenceTransformer(self.model_name) if embedding_function is None else embedding_function
625-
)
626-
except Exception as e:
627-
logger.error(
628-
f"Validate the model name entered: {self.model_name} "
629-
f"from https://huggingface.co/models?library=sentence-transformers\nError: {e}"
630-
)
631-
raise e
632-
self.metadata = metadata
633-
self.client.execute("CREATE EXTENSION IF NOT EXISTS vector")
634-
register_vector(self.client)
635-
self.active_collection = None
707+
return self.client
636708

637709
def create_collection(
638710
self, collection_name: str, overwrite: bool = False, get_or_create: bool = True

0 commit comments

Comments
 (0)