Skip to content

Commit 4a0cbd1

Browse files
committed
Better guesses for why logp has RVs
1 parent fa43eba commit 4a0cbd1

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

pymc/logprob/basic.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -400,13 +400,6 @@ def icdf(rv: TensorVariable, value: TensorLike, warn_rvs=None, **kwargs) -> Tens
400400
return expr
401401

402402

403-
RVS_IN_JOINT_LOGP_GRAPH_MSG = (
404-
"Random variables detected in the logp graph: %s.\n"
405-
"This can happen when DensityDist logp or Interval transform functions reference nonlocal variables,\n"
406-
"or when not all rvs have a corresponding value variable."
407-
)
408-
409-
410403
def conditional_logp(
411404
rv_values: dict[TensorVariable, TensorVariable],
412405
warn_rvs=None,
@@ -563,7 +556,11 @@ def conditional_logp(
563556
if warn_rvs:
564557
rvs_in_logp_expressions = _find_unallowed_rvs_in_graph(logprobs)
565558
if rvs_in_logp_expressions:
566-
warnings.warn(RVS_IN_JOINT_LOGP_GRAPH_MSG % rvs_in_logp_expressions, UserWarning)
559+
warnings.warn(
560+
f"Random variables detected in the logp graph: {rvs_in_logp_expressions}.\n"
561+
"This can happen when not all random variables have a corresponding value variable.",
562+
UserWarning,
563+
)
567564

568565
return values_to_logprobs
569566

@@ -611,7 +608,11 @@ def transformed_conditional_logp(
611608

612609
rvs_in_logp_expressions = _find_unallowed_rvs_in_graph(logp_terms_list)
613610
if rvs_in_logp_expressions:
614-
raise ValueError(RVS_IN_JOINT_LOGP_GRAPH_MSG % rvs_in_logp_expressions)
611+
raise ValueError(
612+
f"Random variables detected in the logp graph: {rvs_in_logp_expressions}.\n"
613+
"This can happen when mixing variables from different models, "
614+
"or when CustomDist logp or Interval transform functions reference nonlocal variables."
615+
)
615616

616617
return logp_terms_list
617618

0 commit comments

Comments
 (0)