@@ -66,6 +66,7 @@ def make_initial_point_fns_per_chain(
66
66
model ,
67
67
overrides : StartDict | Sequence [StartDict | None ] | None ,
68
68
jitter_rvs : set [TensorVariable ] | None = None ,
69
+ jitter_scale : float = 1.0 ,
69
70
chains : int ,
70
71
) -> list [Callable ]:
71
72
"""Create an initial point function for each chain, as defined by initvals.
@@ -96,6 +97,7 @@ def make_initial_point_fns_per_chain(
96
97
model = model ,
97
98
overrides = overrides ,
98
99
jitter_rvs = jitter_rvs ,
100
+ jitter_scale = jitter_scale ,
99
101
return_transformed = True ,
100
102
)
101
103
] * chains
@@ -104,6 +106,7 @@ def make_initial_point_fns_per_chain(
104
106
make_initial_point_fn (
105
107
model = model ,
106
108
jitter_rvs = jitter_rvs ,
109
+ jitter_scale = jitter_scale ,
107
110
overrides = chain_overrides ,
108
111
return_transformed = True ,
109
112
)
@@ -122,6 +125,7 @@ def make_initial_point_fn(
122
125
model ,
123
126
overrides : StartDict | None = None ,
124
127
jitter_rvs : set [TensorVariable ] | None = None ,
128
+ jitter_scale : float = 1.0 ,
125
129
default_strategy : str = "support_point" ,
126
130
return_transformed : bool = True ,
127
131
) -> Callable :
@@ -150,6 +154,7 @@ def make_initial_point_fn(
150
154
rvs_to_transforms = model .rvs_to_transforms ,
151
155
initval_strategies = initval_strats ,
152
156
jitter_rvs = jitter_rvs ,
157
+ jitter_scale = jitter_scale ,
153
158
default_strategy = default_strategy ,
154
159
return_transformed = return_transformed ,
155
160
)
@@ -188,6 +193,7 @@ def make_initial_point_expression(
188
193
rvs_to_transforms : dict [TensorVariable , Transform ],
189
194
initval_strategies : dict [TensorVariable , np .ndarray | Variable | str | None ],
190
195
jitter_rvs : set [TensorVariable ] | None = None ,
196
+ jitter_scale : float = 1.0 ,
191
197
default_strategy : str = "support_point" ,
192
198
return_transformed : bool = False ,
193
199
) -> list [TensorVariable ]:
@@ -265,7 +271,7 @@ def make_initial_point_expression(
265
271
value = transform .forward (value , * variable .owner .inputs )
266
272
267
273
if variable in jitter_rvs :
268
- jitter = pt .random .uniform (- 1 , 1 , size = value .shape )
274
+ jitter = pt .random .uniform (- jitter_scale , jitter_scale , size = value .shape )
269
275
jitter .name = f"{ variable .name } _jitter"
270
276
value = value + jitter
271
277
0 commit comments