Skip to content

Commit

Permalink
Add cache_dir config, default to /tmp/VSB
Browse files Browse the repository at this point in the history
Currently downloaded datasets are stored in the current working
directory.

This can result in multiple copies of the same dataset being
downloaded; which also slows down startup (as we end up re-downloading
even if the files already exist somewhere locally.

Change so we use a fixed location, defaulting to /tmp/VSB/cache.
  • Loading branch information
daverigby committed Apr 30, 2024
1 parent 96005ac commit f419179
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 12 deletions.
10 changes: 9 additions & 1 deletion vsb.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class Phase(Enum):

"""Represents a single user (aka client) performing requests against
a particular Backend."""

def __init__(self, environment):
super().__init__(environment)
self.database = environment.database
Expand Down Expand Up @@ -134,6 +135,13 @@ def main():
default=1,
help="Number of clients concurrently accessing the database",
)
parser.add_argument(
"--cache_dir",
type=str,
default="/tmp/VSB/cache",
help="Directory to store downloaded datasets",
)

# TODO: These shouldn't be hardcoded - they should be based on the
# specified database - e.g. pgvector knows nothing about API keys.
# Preferably some form of generic way of specifying - e.g.
Expand All @@ -151,7 +159,7 @@ def main():
opts = env.options
db_config = {"api_key": opts.api_key, "index_name": opts.index_name}
env.database = Database(env.options.database).build(db_config)
env.workload = Workload(env.options.workload).build()
env.workload = Workload(env.options.workload).get_class()(options.cache_dir)

runner = env.create_local_runner()

Expand Down
8 changes: 4 additions & 4 deletions vsb/workloads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ class Workload(Enum):
Mnist = "mnist"
MnistTest = "mnist-test"

def build(self) -> VectorWorkload:
"""Construct an instance of Benchmark based on the value of the enum"""
def get_class(self) -> type[VectorWorkload]:
"""Return the VectorWorkload class to use, based on the value of the enum"""
match self:
case Workload.Mnist:
from .mnist.mnist import Mnist

return Mnist()
return Mnist
case Workload.MnistTest:
from .mnist.mnist import MnistTest

return MnistTest()
return MnistTest
9 changes: 4 additions & 5 deletions vsb/workloads/mnist/mnist.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from ..dataset import Dataset
from ..parquet_workload.parquet_workload import ParquetWorkload


class Mnist(ParquetWorkload):
def __init__(self):
super().__init__("mnist")
def __init__(self, cache_dir: str):
super().__init__("mnist", cache_dir=cache_dir)

def name(self) -> str:
return "mnist"
Expand All @@ -14,8 +13,8 @@ class MnistTest(ParquetWorkload):
"""Reduced, "test" variant of mnist; with 1% of the full dataset (600
passages and 100 queries)."""

def __init__(self):
super().__init__("mnist", 600, 100)
def __init__(self, cache_dir: str):
super().__init__("mnist", cache_dir=cache_dir, limit=600, query_limit=100)

def name(self) -> str:
return "mnist-test"
6 changes: 4 additions & 2 deletions vsb/workloads/parquet_workload/parquet_workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@ class ParquetWorkload(VectorWorkload):
from a second set of parquet files.
"""

def __init__(self, dataset_name: str, limit: int = 0, query_limit: int = 0):
self.dataset = Dataset(dataset_name, limit=limit)
def __init__(
self, dataset_name: str, cache_dir: str, limit: int = 0, query_limit: int = 0
):
self.dataset = Dataset(dataset_name, cache_dir=cache_dir, limit=limit)
self.dataset.load_documents()
# TODO: At parquet level should probably just iterate across entire row
# groups, if the DB wants to split further they can chose to.
Expand Down

0 comments on commit f419179

Please sign in to comment.