Skip to content

Commit e0048bf

Browse files
committed
ENH: Add jitter_scale parameter for initial point generation (pymc-devs#7555)
1 parent 671d704 commit e0048bf

File tree

3 files changed

+38
-2
lines changed

3 files changed

+38
-2
lines changed

pymc/initial_point.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def make_initial_point_fns_per_chain(
6666
model,
6767
overrides: StartDict | Sequence[StartDict | None] | None,
6868
jitter_rvs: set[TensorVariable] | None = None,
69+
jitter_scale: float = 1.0,
6970
chains: int,
7071
) -> list[Callable]:
7172
"""Create an initial point function for each chain, as defined by initvals.
@@ -96,6 +97,7 @@ def make_initial_point_fns_per_chain(
9697
model=model,
9798
overrides=overrides,
9899
jitter_rvs=jitter_rvs,
100+
jitter_scale=jitter_scale,
99101
return_transformed=True,
100102
)
101103
] * chains
@@ -104,6 +106,7 @@ def make_initial_point_fns_per_chain(
104106
make_initial_point_fn(
105107
model=model,
106108
jitter_rvs=jitter_rvs,
109+
jitter_scale=jitter_scale,
107110
overrides=chain_overrides,
108111
return_transformed=True,
109112
)
@@ -122,6 +125,7 @@ def make_initial_point_fn(
122125
model,
123126
overrides: StartDict | None = None,
124127
jitter_rvs: set[TensorVariable] | None = None,
128+
jitter_scale: float = 1.0,
125129
default_strategy: str = "support_point",
126130
return_transformed: bool = True,
127131
) -> Callable:
@@ -150,6 +154,7 @@ def make_initial_point_fn(
150154
rvs_to_transforms=model.rvs_to_transforms,
151155
initval_strategies=initval_strats,
152156
jitter_rvs=jitter_rvs,
157+
jitter_scale=jitter_scale,
153158
default_strategy=default_strategy,
154159
return_transformed=return_transformed,
155160
)
@@ -188,6 +193,7 @@ def make_initial_point_expression(
188193
rvs_to_transforms: dict[TensorVariable, Transform],
189194
initval_strategies: dict[TensorVariable, np.ndarray | Variable | str | None],
190195
jitter_rvs: set[TensorVariable] | None = None,
196+
jitter_scale: float = 1.0,
191197
default_strategy: str = "support_point",
192198
return_transformed: bool = False,
193199
) -> list[TensorVariable]:
@@ -265,7 +271,7 @@ def make_initial_point_expression(
265271
value = transform.forward(value, *variable.owner.inputs)
266272

267273
if variable in jitter_rvs:
268-
jitter = pt.random.uniform(-1, 1, size=value.shape)
274+
jitter = pt.random.uniform(-jitter_scale, jitter_scale, size=value.shape)
269275
jitter.name = f"{variable.name}_jitter"
270276
value = value + jitter
271277

pymc/sampling/mcmc.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1423,10 +1423,11 @@ def _init_jitter(
14231423
initvals: StartDict | Sequence[StartDict | None] | None,
14241424
seeds: Sequence[int] | np.ndarray,
14251425
jitter: bool,
1426+
jitter_scale: float,
14261427
jitter_max_retries: int,
14271428
logp_dlogp_func=None,
14281429
) -> list[PointType]:
1429-
"""Apply a uniform jitter in [-1, 1] to the test value as starting point in each chain.
1430+
"""Apply a uniform jitter in [-jitter_scale, jitter_scale] to the test value as starting point in each chain.
14301431
14311432
``model.check_start_vals`` is used to test whether the jittered starting
14321433
values produce a finite log probability. Invalid values are resampled
@@ -1449,6 +1450,7 @@ def _init_jitter(
14491450
model=model,
14501451
overrides=initvals,
14511452
jitter_rvs=set(model.free_RVs) if jitter else set(),
1453+
jitter_scale=jitter_scale if jitter else 1.0,
14521454
chains=len(seeds),
14531455
)
14541456

tests/test_initial_point.py

+28
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,34 @@ def test_adds_jitter(self):
152152
assert fn(0) == fn(0)
153153
assert fn(0) != fn(1)
154154

155+
def test_jitter_scale(self):
156+
with pm.Model() as pmodel:
157+
A = pm.HalfFlat("A", initval="support_point")
158+
159+
jitter_scale_tests = np.array([1.0, 2.0, 5.0])
160+
fns = []
161+
for jitter_scale in jitter_scale_tests:
162+
fns.append(
163+
make_initial_point_fn(
164+
model=pmodel,
165+
jitter_rvs=set(pmodel.free_RVs),
166+
jitter_scale=jitter_scale,
167+
return_transformed=True,
168+
)
169+
)
170+
171+
n_draws = 1000
172+
jitter_samples = np.empty((n_draws, len(fns)))
173+
for j, fn in enumerate(fns):
174+
# start and end to ensure random samples, otherwise jitter_samples across different jitter_scale will be an exact scale of each other
175+
start = j * n_draws
176+
end = start + n_draws
177+
jitter_samples[:, j] = np.asarray([fn(i)["A_log__"] for i in range(start, end)])
178+
179+
init_standardised = np.mean((jitter_samples / jitter_scale_tests), axis=0)
180+
181+
assert np.all((-0.05 < init_standardised) & (init_standardised < 0.05))
182+
155183
def test_respects_overrides(self):
156184
with pm.Model() as pmodel:
157185
A = pm.Flat("A", initval="support_point")

0 commit comments

Comments
 (0)