-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathbase.py
158 lines (134 loc) · 4.85 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from abc import ABC, abstractmethod
from collections.abc import Iterator
from vsb.vsb_types import SearchRequest, DistanceMetric, RecordList, Record
class VectorWorkload(ABC):
@abstractmethod
def __init__(self, name: str, **kwargs):
self._name = name
@property
def name(self) -> str:
return self._name
@staticmethod
@abstractmethod
def dimensions() -> int:
"""
The dimensions of (dense) vectors for this workload.
"""
raise NotImplementedError
@staticmethod
@abstractmethod
def metric() -> DistanceMetric:
"""
The distance metric of this workload.
"""
raise NotImplementedError
@staticmethod
@abstractmethod
def record_count() -> int:
"""
The number of records in the initial workload after population, but
before issuing any additional requests.
"""
raise NotImplementedError
@staticmethod
@abstractmethod
def request_count() -> int:
"""
The number of requests in the Run phase of the test.
"""
raise NotImplementedError
@abstractmethod
def get_sample_record(self) -> Record:
"""
Return a sample record from the workload, to aid in databases sizing
the batch size to use, etc.
"""
raise NotImplementedError
@abstractmethod
def get_record_batch_iter(
self, num_users: int, user_id: int, batch_size: int
) -> Iterator[tuple[str, RecordList]]:
"""
For initial record ingest, returns a RecordBatchIterator over the
records for the specified `user_id`, assuming there is a total of
`num_users` which will be ingesting data - i.e. for the entire workload
to be loaded there should be `num_users` calls to this method.
Returns an Iterator which yields a tuple of
(namespace, batch of records), or (None, None) if there are no more
records to load.
:param num_users: The number of clients the dataset ingest is
distributed across.
:param user_id: The ID of the user requesting the iterator.
:param batch_size: The size of the batches to create.
"""
raise NotImplementedError
@abstractmethod
def get_query_iter(
self, num_users: int, user_id: int, batch_size: int
) -> Iterator[tuple[str, SearchRequest]]:
"""
Returns an iterator over the sequence of queries for the given user_id,
assuming a total of `num_users` which will be issuing queries - i.e.
for the entire query set to be requested there should be `num_users` calls
to this method.
Returns an Iterator which yields a tuple of (tenant, Request).
:param num_users: The number of clients the queries are distributed across.
:param user_id: The ID of the user requesting the iterator.
:param batch_size: The maximum batch size for upsert requests.
"""
raise NotImplementedError
def get_stats_prefix(self) -> str:
"""
Returns the prefix to use for stats emitted by this workload.
"""
return self.name + "."
def recall_available(self) -> bool:
"""
Returns True if the workload has recall data available for queries.
"""
return True
class VectorWorkloadSequence(ABC):
@abstractmethod
def __init__(self, name: str, **kwargs):
self._name = name
@property
def name(self) -> str:
return self._name
@staticmethod
@abstractmethod
def workload_count() -> int:
"""
The number of workloads in the sequence.
"""
raise NotImplementedError
def __getitem__(self, index: int) -> VectorWorkload:
"""
Return the workload at the specified index.
A default implementation is provided assuming
that the workloads are stored in a list named
`workloads` on the class.
"""
if not hasattr(self, "workloads"):
raise NotImplementedError
if index < 0 or index >= len(self.workloads):
raise IndexError
return self.workloads[index]
def record_count_upto(self, index: int) -> int:
"""
Return the total number of records in the sequencedf
up to and including the specified index's workload.
"""
if index >= self.workload_count():
raise IndexError
return sum(self[index].record_count() for index in range(index + 1))
class SingleVectorWorkloadSequence(VectorWorkloadSequence):
def __init__(self, name: str, workload: VectorWorkload):
super().__init__(name)
self.workload = workload
@staticmethod
def workload_count() -> int:
return 1
def __getitem__(self, index: int) -> VectorWorkload:
if index != 0:
raise IndexError
return self.workload