diff --git a/vsb/workloads/dataset.py b/vsb/workloads/dataset.py index fede11d..c9e9451 100644 --- a/vsb/workloads/dataset.py +++ b/vsb/workloads/dataset.py @@ -1,11 +1,14 @@ from collections.abc import Iterator +import gevent +from gevent.lock import BoundedSemaphore import numpy import pyarrow.parquet from google.cloud.storage import Bucket, Client, transfer_manager import json import pandas import pathlib + from pinecone.grpc import PineconeGRPC import pyarrow.dataset as ds from pyarrow.parquet import ParquetDataset, ParquetFile @@ -204,23 +207,31 @@ def should_download(blob): " ✔ Dataset download complete", total=len(to_download), ) as download_task: - for blob in to_download: - logger.debug( - f"Dataset file '{blob.name}' not found in cache - will be downloaded" - ) - dest_path = self.cache / blob.name - dest_path.parent.mkdir(parents=True, exist_ok=True) - blob.download_to_file( - ProgressIOWrapper( - dest=dest_path, - progress=vsb.progress, - total=blob.size, - scale=1024 * 1024, - indent=2, + # Limit the max number of files concurrently downloaded. + semaphore = BoundedSemaphore(5) + + def download_file(blob, semaphore): + with semaphore: + logger.debug( + f"Dataset file '{blob.name}' not found in cache - will be downloaded" ) - ) - if vsb.progress: - vsb.progress.update(download_task, advance=1) + dest_path = self.cache / blob.name + dest_path.parent.mkdir(parents=True, exist_ok=True) + blob.download_to_file( + ProgressIOWrapper( + dest=dest_path, + progress=vsb.progress, + total=blob.size, + scale=1024 * 1024, + indent=2, + ) + ) + if vsb.progress: + vsb.progress.update(download_task, advance=1) + + greenlets = [gevent.spawn(download_file, blob, semaphore) for blob in to_download] + gevent.joinall(greenlets) + # Clear the progress bar now we're done. vsb.progress.stop() vsb.progress = None