Skip to content

Commit 0fba15c

Browse files
committed
Allow debug evaling IR logp graphs
1 parent 268e13b commit 0fba15c

File tree

3 files changed

+30
-2
lines changed

3 files changed

+30
-2
lines changed

pymc/logprob/abstract.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -236,13 +236,16 @@ class ValuedRV(Op):
236236
and breaking the dependency of `b` on `a`. The new nodes isolate the graphs between conditioning points.
237237
"""
238238

239+
view_map = {0: [0]}
240+
239241
def make_node(self, rv, value):
240242
assert isinstance(rv, Variable)
241243
assert isinstance(value, Variable)
242244
return Apply(self, [rv, value], [rv.type(name=rv.name)])
243245

244246
def perform(self, node, inputs, out):
245-
raise NotImplementedError("ValuedVar should not be present in the final graph!")
247+
warnings.warn("ValuedVar should not be present in the final graph!")
248+
out[0][0] = inputs[0]
246249

247250
def infer_shape(self, fgraph, node, input_shapes):
248251
return [input_shapes[0]]

pymc/logprob/transform_value.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import warnings
1415

1516
from collections.abc import Sequence
1617

@@ -40,7 +41,8 @@ def make_node(self, tran_value: TensorVariable, value: TensorVariable):
4041
return Apply(self, [tran_value, value], [tran_value.type()])
4142

4243
def perform(self, node, inputs, outputs):
43-
raise NotImplementedError("These `Op`s should be removed from graphs used for computation.")
44+
warnings.warn("TransformedValue should not be present in the final graph!")
45+
outputs[0][0] = inputs[0]
4446

4547
def infer_shape(self, fgraph, node, input_shapes):
4648
return [input_shapes[0]]

tests/logprob/test_basic.py

+23
Original file line numberDiff line numberDiff line change
@@ -436,3 +436,26 @@ def test_ir_rewrite_does_not_disconnect_valued_rvs():
436436
logp_b.eval({a_value: np.pi, b_value: np.e}),
437437
stats.norm.logpdf(np.e, np.pi * 8, 1),
438438
)
439+
440+
441+
def test_ir_ops_can_be_evaluated_with_warning():
442+
_eval_values = [None, None]
443+
444+
def my_logp(value, lam):
445+
nonlocal _eval_values
446+
_eval_values[0] = value.eval()
447+
_eval_values[1] = lam.eval({"lam_log__": -1.5})
448+
return value * lam
449+
450+
with pm.Model() as m:
451+
lam = pm.Exponential("lam")
452+
pm.CustomDist("y", lam, logp=my_logp, observed=[0, 1, 2])
453+
454+
with pytest.warns(
455+
UserWarning, match="TransformedValue should not be present in the final graph"
456+
):
457+
with pytest.warns(UserWarning, match="ValuedVar should not be present in the final graph"):
458+
m.logp()
459+
460+
assert _eval_values[0].sum() == 3
461+
assert _eval_values[1] == np.exp(-1.5)

0 commit comments

Comments
 (0)