@@ -400,13 +400,6 @@ def icdf(rv: TensorVariable, value: TensorLike, warn_rvs=None, **kwargs) -> Tens
400
400
return expr
401
401
402
402
403
- RVS_IN_JOINT_LOGP_GRAPH_MSG = (
404
- "Random variables detected in the logp graph: %s.\n "
405
- "This can happen when DensityDist logp or Interval transform functions reference nonlocal variables,\n "
406
- "or when not all rvs have a corresponding value variable."
407
- )
408
-
409
-
410
403
def conditional_logp (
411
404
rv_values : dict [TensorVariable , TensorVariable ],
412
405
warn_rvs = None ,
@@ -563,7 +556,11 @@ def conditional_logp(
563
556
if warn_rvs :
564
557
rvs_in_logp_expressions = _find_unallowed_rvs_in_graph (logprobs )
565
558
if rvs_in_logp_expressions :
566
- warnings .warn (RVS_IN_JOINT_LOGP_GRAPH_MSG % rvs_in_logp_expressions , UserWarning )
559
+ warnings .warn (
560
+ f"Random variables detected in the logp graph: { rvs_in_logp_expressions } .\n "
561
+ "This can happen when not all random variables have a corresponding value variable." ,
562
+ UserWarning ,
563
+ )
567
564
568
565
return values_to_logprobs
569
566
@@ -611,7 +608,11 @@ def transformed_conditional_logp(
611
608
612
609
rvs_in_logp_expressions = _find_unallowed_rvs_in_graph (logp_terms_list )
613
610
if rvs_in_logp_expressions :
614
- raise ValueError (RVS_IN_JOINT_LOGP_GRAPH_MSG % rvs_in_logp_expressions )
611
+ raise ValueError (
612
+ f"Random variables detected in the logp graph: { rvs_in_logp_expressions } .\n "
613
+ "This can happen when mixing variables from different models, "
614
+ "or when CustomDist logp or Interval transform functions reference nonlocal variables."
615
+ )
615
616
616
617
return logp_terms_list
617
618
0 commit comments