1
1
import os
2
2
import re
3
3
import urllib .parse
4
- from typing import Callable , List
4
+ from typing import Callable , List , Optional , Union
5
5
6
6
import numpy as np
7
7
from sentence_transformers import SentenceTransformer
@@ -231,7 +231,14 @@ def table_exists(self, table_name: str) -> bool:
231
231
exists = cursor .fetchone ()[0 ]
232
232
return exists
233
233
234
- def get (self , ids = None , include = None , where = None , limit = None , offset = None ) -> List [Document ]:
234
+ def get (
235
+ self ,
236
+ ids : Optional [str ] = None ,
237
+ include : Optional [str ] = None ,
238
+ where : Optional [str ] = None ,
239
+ limit : Optional [Union [int , str ]] = None ,
240
+ offset : Optional [Union [int , str ]] = None ,
241
+ ) -> List [Document ]:
235
242
"""
236
243
Retrieve documents from the collection.
237
244
@@ -272,7 +279,6 @@ def get(self, ids=None, include=None, where=None, limit=None, offset=None) -> Li
272
279
273
280
# Construct the full query
274
281
query = f"{ select_clause } { from_clause } { where_clause } { limit_clause } { offset_clause } "
275
-
276
282
retrieved_documents = []
277
283
try :
278
284
# Execute the query with the appropriate values
@@ -380,11 +386,11 @@ def inner_product_distance(arr1: List[float], arr2: List[float]) -> float:
380
386
def query (
381
387
self ,
382
388
query_texts : List [str ],
383
- collection_name : str = None ,
384
- n_results : int = 10 ,
385
- distance_type : str = "euclidean" ,
386
- distance_threshold : float = - 1 ,
387
- include_embedding : bool = False ,
389
+ collection_name : Optional [ str ] = None ,
390
+ n_results : Optional [ int ] = 10 ,
391
+ distance_type : Optional [ str ] = "euclidean" ,
392
+ distance_threshold : Optional [ float ] = - 1 ,
393
+ include_embedding : Optional [ bool ] = False ,
388
394
) -> QueryResults :
389
395
"""
390
396
Query documents in the collection.
@@ -450,7 +456,7 @@ def query(
450
456
return results
451
457
452
458
@staticmethod
453
- def convert_string_to_array (array_string ) -> List [float ]:
459
+ def convert_string_to_array (array_string : str ) -> List [float ]:
454
460
"""
455
461
Convert a string representation of an array to a list of floats.
456
462
@@ -467,7 +473,7 @@ def convert_string_to_array(array_string) -> List[float]:
467
473
array = [float (num ) for num in array_string .split ()]
468
474
return array
469
475
470
- def modify (self , metadata , collection_name : str = None ) -> None :
476
+ def modify (self , metadata , collection_name : Optional [ str ] = None ) -> None :
471
477
"""
472
478
Modify metadata for the collection.
473
479
@@ -486,7 +492,7 @@ def modify(self, metadata, collection_name: str = None) -> None:
486
492
)
487
493
cursor .close ()
488
494
489
- def delete (self , ids : List [ItemID ], collection_name : str = None ) -> None :
495
+ def delete (self , ids : List [ItemID ], collection_name : Optional [ str ] = None ) -> None :
490
496
"""
491
497
Delete documents from the collection.
492
498
@@ -504,7 +510,7 @@ def delete(self, ids: List[ItemID], collection_name: str = None) -> None:
504
510
cursor .execute (f"DELETE FROM { self .name } WHERE id IN ({ id_placeholders } );" , ids )
505
511
cursor .close ()
506
512
507
- def delete_collection (self , collection_name : str = None ) -> None :
513
+ def delete_collection (self , collection_name : Optional [ str ] = None ) -> None :
508
514
"""
509
515
Delete the entire collection.
510
516
@@ -520,7 +526,7 @@ def delete_collection(self, collection_name: str = None) -> None:
520
526
cursor .execute (f"DROP TABLE IF EXISTS { self .name } " )
521
527
cursor .close ()
522
528
523
- def create_collection (self , collection_name : str = None ) -> None :
529
+ def create_collection (self , collection_name : Optional [ str ] = None ) -> None :
524
530
"""
525
531
Create a new collection.
526
532
@@ -557,23 +563,27 @@ class PGVectorDB(VectorDB):
557
563
def __init__ (
558
564
self ,
559
565
* ,
560
- connection_string : str = None ,
561
- host : str = None ,
562
- port : int = None ,
563
- dbname : str = None ,
564
- username : str = None ,
565
- password : str = None ,
566
- connect_timeout : int = 10 ,
566
+ conn : Optional [psycopg .Connection ] = None ,
567
+ connection_string : Optional [str ] = None ,
568
+ host : Optional [str ] = None ,
569
+ port : Optional [Union [int , str ]] = None ,
570
+ dbname : Optional [str ] = None ,
571
+ username : Optional [str ] = None ,
572
+ password : Optional [str ] = None ,
573
+ connect_timeout : Optional [int ] = 10 ,
567
574
embedding_function : Callable = None ,
568
- metadata : dict = None ,
569
- model_name : str = "all-MiniLM-L6-v2" ,
575
+ metadata : Optional [ dict ] = None ,
576
+ model_name : Optional [ str ] = "all-MiniLM-L6-v2" ,
570
577
) -> None :
571
578
"""
572
579
Initialize the vector database.
573
580
574
581
Note: connection_string or host + port + dbname must be specified
575
582
576
583
Args:
584
+ conn: psycopg.Connection | A customer connection object to connect to the database.
585
+ A connection object may include additional key/values:
586
+ https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING
577
587
connection_string: "postgresql://username:password@hostname:port/database" | The PGVector connection string. Default is None.
578
588
host: str | The host to connect to. Default is None.
579
589
port: int | The port to connect to. Default is None.
@@ -593,46 +603,108 @@ def __init__(
593
603
Returns:
594
604
None
595
605
"""
606
+ self .client = self .establish_connection (
607
+ conn = conn ,
608
+ connection_string = connection_string ,
609
+ host = host ,
610
+ port = port ,
611
+ dbname = dbname ,
612
+ username = username ,
613
+ password = password ,
614
+ connect_timeout = connect_timeout ,
615
+ )
616
+ self .model_name = model_name
596
617
try :
597
- if connection_string :
618
+ self .embedding_function = (
619
+ SentenceTransformer (self .model_name ) if embedding_function is None else embedding_function
620
+ )
621
+ except Exception as e :
622
+ logger .error (
623
+ f"Validate the model name entered: { self .model_name } "
624
+ f"from https://huggingface.co/models?library=sentence-transformers\n Error: { e } "
625
+ )
626
+ raise e
627
+ self .metadata = metadata
628
+ register_vector (self .client )
629
+ self .active_collection = None
630
+
631
+ def establish_connection (
632
+ self ,
633
+ conn : Optional [psycopg .Connection ] = None ,
634
+ connection_string : Optional [str ] = None ,
635
+ host : Optional [str ] = None ,
636
+ port : Optional [Union [int , str ]] = None ,
637
+ dbname : Optional [str ] = None ,
638
+ username : Optional [str ] = None ,
639
+ password : Optional [str ] = None ,
640
+ connect_timeout : Optional [int ] = 10 ,
641
+ ) -> psycopg .Connection :
642
+ """
643
+ Establishes a connection to a PostgreSQL database using psycopg.
644
+
645
+ Args:
646
+ conn: An existing psycopg connection object. If provided, this connection will be used.
647
+ connection_string: A string containing the connection information. If provided, a new connection will be established using this string.
648
+ host: The hostname of the PostgreSQL server. Used if connection_string is not provided.
649
+ port: The port number to connect to at the server host. Used if connection_string is not provided.
650
+ dbname: The database name. Used if connection_string is not provided.
651
+ username: The username to connect as. Used if connection_string is not provided.
652
+ password: The user's password. Used if connection_string is not provided.
653
+ connect_timeout: Maximum wait for connection, in seconds. The default is 10 seconds.
654
+
655
+ Returns:
656
+ A psycopg.Connection object representing the established connection.
657
+
658
+ Raises:
659
+ PermissionError if no credentials are supplied
660
+ psycopg.Error: If an error occurs while trying to connect to the database.
661
+ """
662
+ try :
663
+ if conn :
664
+ self .client = conn
665
+ elif connection_string :
598
666
parsed_connection = urllib .parse .urlparse (connection_string )
599
667
encoded_username = urllib .parse .quote (parsed_connection .username , safe = "" )
600
668
encoded_password = urllib .parse .quote (parsed_connection .password , safe = "" )
669
+ encoded_password = f":{ encoded_password } @"
601
670
encoded_host = urllib .parse .quote (parsed_connection .hostname , safe = "" )
671
+ encoded_port = f":{ parsed_connection .port } "
602
672
encoded_database = urllib .parse .quote (parsed_connection .path [1 :], safe = "" )
603
673
connection_string_encoded = (
604
- f"{ parsed_connection .scheme } ://{ encoded_username } : { encoded_password } "
605
- f"@ { encoded_host } : { parsed_connection . port } /{ encoded_database } "
674
+ f"{ parsed_connection .scheme } ://{ encoded_username } { encoded_password } "
675
+ f"{ encoded_host } { encoded_port } /{ encoded_database } "
606
676
)
607
677
self .client = psycopg .connect (conninfo = connection_string_encoded , autocommit = True )
608
- elif host and port and dbname :
678
+ elif host :
679
+ connection_string = ""
680
+ if host :
681
+ encoded_host = urllib .parse .quote (host , safe = "" )
682
+ connection_string += f"host={ encoded_host } "
683
+ if port :
684
+ connection_string += f"port={ port } "
685
+ if dbname :
686
+ encoded_database = urllib .parse .quote (dbname , safe = "" )
687
+ connection_string += f"dbname={ encoded_database } "
688
+ if username :
689
+ encoded_username = urllib .parse .quote (username , safe = "" )
690
+ connection_string += f"user={ encoded_username } "
691
+ if password :
692
+ encoded_password = urllib .parse .quote (password , safe = "" )
693
+ connection_string += f"password={ encoded_password } "
694
+
609
695
self .client = psycopg .connect (
610
- host = host ,
611
- port = port ,
612
- dbname = dbname ,
613
- username = username ,
614
- password = password ,
696
+ conninfo = connection_string ,
615
697
connect_timeout = connect_timeout ,
616
698
autocommit = True ,
617
699
)
700
+ else :
701
+ logger .error ("Credentials were not supplied..." )
702
+ raise PermissionError
703
+ self .client .execute ("CREATE EXTENSION IF NOT EXISTS vector" )
618
704
except psycopg .Error as e :
619
705
logger .error ("Error connecting to the database: " , e )
620
706
raise e
621
- self .model_name = model_name
622
- try :
623
- self .embedding_function = (
624
- SentenceTransformer (self .model_name ) if embedding_function is None else embedding_function
625
- )
626
- except Exception as e :
627
- logger .error (
628
- f"Validate the model name entered: { self .model_name } "
629
- f"from https://huggingface.co/models?library=sentence-transformers\n Error: { e } "
630
- )
631
- raise e
632
- self .metadata = metadata
633
- self .client .execute ("CREATE EXTENSION IF NOT EXISTS vector" )
634
- register_vector (self .client )
635
- self .active_collection = None
707
+ return self .client
636
708
637
709
def create_collection (
638
710
self , collection_name : str , overwrite : bool = False , get_or_create : bool = True
0 commit comments