Skip to content

Commit

Permalink
updated Create Index module to include metric type and engine type
Browse files Browse the repository at this point in the history
  • Loading branch information
vkrishna1084 committed Feb 27, 2025
1 parent b7c7dcf commit 0af5014
Showing 1 changed file with 27 additions and 4 deletions.
31 changes: 27 additions & 4 deletions vsb/databases/opensearch/opensearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __init__(
self.skip_populate = config["skip_populate"]
self.overwrite = config["overwrite"]
self.dimensions = dimensions
self.metric = metric

# Create the OpenSearch client
awsauth = AWS4Auth(
Expand All @@ -107,13 +108,11 @@ def __init__(
connection_class=RequestsHttpConnection,
)

# Create the index
# Generate index name if not specified
if self.index_name is None:
# None specified, default to "vsb-<workload>"
self.index_name = f"vsb-{name}"

if not self.client.indices.exists(self.index_name):
self.create_index()

def create_index(self):
# Create the index
Expand All @@ -125,7 +124,7 @@ def create_index(self):
"type": "text",
"fields": {"keyword": {"type": "keyword"}},
},
"v_content": {"type": "knn_vector", "dimension": self.dimensions},
"v_content": {"type": "knn_vector", "dimension": self.dimensions, "method": {"name": "hnsw", "space_type": OpenSearchDB._get_distance_func(self.metric), "engine": OpenSearchDB._get_engine_func(self.metric)}},
}
},
}
Expand Down Expand Up @@ -175,6 +174,8 @@ def get_namespace(self, namespace: str) -> Namespace:
def initialize_population(self):
if self.skip_populate:
return
if not self.client.indices.exists(self.index_name):
self.create_index()
if not self.created_index and not self.overwrite:
msg = (
f"OpenSearchDB: Index '{self.index_name}' already exists - cowardly "
Expand Down Expand Up @@ -220,3 +221,25 @@ def skip_refinalize(self):

def get_record_count(self) -> int:
return self.client.count(index=self.index_name)["count"]

@staticmethod
def _get_distance_func(metric: DistanceMetric) -> str:
match metric:
case DistanceMetric.Cosine:
return "cosinesim"
case DistanceMetric.Euclidean:
return "l2"
case DistanceMetric.DotProduct:
return "innerproduct"
raise ValueError("Invalid metric:{}".format(metric))

@staticmethod
def _get_engine_func(metric: DistanceMetric) -> str:
match metric:
case DistanceMetric.Cosine:
return "nmslib"
case DistanceMetric.Euclidean:
return "faiss"
case DistanceMetric.DotProduct:
return "faiss"
raise ValueError("Invalid metric:{}".format(metric))

0 comments on commit 0af5014

Please sign in to comment.