Skip to content

Commit

Permalink
Add linting via Black to GH workflows
Browse files Browse the repository at this point in the history
Reformat existing files to match the expected format.
  • Loading branch information
daverigby committed Apr 29, 2024
1 parent a35fa6b commit 3b1ab0b
Show file tree
Hide file tree
Showing 8 changed files with 149 additions and 33 deletions.
12 changes: 12 additions & 0 deletions .github/workflows/black.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
name: Lint (via Black)

on: [push, pull_request]

jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: psf/black@stable
with:
src: "./vsb"
68 changes: 67 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pydantic = "^2.7.1"


[tool.poetry.group.dev.dependencies]
black = "^24.4.2"
pytest = "^8.0.2"
flake8 = "^7.0.0"
pre-commit = "^3.6.2"
Expand Down
5 changes: 4 additions & 1 deletion vsb/databases/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,18 @@
class Database(Enum):
"""Set of supported database backends, the value is the string used to
specify via --database="""

Pinecone = "pinecone"
PGVector = "pgvector"

def build(self, config: dict) -> DB:
"""Construct an instance of """
"""Construct an instance of DB based on the database enum value"""
match self:
case Database.Pinecone:
from .pinecone.pinecone import PineconeDB

return PineconeDB(config)
case Database.PGVector:
from .pgvector.pgvector import PGVectorDB

return PGVectorDB()
1 change: 0 additions & 1 deletion vsb/vsb_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
Vector = list[float]


Record = dict[str, ]
class Record(BaseModel):
id: str
values: Vector
Expand Down
2 changes: 2 additions & 0 deletions vsb/workloads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ def build(self) -> VectorWorkload:
match self:
case Workload.Mnist:
from .mnist.mnist import Mnist

return Mnist()
case Workload.MnistTest:
from .mnist.mnist import MnistTest

return MnistTest()
92 changes: 62 additions & 30 deletions vsb/workloads/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ class Dataset:
Google Cloud Storage bucket and are downloaded on-demand on first access,
then cached on the local machine.
"""

gcs_bucket = "pinecone-datasets-dev"

@staticmethod
def split_dataframe(df: pandas.DataFrame, batch_size) -> Iterator[pandas.DataFrame]:
for i in range(0, len(df), batch_size):
batch = df.iloc[i: i + batch_size]
batch = df.iloc[i : i + batch_size]
yield batch

@staticmethod
Expand Down Expand Up @@ -83,17 +84,18 @@ def load_documents(self, skip_download: bool = False):
# to use for documents into a pandas dataframe.
self.documents = self._load_parquet_dataset("passages", limit=self.limit)

def setup_queries(self,
load_queries: bool = True,
doc_sample_fraction: float = 1.0):
def setup_queries(
self, load_queries: bool = True, doc_sample_fraction: float = 1.0
):

# If there is an explicit 'queries' dataset, then load that and use
# for querying, otherwise use documents directly.
if load_queries:
self.queries = self._load_parquet_dataset("queries")
if not self.queries.empty:
logging.info(
f"Using {len(self.queries)} query vectors loaded from dataset 'queries' table")
f"Using {len(self.queries)} query vectors loaded from dataset 'queries' table"
)
else:
# Queries expect a different schema than documents.
# Documents looks like:
Expand All @@ -105,7 +107,9 @@ def setup_queries(self,
# 'vector' field of queries is currently used).
if self.documents.empty:
self.load_documents()
assert not self.documents.empty, "Cannot sample 'documents' to use for queries as it is empty"
assert (
not self.documents.empty
), "Cannot sample 'documents' to use for queries as it is empty"
self.queries = self.documents[["values"]].copy()
self.queries.rename(columns={"values": "vector"}, inplace=True)

Expand All @@ -115,9 +119,12 @@ def setup_queries(self,
self.queries = self.queries.sample(frac=doc_sample_fraction, random_state=1)
logging.info(
f"Using {doc_sample_fraction * 100}% of documents' dataset "
f"for query data ({len(self.queries)} sampled)")
f"for query data ({len(self.queries)} sampled)"
)

def upsert_into_index(self, index_host, api_key, skip_if_count_identical: bool = False):
def upsert_into_index(
self, index_host, api_key, skip_if_count_identical: bool = False
):
"""
Upsert the datasets' documents into the specified index.
:param index_host: Pinecone index to upsert into (must already exist)
Expand All @@ -127,28 +134,36 @@ def upsert_into_index(self, index_host, api_key, skip_if_count_identical: bool =
pinecone = PineconeGRPC(api_key)
index = pinecone.Index(host=index_host)
if skip_if_count_identical:
if index.describe_index_stats()['total_vector_count'] == len(self.documents):
if index.describe_index_stats()["total_vector_count"] == len(
self.documents
):
logging.info(
f"Skipping upsert as index already has same number of documents as dataset ({len(self.documents)}")
f"Skipping upsert as index already has same number of documents as dataset ({len(self.documents)}"
)
return

upserted_count = self._upsert_from_dataframe(index)
if upserted_count != len(self.documents):
logging.warning(
f"Not all records upserted successfully. Dataset count:{len(self.documents)},"
f" upserted count:{upserted_count}")
f" upserted count:{upserted_count}"
)

def prune_documents(self):
"""
Discard the contents of self.documents once it is no longer required
(it can consume a significant amount of memory).
"""
del self.documents
logging.debug(f"After pruning, 'queries' memory usage:{self.queries.memory_usage()}")
logging.debug(
f"After pruning, 'queries' memory usage:{self.queries.memory_usage()}"
)

def _download_dataset_files(self):
self.cache.mkdir(parents=True, exist_ok=True)
logging.debug(f"Checking for existence of dataset '{self.name}' in dataset cache '{self.cache}'")
logging.debug(
f"Checking for existence of dataset '{self.name}' in dataset cache '{self.cache}'"
)
client = Client.create_anonymous_client()
bucket: Bucket = client.bucket(Dataset.gcs_bucket)
blobs = [b for b in bucket.list_blobs(prefix=self.name + "/")]
Expand All @@ -173,19 +188,23 @@ def should_download(blob):
to_download = [b for b in filter(lambda b: should_download(b), blobs)]
if not to_download:
return
pbar = tqdm(desc="Downloading datset",
total=sum([b.size for b in to_download]),
unit="Bytes",
unit_scale=True)
pbar = tqdm(
desc="Downloading datset",
total=sum([b.size for b in to_download]),
unit="Bytes",
unit_scale=True,
)
for blob in to_download:
logging.debug(f"Dataset file '{blob.name}' not found in cache - will be downloaded")
logging.debug(
f"Dataset file '{blob.name}' not found in cache - will be downloaded"
)
dest_path = self.cache / blob.name
dest_path.parent.mkdir(parents=True, exist_ok=True)
blob.download_to_filename(self.cache / blob.name)
pbar.update(blob.size)

def _load_parquet_dataset(self, kind, limit=0):
parquet_files = [f for f in (self.cache / self.name).glob(kind + '/*.parquet')]
parquet_files = [f for f in (self.cache / self.name).glob(kind + "/*.parquet")]
if not len(parquet_files):
return pandas.DataFrame()

Expand All @@ -203,7 +222,9 @@ def _load_parquet_dataset(self, kind, limit=0):
fields = set(dataset.schema.names)
missing = fields.difference(required)
if len(missing) > 0:
raise ValueError(f"Missing required fields ({missing}) for passages from dataset '{self.name}'")
raise ValueError(
f"Missing required fields ({missing}) for passages from dataset '{self.name}'"
)
# Also load in supported optional fields.
optional = set(["sparse_values", "metadata"])
columns = list(required.union((fields.intersection(optional))))
Expand All @@ -224,21 +245,29 @@ def _load_parquet_dataset(self, kind, limit=0):
required = set(["top_k", "blob"])
missing = required.difference(fields)
if len(missing) > 0:
raise ValueError(f"Missing required fields ({missing}) for queries from dataset '{self.name}'")
raise ValueError(
f"Missing required fields ({missing}) for queries from dataset '{self.name}'"
)
value_field = set(["values", "vector"]).intersection(fields)
match len(value_field):
case 0:
raise ValueError(f"Missing required search vector field ('values' or 'vector') queries from dataset '{self.name}'")
raise ValueError(
f"Missing required search vector field ('values' or 'vector') queries from dataset '{self.name}'"
)
case 2:
raise ValueError(f"Multiple search vector fields ('values' and 'vector') present in queries from dataset '{self.name}'")
raise ValueError(
f"Multiple search vector fields ('values' and 'vector') present in queries from dataset '{self.name}'"
)
case 1:
required = required | value_field
# Also load in supported optional fields.
optional = set(["id", "sparse_vector", "filter"])
columns = list(required.union((fields.intersection(optional))))
metadata_column = "filter"
else:
raise ValueError(f"Unsupported kind '{kind}' - must be one of (documents, queries)")
raise ValueError(
f"Unsupported kind '{kind}' - must be one of (documents, queries)"
)
# Note: We to specify pandas.ArrowDtype as the types mapper to use pyarrow datatypes in the
# resulting DataFrame. This is significant as (for reasons unknown) it allows subsequent
# samples() of the DataFrame to be "disconnected" from the original underlying pyarrow data,
Expand All @@ -251,9 +280,10 @@ def _load_parquet_dataset(self, kind, limit=0):
# And drop any columns which all values are missing - e.g. not all
# datasets have sparse_values, but the parquet file may still have
# the (empty) column present.
df.dropna(axis='columns', how="all", inplace=True)
df.dropna(axis="columns", how="all", inplace=True)

if metadata_column in df:

def cleanup_null_values(metadata):
# Null metadata values are not supported, remove any key
# will a null value.
Expand All @@ -271,7 +301,9 @@ def convert_metadata_to_dict(metadata) -> dict:
return metadata
if isinstance(metadata, str):
return json.loads(metadata)
raise TypeError(f"metadata must be a string or dict (found {type(metadata)})")
raise TypeError(
f"metadata must be a string or dict (found {type(metadata)})"
)

def prepare_metadata(metadata):
return cleanup_null_values(convert_metadata_to_dict(metadata))
Expand Down Expand Up @@ -310,10 +342,10 @@ def _upsert_from_dataframe(self, index):
# However, converting the entire sub-frame's column back to a Python object before calling
# upsert_from_dataframe() is significantly faster, such that the overall upsert throughput
# (including the actual server-side work) is around 2x greater if we pre-convert.
converted = sub_frame.astype(dtype={'values': object})
resp = index.upsert_from_dataframe(converted,
batch_size=200,
show_progress=False)
converted = sub_frame.astype(dtype={"values": object})
resp = index.upsert_from_dataframe(
converted, batch_size=200, show_progress=False
)
upserted_count += resp.upserted_count
pbar.update(len(sub_frame))
return upserted_count
1 change: 1 addition & 0 deletions vsb/workloads/mnist/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def name(self) -> str:
class MnistTest(ParquetWorkload):
"""Reduced, "test" variant of mnist; with 1% of the full dataset (600
passages and 100 queries)."""

def __init__(self):
super().__init__("mnist", 600, 100)

Expand Down

0 comments on commit 3b1ab0b

Please sign in to comment.