Skip to content

Commit

Permalink
Merge pull request #3112 from stan-dev/fix/3111-wiener_lpdf-expressio…
Browse files Browse the repository at this point in the history
…n-tests

Fix new wiener_lpdfs evaluating arguments twice
  • Loading branch information
WardBrian authored Oct 7, 2024
2 parents 9c7c3ff + 9e7af77 commit 2fdd3ed
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 17 deletions.
1 change: 1 addition & 0 deletions makefile
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ doxygen:
clean:
@echo ' removing generated test files'
@$(RM) $(wildcard test/prob/generate_tests$(EXE))
@$(RM) $(EXPRESSION_TESTS) $(call findfiles,test/expressions,*_test.cpp)
@$(RM) $(call findfiles,test/prob,*_generated_v_test.cpp)
@$(RM) $(call findfiles,test/prob,*_generated_vv_test.cpp)
@$(RM) $(call findfiles,test/prob,*_generated_fd_test.cpp)
Expand Down
2 changes: 1 addition & 1 deletion runTests.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def checkToolchainPathWindows():
universal_newlines=True,
)
out, err = p1.communicate()
if re.search(" |\(|\)", out):
if re.search(r" |\(|\)", out):
stopErr(
"The RTools toolchain is installed in a path with spaces or bracket. Please reinstall to a valid path.",
-1,
Expand Down
12 changes: 6 additions & 6 deletions stan/math/prim/prob/wiener5_lpdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -679,12 +679,12 @@ inline auto wiener_lpdf(const T_y& y, const T_a& a, const T_t0& t0,
if (!include_summand<propto, T_y, T_a, T_t0, T_w, T_v, T_sv>::value) {
return ret_t(0.0);
}
using T_y_ref = ref_type_if_t<!is_constant<T_y>::value, T_y>;
using T_a_ref = ref_type_if_t<!is_constant<T_a>::value, T_a>;
using T_t0_ref = ref_type_if_t<!is_constant<T_t0>::value, T_t0>;
using T_w_ref = ref_type_if_t<!is_constant<T_w>::value, T_w>;
using T_v_ref = ref_type_if_t<!is_constant<T_v>::value, T_v>;
using T_sv_ref = ref_type_if_t<!is_constant<T_sv>::value, T_sv>;
using T_y_ref = ref_type_t<T_y>;
using T_a_ref = ref_type_t<T_a>;
using T_t0_ref = ref_type_t<T_t0>;
using T_w_ref = ref_type_t<T_w>;
using T_v_ref = ref_type_t<T_v>;
using T_sv_ref = ref_type_t<T_sv>;

static constexpr const char* function_name = "wiener5_lpdf";

Expand Down
19 changes: 11 additions & 8 deletions stan/math/prim/prob/wiener_full_lpdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,14 +327,14 @@ inline auto wiener_lpdf(const T_y& y, const T_a& a, const T_t0& t0,
return ret_t(0);
}

using T_y_ref = ref_type_if_t<!is_constant<T_y>::value, T_y>;
using T_a_ref = ref_type_if_t<!is_constant<T_a>::value, T_a>;
using T_v_ref = ref_type_if_t<!is_constant<T_v>::value, T_v>;
using T_w_ref = ref_type_if_t<!is_constant<T_w>::value, T_w>;
using T_t0_ref = ref_type_if_t<!is_constant<T_t0>::value, T_t0>;
using T_sv_ref = ref_type_if_t<!is_constant<T_sv>::value, T_sv>;
using T_sw_ref = ref_type_if_t<!is_constant<T_sw>::value, T_sw>;
using T_st0_ref = ref_type_if_t<!is_constant<T_st0>::value, T_st0>;
using T_y_ref = ref_type_t<T_y>;
using T_a_ref = ref_type_t<T_a>;
using T_v_ref = ref_type_t<T_v>;
using T_w_ref = ref_type_t<T_w>;
using T_t0_ref = ref_type_t<T_t0>;
using T_sv_ref = ref_type_t<T_sv>;
using T_sw_ref = ref_type_t<T_sw>;
using T_st0_ref = ref_type_t<T_st0>;

using T_partials_return
= partials_return_t<T_y, T_a, T_t0, T_w, T_v, T_sv, T_sw, T_st0>;
Expand Down Expand Up @@ -449,6 +449,9 @@ inline auto wiener_lpdf(const T_y& y, const T_a& a, const T_t0& t0,
// calculate density and partials
for (size_t i = 0; i < N; i++) {
if (sw_vec[i] == 0 && st0_vec[i] == 0) {
// note: because we're delegating to wiener5_lpdf,
// we need to make sure is_constant is consistent between
// our inputs and these
result += wiener_lpdf<propto>(y_vec[i], a_vec[i], t0_vec[i], w_vec[i],
v_vec[i], sv_vec[i], precision_derivatives);
continue;
Expand Down
5 changes: 4 additions & 1 deletion test/generate_expression_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ def save_tests_in_files(N_files, tests):
for i in range(N_files):
start = i * len(tests) // N_files
end = (i + 1) * len(tests) // N_files
if start >= end:
# don't try to compile an empty file
continue
with open(src_folder + "tests%d_test.cpp" % i, "w") as out:
out.write("#include <test/expressions/expression_test_helpers.hpp>\n\n")
for test in tests[start:end]:
Expand Down Expand Up @@ -125,5 +128,5 @@ def main(functions=(), j=1):
code = cg.cpp(),
)
)

save_tests_in_files(j, tests)
2 changes: 1 addition & 1 deletion test/sig_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def get_cpp_type(stan_type):
"uniform_lcdf": [None, 0.2, 0.9],
"uniform_lpdf": [None, 0.2, 0.9],
"uniform_rng": [0.2, 1.9, None],
"wiener_lpdf": [0.8, None, 0.4, None, None],
"wiener_lpdf": [0.8, None, 0.4, None, None, None, None, None],
}

# list of functions we do not test. These are mainly functions implemented in compiler
Expand Down

0 comments on commit 2fdd3ed

Please sign in to comment.