@@ -28,15 +28,65 @@ class Distribution(abc.ABC):
28
28
If a float, the distribution is transformed to its corresponding
29
29
log distribution with the given base (e.g., Normal -> Log10Normal).
30
30
If ``False``, no transformation is applied.
31
+ :param trunc: The truncation points (lower, upper) of the distribution
32
+ or ``None`` if the distribution is not truncated.
31
33
"""
32
34
33
- def __init__ (self , log : bool | float = False ):
35
+ def __init__ (
36
+ self , * , log : bool | float = False , trunc : tuple [float , float ] = None
37
+ ):
34
38
if log is True :
35
39
log = np .exp (1 )
40
+
41
+ if trunc == (- np .inf , np .inf ):
42
+ trunc = None
43
+
44
+ if trunc is not None and trunc [0 ] > trunc [1 ]:
45
+ raise ValueError (
46
+ "The lower truncation limit must be smaller "
47
+ "than the upper truncation limit."
48
+ )
49
+
36
50
self ._logbase = log
51
+ self ._trunc = trunc
52
+
53
+ self ._cd_low = None
54
+ self ._cd_high = None
55
+ self ._truncation_normalizer = 1
56
+
57
+ if self ._trunc is not None :
58
+ try :
59
+ # the cumulative density of the transformed distribution at the
60
+ # truncation limits
61
+ self ._cd_low = self ._cdf_transformed_untruncated (
62
+ self .trunc_low
63
+ )
64
+ self ._cd_high = self ._cdf_transformed_untruncated (
65
+ self .trunc_high
66
+ )
67
+ # normalization factor for the PDF of the transformed
68
+ # distribution to account for truncation
69
+ self ._truncation_normalizer = 1 / (
70
+ self ._cd_high - self ._cd_low
71
+ )
72
+ except NotImplementedError :
73
+ pass
74
+
75
+ @property
76
+ def trunc_low (self ) -> float :
77
+ """The lower truncation limit of the transformed distribution."""
78
+ return self ._trunc [0 ] if self ._trunc else - np .inf
79
+
80
+ @property
81
+ def trunc_high (self ) -> float :
82
+ """The upper truncation limit of the transformed distribution."""
83
+ return self ._trunc [1 ] if self ._trunc else np .inf
37
84
38
- def _undo_log (self , x : np .ndarray | float ) -> np .ndarray | float :
39
- """Undo the log transformation.
85
+ def _exp (self , x : np .ndarray | float ) -> np .ndarray | float :
86
+ """Exponentiate / undo the log transformation according.
87
+
88
+ Exponentiate if a log transformation is applied to the distribution.
89
+ Otherwise, return the input.
40
90
41
91
:param x: The sample to transform.
42
92
:return: The transformed sample
@@ -45,9 +95,12 @@ def _undo_log(self, x: np.ndarray | float) -> np.ndarray | float:
45
95
return x
46
96
return self ._logbase ** x
47
97
48
- def _apply_log (self , x : np .ndarray | float ) -> np .ndarray | float :
98
+ def _log (self , x : np .ndarray | float ) -> np .ndarray | float :
49
99
"""Apply the log transformation.
50
100
101
+ Compute the log of x with the specified base if a log transformation
102
+ is applied to the distribution. Otherwise, return the input.
103
+
51
104
:param x: The value to transform.
52
105
:return: The transformed value.
53
106
"""
@@ -61,12 +114,17 @@ def sample(self, shape=None) -> np.ndarray:
61
114
:param shape: The shape of the sample.
62
115
:return: A sample from the distribution.
63
116
"""
64
- sample = self ._sample (shape )
65
- return self ._undo_log (sample )
117
+ sample = (
118
+ self ._exp (self ._sample (shape ))
119
+ if self ._trunc is None
120
+ else self ._inverse_transform_sample (shape )
121
+ )
122
+
123
+ return sample
66
124
67
125
@abc .abstractmethod
68
126
def _sample (self , shape = None ) -> np .ndarray :
69
- """Sample from the underlying distribution.
127
+ """Sample from the underlying distribution, accounting for truncation .
70
128
71
129
:param shape: The shape of the sample.
72
130
:return: A sample from the underlying distribution,
@@ -85,7 +143,11 @@ def pdf(self, x):
85
143
chain_rule_factor = (
86
144
(1 / (x * np .log (self ._logbase ))) if self ._logbase else 1
87
145
)
88
- return self ._pdf (self ._apply_log (x )) * chain_rule_factor
146
+ return (
147
+ self ._pdf (self ._log (x ))
148
+ * chain_rule_factor
149
+ * self ._truncation_normalizer
150
+ )
89
151
90
152
@abc .abstractmethod
91
153
def _pdf (self , x ):
@@ -104,13 +166,71 @@ def logbase(self) -> bool | float:
104
166
"""
105
167
return self ._logbase
106
168
169
+ def cdf (self , x ):
170
+ """Cumulative distribution function at x.
171
+
172
+ :param x: The value at which to evaluate the CDF.
173
+ :return: The value of the CDF at ``x``.
174
+ """
175
+ return self ._cdf_transformed_untruncated (x ) - self ._cd_low
176
+
177
+ def _cdf_transformed_untruncated (self , x ):
178
+ """Cumulative distribution function of the transformed, but untruncated
179
+ distribution at x.
180
+
181
+ :param x: The value at which to evaluate the CDF.
182
+ :return: The value of the CDF at ``x``.
183
+ """
184
+ return self ._cdf_untransformed_untruncated (self ._log (x ))
185
+
186
+ def _cdf_untransformed_untruncated (self , x ):
187
+ """Cumulative distribution function of the underlying
188
+ (untransformed, untruncated) distribution at x.
189
+
190
+ :param x: The value at which to evaluate the CDF.
191
+ :return: The value of the CDF at ``x``.
192
+ """
193
+ raise NotImplementedError
194
+
195
+ def _ppf_untransformed_untruncated (self , q ):
196
+ """Percent point function of the underlying
197
+ (untransformed, untruncated) distribution at q.
198
+
199
+ :param q: The quantile at which to evaluate the PPF.
200
+ :return: The value of the PPF at ``q``.
201
+ """
202
+ raise NotImplementedError
203
+
204
+ def _ppf_transformed_untruncated (self , q ):
205
+ """Percent point function of the transformed, but untruncated
206
+ distribution at q.
207
+
208
+ :param q: The quantile at which to evaluate the PPF.
209
+ :return: The value of the PPF at ``q``.
210
+ """
211
+ return self ._exp (self ._ppf_untransformed_untruncated (q ))
212
+
213
+ def _inverse_transform_sample (self , shape ):
214
+ """Generate an inverse transform sample from the transformed and
215
+ truncated distribution.
216
+
217
+ :param shape: The shape of the sample.
218
+ :return: The sample.
219
+ """
220
+ uniform_sample = np .random .uniform (
221
+ low = self ._cd_low , high = self ._cd_high , size = shape
222
+ )
223
+ return self ._ppf_transformed_untruncated (uniform_sample )
224
+
107
225
108
226
class Normal (Distribution ):
109
227
"""A (log-)normal distribution.
110
228
111
229
:param loc: The location parameter of the distribution.
112
230
:param scale: The scale parameter of the distribution.
113
- :param truncation: The truncation limits of the distribution.
231
+ :param trunc: The truncation limits of the distribution.
232
+ ``None`` if the distribution is not truncated. The truncation limits
233
+ are the truncation limits of the transformed distribution.
114
234
:param log: If ``True``, the distribution is transformed to a log-normal
115
235
distribution. If a float, the distribution is transformed to a
116
236
log-normal distribution with the given base.
@@ -124,19 +244,15 @@ def __init__(
124
244
self ,
125
245
loc : float ,
126
246
scale : float ,
127
- truncation : tuple [float , float ] | None = None ,
247
+ trunc : tuple [float , float ] | None = None ,
128
248
log : bool | float = False ,
129
249
):
130
- super ().__init__ (log = log )
131
250
self ._loc = loc
132
251
self ._scale = scale
133
- self ._truncation = truncation
134
-
135
- if truncation is not None :
136
- raise NotImplementedError ("Truncation is not yet implemented." )
252
+ super ().__init__ (log = log , trunc = trunc )
137
253
138
254
def __repr__ (self ):
139
- trunc = f", truncation ={ self ._truncation } " if self ._truncation else ""
255
+ trunc = f", trunc ={ self ._trunc } " if self ._trunc else ""
140
256
log = f", log={ self ._logbase } " if self ._logbase else ""
141
257
return f"Normal(loc={ self ._loc } , scale={ self ._scale } { trunc } { log } )"
142
258
@@ -146,6 +262,12 @@ def _sample(self, shape=None):
146
262
def _pdf (self , x ):
147
263
return norm .pdf (x , loc = self ._loc , scale = self ._scale )
148
264
265
+ def _cdf_untransformed_untruncated (self , x ):
266
+ return norm .cdf (x , loc = self ._loc , scale = self ._scale )
267
+
268
+ def _ppf_untransformed_untruncated (self , q ):
269
+ return norm .ppf (q , loc = self ._loc , scale = self ._scale )
270
+
149
271
@property
150
272
def loc (self ):
151
273
"""The location parameter of the underlying distribution."""
@@ -177,9 +299,9 @@ def __init__(
177
299
* ,
178
300
log : bool | float = False ,
179
301
):
180
- super ().__init__ (log = log )
181
302
self ._low = low
182
303
self ._high = high
304
+ super ().__init__ (log = log )
183
305
184
306
def __repr__ (self ):
185
307
log = f", log={ self ._logbase } " if self ._logbase else ""
@@ -191,13 +313,21 @@ def _sample(self, shape=None):
191
313
def _pdf (self , x ):
192
314
return uniform .pdf (x , loc = self ._low , scale = self ._high - self ._low )
193
315
316
+ def _cdf_untransformed_untruncated (self , x ):
317
+ return uniform .cdf (x , loc = self ._low , scale = self ._high - self ._low )
318
+
319
+ def _ppf_untransformed_untruncated (self , q ):
320
+ return uniform .ppf (q , loc = self ._low , scale = self ._high - self ._low )
321
+
194
322
195
323
class Laplace (Distribution ):
196
324
"""A (log-)Laplace distribution.
197
325
198
326
:param loc: The location parameter of the distribution.
199
327
:param scale: The scale parameter of the distribution.
200
- :param truncation: The truncation limits of the distribution.
328
+ :param trunc: The truncation limits of the distribution.
329
+ ``None`` if the distribution is not truncated. The truncation limits
330
+ are the truncation limits of the transformed distribution.
201
331
:param log: If ``True``, the distribution is transformed to a log-Laplace
202
332
distribution. If a float, the distribution is transformed to a
203
333
log-Laplace distribution with the given base.
@@ -211,18 +341,15 @@ def __init__(
211
341
self ,
212
342
loc : float ,
213
343
scale : float ,
214
- truncation : tuple [float , float ] | None = None ,
344
+ trunc : tuple [float , float ] | None = None ,
215
345
log : bool | float = False ,
216
346
):
217
- super ().__init__ (log = log )
218
347
self ._loc = loc
219
348
self ._scale = scale
220
- self ._truncation = truncation
221
- if truncation is not None :
222
- raise NotImplementedError ("Truncation is not yet implemented." )
349
+ super ().__init__ (log = log , trunc = trunc )
223
350
224
351
def __repr__ (self ):
225
- trunc = f", truncation ={ self ._truncation } " if self ._truncation else ""
352
+ trunc = f", trunc ={ self ._trunc } " if self ._trunc else ""
226
353
log = f", log={ self ._logbase } " if self ._logbase else ""
227
354
return f"Laplace(loc={ self ._loc } , scale={ self ._scale } { trunc } { log } )"
228
355
@@ -232,6 +359,12 @@ def _sample(self, shape=None):
232
359
def _pdf (self , x ):
233
360
return laplace .pdf (x , loc = self ._loc , scale = self ._scale )
234
361
362
+ def _cdf_untransformed_untruncated (self , x ):
363
+ return laplace .cdf (x , loc = self ._loc , scale = self ._scale )
364
+
365
+ def _ppf_untransformed_untruncated (self , q ):
366
+ return laplace .ppf (q , loc = self ._loc , scale = self ._scale )
367
+
235
368
@property
236
369
def loc (self ):
237
370
"""The location parameter of the underlying distribution."""
0 commit comments