Skip to content

Commit f4b5153

Browse files
committed
Implement proper truncation for prior distributions
Currently, when sampled startpoints are outside the bounds, their value is set to the upper/lower bounds. This may put too much probability mass on the bounds. With these changes, we properly sample from the respective truncated distributions. Closes #330.
1 parent 8456635 commit f4b5153

File tree

4 files changed

+247
-72
lines changed

4 files changed

+247
-72
lines changed

Diff for: doc/example/distributions.ipynb

+19-9
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@
2323
},
2424
{
2525
"metadata": {
26-
"collapsed": true
26+
"collapsed": true,
27+
"jupyter": {
28+
"is_executing": true
29+
}
2730
},
2831
"cell_type": "code",
2932
"source": [
@@ -42,7 +45,7 @@
4245
" if ax is None:\n",
4346
" fig, ax = plt.subplots()\n",
4447
"\n",
45-
" sample = prior.sample(10000)\n",
48+
" sample = prior.sample(20_000)\n",
4649
"\n",
4750
" # pdf\n",
4851
" xmin = min(sample.min(), prior.lb_scaled if prior.bounds is not None else sample.min())\n",
@@ -138,11 +141,13 @@
138141
"metadata": {},
139142
"cell_type": "code",
140143
"source": [
144+
"# different, because transformation!=LIN\n",
141145
"plot(Prior(UNIFORM, (0.01, 2), transformation=LOG10))\n",
142146
"plot(Prior(PARAMETER_SCALE_UNIFORM, (0.01, 2), transformation=LOG10))\n",
143147
"\n",
148+
"# same, because transformation=LIN\n",
144149
"plot(Prior(UNIFORM, (0.01, 2), transformation=LIN))\n",
145-
"plot(Prior(PARAMETER_SCALE_UNIFORM, (0.01, 2), transformation=LIN))\n"
150+
"plot(Prior(PARAMETER_SCALE_UNIFORM, (0.01, 2), transformation=LIN))"
146151
],
147152
"id": "5ca940bc24312fc6",
148153
"outputs": [],
@@ -151,15 +156,18 @@
151156
{
152157
"metadata": {},
153158
"cell_type": "markdown",
154-
"source": "To prevent the sampled parameters from exceeding the bounds, the sampled parameters are clipped to the bounds. The bounds are defined in the parameter table. Note that the current implementation does not support sampling from a truncated distribution. Instead, the samples are clipped to the bounds. This may introduce unwanted bias, and thus, should only be used with caution (i.e., the bounds should be chosen wide enough):",
159+
"source": "The given distributions are truncated at the bounds defined in the parameter table:",
155160
"id": "b1a8b17d765db826"
156161
},
157162
{
158163
"metadata": {},
159164
"cell_type": "code",
160165
"source": [
161-
"plot(Prior(NORMAL, (0, 1), bounds=(-4, 4))) # negligible clipping-bias at 4 sigma\n",
162-
"plot(Prior(UNIFORM, (0, 1), bounds=(0.1, 0.9))) # significant clipping-bias"
166+
"plot(Prior(NORMAL, (0, 1), bounds=(-2, 2)))\n",
167+
"plot(Prior(UNIFORM, (0, 1), bounds=(0.1, 0.9)))\n",
168+
"plot(Prior(UNIFORM, (1e-8, 1), bounds=(0.1, 0.9), transformation=LOG10))\n",
169+
"plot(Prior(LAPLACE, (0, 1), bounds=(-0.5, 0.5)))\n",
170+
"plot(Prior(PARAMETER_SCALE_UNIFORM, (-3, 1), bounds=(1e-2, 1), transformation=LOG10))\n"
163171
],
164172
"id": "4ac42b1eed759bdd",
165173
"outputs": [],
@@ -175,9 +183,11 @@
175183
"metadata": {},
176184
"cell_type": "code",
177185
"source": [
178-
"plot(Prior(NORMAL, (10, 1), bounds=(6, 14), transformation=\"log10\"))\n",
179-
"plot(Prior(PARAMETER_SCALE_NORMAL, (10, 1), bounds=(10**6, 10**14), transformation=\"log10\"))\n",
180-
"plot(Prior(LAPLACE, (10, 2), bounds=(6, 14)))"
186+
"plot(Prior(NORMAL, (10, 1), bounds=(6, 11), transformation=\"log10\"))\n",
187+
"plot(Prior(PARAMETER_SCALE_NORMAL, (10, 1), bounds=(10**9, 10**14), transformation=\"log10\"))\n",
188+
"plot(Prior(LAPLACE, (10, 2), bounds=(6, 14)))\n",
189+
"plot(Prior(LOG_LAPLACE, (1, 0.5), bounds=(0.5, 8)))\n",
190+
"plot(Prior(LOG_NORMAL, (2, 1), bounds=(0.5, 8)))"
181191
],
182192
"id": "581e1ac431860419",
183193
"outputs": [],

Diff for: petab/v1/distributions.py

+157-24
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,65 @@ class Distribution(abc.ABC):
2828
If a float, the distribution is transformed to its corresponding
2929
log distribution with the given base (e.g., Normal -> Log10Normal).
3030
If ``False``, no transformation is applied.
31+
:param trunc: The truncation points (lower, upper) of the distribution
32+
or ``None`` if the distribution is not truncated.
3133
"""
3234

33-
def __init__(self, log: bool | float = False):
35+
def __init__(
36+
self, *, log: bool | float = False, trunc: tuple[float, float] = None
37+
):
3438
if log is True:
3539
log = np.exp(1)
40+
41+
if trunc == (-np.inf, np.inf):
42+
trunc = None
43+
44+
if trunc is not None and trunc[0] > trunc[1]:
45+
raise ValueError(
46+
"The lower truncation limit must be smaller "
47+
"than the upper truncation limit."
48+
)
49+
3650
self._logbase = log
51+
self._trunc = trunc
52+
53+
self._cd_low = None
54+
self._cd_high = None
55+
self._truncation_normalizer = 1
56+
57+
if self._trunc is not None:
58+
try:
59+
# the cumulative density of the transformed distribution at the
60+
# truncation limits
61+
self._cd_low = self._cdf_transformed_untruncated(
62+
self.trunc_low
63+
)
64+
self._cd_high = self._cdf_transformed_untruncated(
65+
self.trunc_high
66+
)
67+
# normalization factor for the PDF of the transformed
68+
# distribution to account for truncation
69+
self._truncation_normalizer = 1 / (
70+
self._cd_high - self._cd_low
71+
)
72+
except NotImplementedError:
73+
pass
74+
75+
@property
76+
def trunc_low(self) -> float:
77+
"""The lower truncation limit of the transformed distribution."""
78+
return self._trunc[0] if self._trunc else -np.inf
79+
80+
@property
81+
def trunc_high(self) -> float:
82+
"""The upper truncation limit of the transformed distribution."""
83+
return self._trunc[1] if self._trunc else np.inf
3784

38-
def _undo_log(self, x: np.ndarray | float) -> np.ndarray | float:
39-
"""Undo the log transformation.
85+
def _exp(self, x: np.ndarray | float) -> np.ndarray | float:
86+
"""Exponentiate / undo the log transformation according.
87+
88+
Exponentiate if a log transformation is applied to the distribution.
89+
Otherwise, return the input.
4090
4191
:param x: The sample to transform.
4292
:return: The transformed sample
@@ -45,9 +95,12 @@ def _undo_log(self, x: np.ndarray | float) -> np.ndarray | float:
4595
return x
4696
return self._logbase**x
4797

48-
def _apply_log(self, x: np.ndarray | float) -> np.ndarray | float:
98+
def _log(self, x: np.ndarray | float) -> np.ndarray | float:
4999
"""Apply the log transformation.
50100
101+
Compute the log of x with the specified base if a log transformation
102+
is applied to the distribution. Otherwise, return the input.
103+
51104
:param x: The value to transform.
52105
:return: The transformed value.
53106
"""
@@ -61,12 +114,17 @@ def sample(self, shape=None) -> np.ndarray:
61114
:param shape: The shape of the sample.
62115
:return: A sample from the distribution.
63116
"""
64-
sample = self._sample(shape)
65-
return self._undo_log(sample)
117+
sample = (
118+
self._exp(self._sample(shape))
119+
if self._trunc is None
120+
else self._inverse_transform_sample(shape)
121+
)
122+
123+
return sample
66124

67125
@abc.abstractmethod
68126
def _sample(self, shape=None) -> np.ndarray:
69-
"""Sample from the underlying distribution.
127+
"""Sample from the underlying distribution, accounting for truncation.
70128
71129
:param shape: The shape of the sample.
72130
:return: A sample from the underlying distribution,
@@ -85,7 +143,11 @@ def pdf(self, x):
85143
chain_rule_factor = (
86144
(1 / (x * np.log(self._logbase))) if self._logbase else 1
87145
)
88-
return self._pdf(self._apply_log(x)) * chain_rule_factor
146+
return (
147+
self._pdf(self._log(x))
148+
* chain_rule_factor
149+
* self._truncation_normalizer
150+
)
89151

90152
@abc.abstractmethod
91153
def _pdf(self, x):
@@ -104,13 +166,71 @@ def logbase(self) -> bool | float:
104166
"""
105167
return self._logbase
106168

169+
def cdf(self, x):
170+
"""Cumulative distribution function at x.
171+
172+
:param x: The value at which to evaluate the CDF.
173+
:return: The value of the CDF at ``x``.
174+
"""
175+
return self._cdf_transformed_untruncated(x) - self._cd_low
176+
177+
def _cdf_transformed_untruncated(self, x):
178+
"""Cumulative distribution function of the transformed, but untruncated
179+
distribution at x.
180+
181+
:param x: The value at which to evaluate the CDF.
182+
:return: The value of the CDF at ``x``.
183+
"""
184+
return self._cdf_untransformed_untruncated(self._log(x))
185+
186+
def _cdf_untransformed_untruncated(self, x):
187+
"""Cumulative distribution function of the underlying
188+
(untransformed, untruncated) distribution at x.
189+
190+
:param x: The value at which to evaluate the CDF.
191+
:return: The value of the CDF at ``x``.
192+
"""
193+
raise NotImplementedError
194+
195+
def _ppf_untransformed_untruncated(self, q):
196+
"""Percent point function of the underlying
197+
(untransformed, untruncated) distribution at q.
198+
199+
:param q: The quantile at which to evaluate the PPF.
200+
:return: The value of the PPF at ``q``.
201+
"""
202+
raise NotImplementedError
203+
204+
def _ppf_transformed_untruncated(self, q):
205+
"""Percent point function of the transformed, but untruncated
206+
distribution at q.
207+
208+
:param q: The quantile at which to evaluate the PPF.
209+
:return: The value of the PPF at ``q``.
210+
"""
211+
return self._exp(self._ppf_untransformed_untruncated(q))
212+
213+
def _inverse_transform_sample(self, shape):
214+
"""Generate an inverse transform sample from the transformed and
215+
truncated distribution.
216+
217+
:param shape: The shape of the sample.
218+
:return: The sample.
219+
"""
220+
uniform_sample = np.random.uniform(
221+
low=self._cd_low, high=self._cd_high, size=shape
222+
)
223+
return self._ppf_transformed_untruncated(uniform_sample)
224+
107225

108226
class Normal(Distribution):
109227
"""A (log-)normal distribution.
110228
111229
:param loc: The location parameter of the distribution.
112230
:param scale: The scale parameter of the distribution.
113-
:param truncation: The truncation limits of the distribution.
231+
:param trunc: The truncation limits of the distribution.
232+
``None`` if the distribution is not truncated. The truncation limits
233+
are the truncation limits of the transformed distribution.
114234
:param log: If ``True``, the distribution is transformed to a log-normal
115235
distribution. If a float, the distribution is transformed to a
116236
log-normal distribution with the given base.
@@ -124,19 +244,15 @@ def __init__(
124244
self,
125245
loc: float,
126246
scale: float,
127-
truncation: tuple[float, float] | None = None,
247+
trunc: tuple[float, float] | None = None,
128248
log: bool | float = False,
129249
):
130-
super().__init__(log=log)
131250
self._loc = loc
132251
self._scale = scale
133-
self._truncation = truncation
134-
135-
if truncation is not None:
136-
raise NotImplementedError("Truncation is not yet implemented.")
252+
super().__init__(log=log, trunc=trunc)
137253

138254
def __repr__(self):
139-
trunc = f", truncation={self._truncation}" if self._truncation else ""
255+
trunc = f", trunc={self._trunc}" if self._trunc else ""
140256
log = f", log={self._logbase}" if self._logbase else ""
141257
return f"Normal(loc={self._loc}, scale={self._scale}{trunc}{log})"
142258

@@ -146,6 +262,12 @@ def _sample(self, shape=None):
146262
def _pdf(self, x):
147263
return norm.pdf(x, loc=self._loc, scale=self._scale)
148264

265+
def _cdf_untransformed_untruncated(self, x):
266+
return norm.cdf(x, loc=self._loc, scale=self._scale)
267+
268+
def _ppf_untransformed_untruncated(self, q):
269+
return norm.ppf(q, loc=self._loc, scale=self._scale)
270+
149271
@property
150272
def loc(self):
151273
"""The location parameter of the underlying distribution."""
@@ -177,9 +299,9 @@ def __init__(
177299
*,
178300
log: bool | float = False,
179301
):
180-
super().__init__(log=log)
181302
self._low = low
182303
self._high = high
304+
super().__init__(log=log)
183305

184306
def __repr__(self):
185307
log = f", log={self._logbase}" if self._logbase else ""
@@ -191,13 +313,21 @@ def _sample(self, shape=None):
191313
def _pdf(self, x):
192314
return uniform.pdf(x, loc=self._low, scale=self._high - self._low)
193315

316+
def _cdf_untransformed_untruncated(self, x):
317+
return uniform.cdf(x, loc=self._low, scale=self._high - self._low)
318+
319+
def _ppf_untransformed_untruncated(self, q):
320+
return uniform.ppf(q, loc=self._low, scale=self._high - self._low)
321+
194322

195323
class Laplace(Distribution):
196324
"""A (log-)Laplace distribution.
197325
198326
:param loc: The location parameter of the distribution.
199327
:param scale: The scale parameter of the distribution.
200-
:param truncation: The truncation limits of the distribution.
328+
:param trunc: The truncation limits of the distribution.
329+
``None`` if the distribution is not truncated. The truncation limits
330+
are the truncation limits of the transformed distribution.
201331
:param log: If ``True``, the distribution is transformed to a log-Laplace
202332
distribution. If a float, the distribution is transformed to a
203333
log-Laplace distribution with the given base.
@@ -211,18 +341,15 @@ def __init__(
211341
self,
212342
loc: float,
213343
scale: float,
214-
truncation: tuple[float, float] | None = None,
344+
trunc: tuple[float, float] | None = None,
215345
log: bool | float = False,
216346
):
217-
super().__init__(log=log)
218347
self._loc = loc
219348
self._scale = scale
220-
self._truncation = truncation
221-
if truncation is not None:
222-
raise NotImplementedError("Truncation is not yet implemented.")
349+
super().__init__(log=log, trunc=trunc)
223350

224351
def __repr__(self):
225-
trunc = f", truncation={self._truncation}" if self._truncation else ""
352+
trunc = f", trunc={self._trunc}" if self._trunc else ""
226353
log = f", log={self._logbase}" if self._logbase else ""
227354
return f"Laplace(loc={self._loc}, scale={self._scale}{trunc}{log})"
228355

@@ -232,6 +359,12 @@ def _sample(self, shape=None):
232359
def _pdf(self, x):
233360
return laplace.pdf(x, loc=self._loc, scale=self._scale)
234361

362+
def _cdf_untransformed_untruncated(self, x):
363+
return laplace.cdf(x, loc=self._loc, scale=self._scale)
364+
365+
def _ppf_untransformed_untruncated(self, q):
366+
return laplace.ppf(q, loc=self._loc, scale=self._scale)
367+
235368
@property
236369
def loc(self):
237370
"""The location parameter of the underlying distribution."""

0 commit comments

Comments
 (0)