Skip to content

Commit ac24ac4

Browse files
committed
Docstring, more tests
1 parent bdca9e1 commit ac24ac4

File tree

3 files changed

+91
-64
lines changed

3 files changed

+91
-64
lines changed

pipeline_dp/dataset_histograms/computing_histograms.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,17 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Functions for computing dataset histograms in pipelines."""
15-
import bisect
1615
import operator
1716
from typing import Iterable, List, Tuple
1817

19-
import numpy as np
20-
2118
import pipeline_dp
2219
from pipeline_dp import pipeline_backend
2320
from pipeline_dp.dataset_histograms import histograms as hist
2421
from pipeline_dp.dataset_histograms import sum_histogram_computation
2522

23+
# Functions _compute_* computes histogram for counts. TODO: move them to
24+
# a separate file, similar to sum_histogram_computation.py.
25+
2626

2727
def _to_bin_lower_upper_logarithmic(value: int) -> Tuple[int, int]:
2828
"""Finds the lower and upper bounds of the histogram bin which contains

pipeline_dp/dataset_histograms/sum_histogram_computation.py

+75-61
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,20 @@
1313
# limitations under the License.
1414
"""Functions for computing linf_sum and sum_per_partition histograms."""
1515

16-
# TODO: theory on sum histograms
16+
# This files contains histograms which is useful for analysis of DP SUM
17+
# aggregation utility.
18+
# The general structure of these histogram is the following:
19+
# The input is collection of values X = (x0, ... x_n).
20+
# Computations is the following:
21+
# 1. Find min_x = min(X), max_x = max(X) of X
22+
# 2. Split the segment [min_x, max_x] in NUMBER_OF_BUCKETS_SUM_HISTOGRAM = 10000
23+
# equals size intervals [l_i, r_i), the last interval includes both endpoints.
24+
# 3. Each bin of the histogram correspoinds to interval [l_i, r_i), and contains
25+
# different statistics of numbers from X, which belogins to this intervals, like
26+
# count, sum, max etc.
27+
#
28+
# For generating bucket class LowerUpperGenerator is used, which takes
29+
# min, max, number of buckets and returns bucket for each number.
1730

1831
import copy
1932
import operator
@@ -28,15 +41,17 @@
2841
class LowerUpperGenerator:
2942
"""Generates lower-upper bounds for FrequencyBin
3043
31-
Attributes:
32-
left, right: bounds on interval on which we compute histogram.
33-
num_buckets: number of buckets on [left, right]. Buckets have the same length.
44+
Attributes:
45+
left, right: bounds on interval on which we compute histogram.
46+
num_buckets: number of buckets on [left, right]. Buckets have the same
47+
length.
3448
35-
i-th bucket corresponds to numbers from
49+
For general context see file docstring.
50+
i-th bucket corresponds to numbers from
3651
[left+i*bucket_len, right+(i+1)*bucket_len), where
3752
bucket_len = (right-left)/num_buckets.
38-
The last bucket includes right end-point.
39-
"""
53+
The last bucket includes right end-point.
54+
"""
4055

4156
def __init__(
4257
self,
@@ -51,8 +66,7 @@ def __init__(
5166
self.bucket_len = (right - left) / num_buckets
5267

5368
def get_bucket_index(self, x: float) -> int:
54-
assert self.left <= x <= self.right # only for debug
55-
if x == self.right: # last bucket includes both ends.
69+
if x >= self.right: # last bucket includes both ends.
5670
return self.num_buckets - 1
5771
if x <= self.left:
5872
return 0
@@ -79,17 +93,17 @@ def _compute_frequency_histogram_per_key(
7993
):
8094
"""Computes histogram of element frequencies in collection.
8195
82-
This is a helper function for computing sum histograms per key.
96+
This is a helper function for computing sum histograms per key.
8397
84-
Args:
98+
Args:
8599
col: collection of (key, value:float)
86100
backend: PipelineBackend to run operations on the collection.
87101
name: name which is assigned to the computed histogram.
88102
num_buckets: the number of buckets in the output histogram.
89103
90-
Returns:
104+
Returns:
91105
1 element collection which contains list [hist.Histogram], sorted by key.
92-
"""
106+
"""
93107
col = backend.to_multi_transformable_collection(col)
94108

95109
bucket_generators = _create_bucket_generators_per_key(
@@ -149,16 +163,16 @@ def sort_histograms_by_index(index_histogram):
149163

150164
def _create_bucket_generators_per_key(
151165
col, number_of_buckets: int, backend: pipeline_backend.PipelineBackend):
152-
"""Create bucket generators per key.
166+
"""Creates bucket generators per key.
153167
154-
Args:
155-
col: collection of (key, value)
156-
backend: PipelineBackend to run operations on the collection.
157-
num_buckets: the number of buckets in the output histogram.
168+
Args:
169+
col: collection of (key, value)
170+
backend: PipelineBackend to run operations on the collection.
171+
num_buckets: the number of buckets in the output histogram.
158172
159-
Returns:
160-
1 element collection with dictionary {key: LowerUpperGenerator}.
161-
"""
173+
Returns:
174+
1 element collection with dictionary {key: LowerUpperGenerator}.
175+
"""
162176
col = pipeline_functions.min_max_per_key(
163177
backend, col, "Min and max value per value column")
164178
# (index, (min, max))
@@ -178,20 +192,20 @@ def create_generators(index_min_max: List[Tuple[int, Tuple[float, float]]]):
178192

179193

180194
def _flat_values(col, backend: pipeline_backend.PipelineBackend):
181-
"""Unnest values in (key, value) collection.
195+
"""Unnests values in (key, value) collection.
182196
183-
Args:
184-
col: collection of (key, value) or (key, [value])
185-
backend: PipelineBackend to run operations on the collection.
197+
Args:
198+
col: collection of (key, value) or (key, [value])
199+
backend: PipelineBackend to run operations on the collection.
186200
187-
Transform each element:
201+
Transform each element:
188202
(key, value: float) -> ((0, key), value)
189203
(key, value: list[float]) -> [((0, key), value[0]), ((1, key), value[1])...]
190-
and then unnest them.
204+
and then unnest them.
191205
192-
Return:
193-
Collection of ((index, key), value).
194-
"""
206+
Return:
207+
Collection of ((index, key), value).
208+
"""
195209

196210
def flat_values(key_values):
197211
key, values = key_values
@@ -208,19 +222,19 @@ def _compute_linf_sum_contributions_histogram(
208222
col, backend: pipeline_backend.PipelineBackend):
209223
"""Computes histogram of per partition privacy id contributions.
210224
211-
This histogram contains: the number of (privacy id, partition_key)-pairs
212-
which have sum of values X_1, X_2, ..., X_n, where X_1 = min_sum,
213-
X_n = one before max sum and n is equal to
214-
NUMBER_OF_BUCKETS_SUM_HISTOGRAM.
225+
This histogram contains: the number of (privacy id, partition_key)-pairs
226+
which have sum of values X_1, X_2, ..., X_n, where X_1 = min_sum,
227+
X_n = one before max sum and n is equal to
228+
NUMBER_OF_BUCKETS_SUM_HISTOGRAM.
215229
216-
Args:
217-
col: collection with elements ((privacy_id, partition_key), value(s)).
218-
where value can be 1 float or tuple of floats (in case of many columns)
219-
backend: PipelineBackend to run operations on the collection.
230+
Args:
231+
col: collection with elements ((privacy_id, partition_key), value(s)).
232+
where value can be 1 float or tuple of floats (in case of many columns)
233+
backend: PipelineBackend to run operations on the collection.
220234
221-
Returns:
222-
1 element collection, which contains the computed hist.Histogram.
223-
"""
235+
Returns:
236+
1 element collection, which contains the computed hist.Histogram.
237+
"""
224238
col = _flat_values(col, backend)
225239
# ((index_value, (pid, pk)), value).
226240
col = backend.sum_per_key(
@@ -239,16 +253,16 @@ def _compute_partition_sum_histogram(col,
239253
backend: pipeline_backend.PipelineBackend):
240254
"""Computes histogram of sum per partition.
241255
242-
This histogram contains: the number of partition_keys which have sum of
243-
values X_1, X_2, ..., X_n, where X_1 = min_sum, X_n = one before max sum and
244-
n is equal to NUMBER_OF_BUCKETS_SUM_HISTOGRAM.
256+
This histogram contains: the number of partition_keys which have sum of
257+
values X_1, X_2, ..., X_n, where X_1 = min_sum, X_n = one before max sum and
258+
n is equal to NUMBER_OF_BUCKETS_SUM_HISTOGRAM.
245259
246-
Args:
260+
Args:
247261
col: collection with elements ((privacy_id, partition_key), value).
248262
backend: PipelineBackend to run operations on the collection.
249-
Returns:
250-
1 element collection, which contains the computed hist.Histogram.
251-
"""
263+
Returns:
264+
1 element collection, which contains the computed hist.Histogram.
265+
"""
252266

253267
col = backend.map_tuple(col, lambda pid_pk, value: (pid_pk[1], value),
254268
"Drop privacy id")
@@ -269,18 +283,18 @@ def _compute_linf_sum_contributions_histogram_on_preaggregated_data(
269283
col, backend: pipeline_backend.PipelineBackend):
270284
"""Computes histogram of per partition privacy id contributions.
271285
272-
This histogram contains: the number of (privacy id, partition_key)-pairs
273-
which have sum of values X_1, X_2, ..., X_n, where X_1 = min_sum,
274-
X_n = one before max sum and n is equal to
275-
NUMBER_OF_BUCKETS_SUM_HISTOGRAM.
286+
This histogram contains: the number of (privacy id, partition_key)-pairs
287+
which have sum of values X_1, X_2, ..., X_n, where X_1 = min_sum,
288+
X_n = one before max sum and n is equal to
289+
NUMBER_OF_BUCKETS_SUM_HISTOGRAM.
276290
277-
Args:
291+
Args:
278292
col: collection with a pre-aggregated dataset, each element is
279293
(partition_key, (count, sum, n_partitions, n_contributions)).
280294
backend: PipelineBackend to run operations on the collection.
281-
Returns:
295+
Returns:
282296
1 element collection, which contains the computed histograms.Histogram.
283-
"""
297+
"""
284298
col = backend.map_tuple(
285299
col,
286300
lambda _, x:
@@ -304,17 +318,17 @@ def _compute_partition_sum_histogram_on_preaggregated_data(
304318
col, backend: pipeline_backend.PipelineBackend):
305319
"""Computes histogram of counts per partition.
306320
307-
This histogram contains: the number of partition_keys which have sum of
308-
values X_1, X_2, ..., X_n, where X_1 = min_sum, X_n = one before max sum and
309-
n is equal to NUMBER_OF_BUCKETS_SUM_HISTOGRAM.
321+
This histogram contains: the number of partition_keys which have sum of
322+
values X_1, X_2, ..., X_n, where X_1 = min_sum, X_n = one before max sum and
323+
n is equal to NUMBER_OF_BUCKETS_SUM_HISTOGRAM.
310324
311-
Args:
325+
Args:
312326
col: collection with a pre-aggregated dataset, each element is
313327
(partition_key, (count, sum, n_partitions, n_contributions)).
314328
backend: PipelineBackend to run operations on the collection.
315-
Returns:
329+
Returns:
316330
1 element collection, which contains the computed histograms.Histogram.
317-
"""
331+
"""
318332
col = backend.map_values(
319333
col,
320334
lambda x: x[1], # x is (count, sum, n_partitions, n_contributions)

tests/dataset_histograms/sum_histogram_computation_test.py

+13
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,19 @@
2222
from analysis import pre_aggregation
2323

2424

25+
class LowerUpperGeneratorTest(parameterized.TestCase):
26+
27+
def test(self):
28+
g = sum_histogram_computation.LowerUpperGenerator(0, 10, num_buckets=20)
29+
self.assertEqual(g.bucket_len, 0.5)
30+
self.assertEqual(g.get_bucket_index(-1), 0)
31+
self.assertEqual(g.get_bucket_index(0), 0)
32+
self.assertEqual(g.get_bucket_index(0.5), 1)
33+
self.assertEqual(g.get_bucket_index(5.1), 10)
34+
self.assertEqual(g.get_bucket_index(10), 19)
35+
self.assertEqual(g.get_bucket_index(11), 19)
36+
37+
2538
class SumHistogramComputationTest(parameterized.TestCase):
2639

2740
@parameterized.product(

0 commit comments

Comments
 (0)