Skip to content

Commit

Permalink
Fix a lot of functions that fail the expr test
Browse files Browse the repository at this point in the history
  • Loading branch information
SteveBronder committed Aug 11, 2023
1 parent 361877c commit 492a7ae
Show file tree
Hide file tree
Showing 32 changed files with 147 additions and 86 deletions.
12 changes: 6 additions & 6 deletions stan/math/fwd/fun/mdivide_left_ldlt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,16 @@ mdivide_left_ldlt(LDLT_factor<T>& A, const EigMat& b) {
check_multiplicable("mdivide_left_ldlt", "A", A.matrix(), "b", b);

const auto& b_ref = to_ref(b);
Eigen::Matrix<EigMatValueScalar, R2, C2> b_val(b.rows(), b.cols());
Eigen::Matrix<EigMatValueScalar, R2, C2> b_der(b.rows(), b.cols());
for (int j = 0; j < b.cols(); j++) {
for (int i = 0; i < b.rows(); i++) {
Eigen::Matrix<EigMatValueScalar, R2, C2> b_val(b_ref.rows(), b_ref.cols());
Eigen::Matrix<EigMatValueScalar, R2, C2> b_der(b_ref.rows(), b_ref.cols());
for (int j = 0; j < b_ref.cols(); j++) {
for (int i = 0; i < b_ref.rows(); i++) {
b_val.coeffRef(i, j) = b_ref.coeff(i, j).val_;
b_der.coeffRef(i, j) = b_ref.coeff(i, j).d_;
}
}

return to_fvar(mdivide_left_ldlt(A, b_val), mdivide_left_ldlt(A, b_der));
auto&& A_ref = to_ref(A);
return to_fvar(mdivide_left_ldlt(A_ref, b_val), mdivide_left_ldlt(A_ref, b_der));
}

} // namespace math
Expand Down
2 changes: 1 addition & 1 deletion stan/math/prim/fun/diag_post_multiply.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ template <typename T1, typename T2, require_eigen_t<T1>* = nullptr,
auto diag_post_multiply(const T1& m1, const T2& m2) {
check_size_match("diag_post_multiply", "m2.size()", m2.size(), "m1.cols()",
m1.cols());
return m1 * m2.asDiagonal();
return to_ref(m1) * to_ref(m2).asDiagonal();
}

} // namespace math
Expand Down
2 changes: 1 addition & 1 deletion stan/math/prim/fun/diag_pre_multiply.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ auto diag_pre_multiply(const T1& m1, const T2& m2) {
check_size_match("diag_pre_multiply", "m1.size()", m1.size(), "m2.rows()",
m2.rows());

return m1.asDiagonal() * m2;
return to_ref(m1).asDiagonal() * to_ref(m2);
}

} // namespace math
Expand Down
20 changes: 11 additions & 9 deletions stan/math/prim/fun/trace_gen_inv_quad_form_ldlt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,16 @@ template <typename EigMat1, typename T2, typename EigMat3,
require_all_not_st_var<EigMat1, T2, EigMat3>* = nullptr>
inline return_type_t<EigMat1, T2, EigMat3> trace_gen_inv_quad_form_ldlt(
const EigMat1& D, LDLT_factor<T2>& A, const EigMat3& B) {
check_square("trace_gen_inv_quad_form_ldlt", "D", D);
auto&& D_ref = to_ref(D);
check_square("trace_gen_inv_quad_form_ldlt", "D", D_ref);
check_multiplicable("trace_gen_inv_quad_form_ldlt", "A", A.matrix(), "B", B);
check_multiplicable("trace_gen_inv_quad_form_ldlt", "B", B, "D", D);
check_multiplicable("trace_gen_inv_quad_form_ldlt", "B", B, "D", D_ref);

if (D.size() == 0 || A.matrix().size() == 0) {
if (D_ref.size() == 0 || A.matrix().size() == 0) {
return 0;
}

return multiply(B, D.transpose()).cwiseProduct(mdivide_left_ldlt(A, B)).sum();
auto&& B_ref = to_ref(B);
return multiply(B_ref, D_ref.transpose()).cwiseProduct(mdivide_left_ldlt(A, B_ref)).sum();
}

/**
Expand All @@ -68,14 +69,15 @@ template <typename EigVec, typename T, typename EigMat,
require_all_not_st_var<EigVec, T, EigMat>* = nullptr>
inline return_type_t<EigVec, T, EigMat> trace_gen_inv_quad_form_ldlt(
const EigVec& D, LDLT_factor<T>& A, const EigMat& B) {
auto&& D_ref = to_ref(D);
check_multiplicable("trace_gen_inv_quad_form_ldlt", "A", A.matrix(), "B", B);
check_multiplicable("trace_gen_inv_quad_form_ldlt", "B", B, "D", D);
check_multiplicable("trace_gen_inv_quad_form_ldlt", "B", B, "D", D_ref);

if (D.size() == 0 || A.matrix().size() == 0) {
if (D_ref.size() == 0 || A.matrix().size() == 0) {
return 0;
}

return (B * D.asDiagonal()).cwiseProduct(mdivide_left_ldlt(A, B)).sum();
auto&& B_ref = to_ref(B);
return (B_ref * D_ref.asDiagonal()).cwiseProduct(mdivide_left_ldlt(A, B_ref)).sum();
}

} // namespace math
Expand Down
4 changes: 2 additions & 2 deletions stan/math/prim/fun/trace_inv_quad_form_ldlt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ inline return_type_t<T, EigMat2> trace_inv_quad_form_ldlt(LDLT_factor<T>& A,
if (A.matrix().size() == 0) {
return 0;
}

return B.cwiseProduct(mdivide_left_ldlt(A, B)).sum();
auto&& B_ref = to_ref(B);
return B_ref.cwiseProduct(mdivide_left_ldlt(A, B_ref)).sum();
}

} // namespace math
Expand Down
26 changes: 15 additions & 11 deletions stan/math/prim/prob/matrix_normal_prec_lpdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,31 +33,35 @@ namespace math {
template <bool propto, typename T_y, typename T_Mu, typename T_Sigma,
typename T_D,
require_all_matrix_t<T_y, T_Mu, T_Sigma, T_D>* = nullptr>
return_type_t<T_y, T_Mu, T_Sigma, T_D> matrix_normal_prec_lpdf(
inline return_type_t<T_y, T_Mu, T_Sigma, T_D> matrix_normal_prec_lpdf(
const T_y& y, const T_Mu& Mu, const T_Sigma& Sigma, const T_D& D) {
static const char* function = "matrix_normal_prec_lpdf";
auto&& y_ref = to_ref(y);
auto&& Mu_ref = to_ref(Mu);
auto&& Sigma_ref = to_ref(Sigma);
auto&& D_ref = to_ref(D);
check_positive(function, "Sigma rows", Sigma.rows());
check_finite(function, "Sigma", Sigma);
check_symmetric(function, "Sigma", Sigma);
check_finite(function, "Sigma", Sigma_ref);
check_symmetric(function, "Sigma", Sigma_ref);

auto ldlt_Sigma = make_ldlt_factor(Sigma);
auto ldlt_Sigma = make_ldlt_factor(Sigma_ref);
check_ldlt_factor(function, "LDLT_Factor of Sigma", ldlt_Sigma);
check_positive(function, "D rows", D.rows());
check_finite(function, "D", D);
check_symmetric(function, "D", D);
check_finite(function, "D", D_ref);
check_symmetric(function, "D", D_ref);

auto ldlt_D = make_ldlt_factor(D);
auto ldlt_D = make_ldlt_factor(D_ref);
check_ldlt_factor(function, "LDLT_Factor of D", ldlt_D);
check_size_match(function, "Rows of random variable", y.rows(),
"Rows of location parameter", Mu.rows());
check_size_match(function, "Columns of random variable", y.cols(),
"Columns of location parameter", Mu.cols());
"Columns of location parameter", Mu_ref.cols());
check_size_match(function, "Rows of random variable", y.rows(),
"Rows of Sigma", Sigma.rows());
check_size_match(function, "Columns of random variable", y.cols(),
"Rows of D", D.rows());
check_finite(function, "Location parameter", Mu);
check_finite(function, "Random variable", y);
check_finite(function, "Location parameter", Mu_ref);
check_finite(function, "Random variable", y_ref);

return_type_t<T_y, T_Mu, T_Sigma, T_D> lp(0.0);

Expand All @@ -74,7 +78,7 @@ return_type_t<T_y, T_Mu, T_Sigma, T_D> matrix_normal_prec_lpdf(
}

if (include_summand<propto, T_y, T_Mu, T_Sigma, T_D>::value) {
lp -= 0.5 * trace_gen_quad_form(D, Sigma, subtract(y, Mu));
lp -= 0.5 * trace_gen_quad_form(D_ref, Sigma_ref, subtract(y_ref, Mu_ref));
}
return lp;
}
Expand Down
13 changes: 7 additions & 6 deletions stan/math/prim/prob/multi_normal_prec_lpdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,17 @@ return_type_t<T_y, T_loc, T_covar> multi_normal_prec_lpdf(
return 0;
}
check_consistent_sizes_mvt(function, "y", y, "mu", mu);

auto&& y_ref = to_ref(y);
auto&& mu_ref = to_ref(mu);
lp_type lp(0);
vector_seq_view<T_y> y_vec(y);
vector_seq_view<T_loc> mu_vec(mu);
size_t size_vec = max_size_mvt(y, mu);
vector_seq_view<T_y> y_vec(y_ref);
vector_seq_view<T_loc> mu_vec(mu_ref);
size_t size_vec = max_size_mvt(y_ref, mu_ref);

int size_y = y_vec[0].size();
int size_mu = mu_vec[0].size();
if (size_vec > 1) {
for (size_t i = 1, size_mvt_y = size_mvt(y); i < size_mvt_y; i++) {
for (size_t i = 1, size_mvt_y = size_mvt(y_ref); i < size_mvt_y; i++) {
check_size_match(function,
"Size of one of the vectors "
"of the random variable",
Expand All @@ -50,7 +51,7 @@ return_type_t<T_y, T_loc, T_covar> multi_normal_prec_lpdf(
"the random variable",
size_y);
}
for (size_t i = 1, size_mvt_mu = size_mvt(mu); i < size_mvt_mu; i++) {
for (size_t i = 1, size_mvt_mu = size_mvt(mu_ref); i < size_mvt_mu; i++) {
check_size_match(function,
"Size of one of the vectors "
"of the location variable",
Expand Down
1 change: 1 addition & 0 deletions test/unit/math/mix/fun/atanh_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ TEST(mathMixMatFun, atanh) {
stan::test::expect_ad(f, std::complex<double>{re, im});
}
}

}

TEST(mathMixMatFun, atanh_varmat) {
Expand Down
6 changes: 4 additions & 2 deletions test/unit/math/mix/fun/eigenvalues_sym_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
TEST(MathMixMatFun, eigenvaluesSym) {
auto f = [](const auto& y) {
// need to maintain symmetry for finite diffs
auto a = ((y + y.transpose()) * 0.5).eval();
auto&& y_ref = stan::math::to_ref(y);
auto a = ((y_ref + y_ref.transpose()) * 0.5).eval();
return stan::math::eigenvalues_sym(a);
};

Expand Down Expand Up @@ -38,7 +39,8 @@ TEST(MathMixMatFun, eigenvaluesSym) {
TEST(MathMixMatFun, eigenvaluesSym_varmat) {
auto f = [](const auto& y) {
// need to maintain symmetry for finite diffs
auto a = stan::math::multiply((y + y.transpose()), 0.5).eval();
auto&& y_ref = stan::math::to_ref(y);
auto a = stan::math::multiply((y_ref + y_ref.transpose()), 0.5).eval();
return stan::math::eigenvalues_sym(a);
};

Expand Down
6 changes: 4 additions & 2 deletions test/unit/math/mix/fun/eigenvectors_sym_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ TEST(MathMixMatFun, eigenvectorsSym) {
if (y.rows() != y.cols()) {
return stan::math::eigenvectors_sym(y);
}
auto a = ((y + y.transpose()) * 0.5).eval();
auto&& y_ref = stan::math::to_ref(y);
auto a = ((y_ref + y_ref.transpose()) * 0.5).eval();
return stan::math::eigenvectors_sym(a);
};

Expand Down Expand Up @@ -37,7 +38,8 @@ TEST(MathMixMatFun, eigenvectorsSym_varmat) {
if (y.rows() != y.cols()) {
return stan::math::eigenvectors_sym(y);
}
auto a = stan::math::multiply((y + y.transpose()), 0.5).eval();
auto&& y_ref = stan::math::to_ref(y);
auto a = stan::math::multiply((y_ref + y_ref.transpose()), 0.5).eval();
return stan::math::eigenvectors_sym(a);
};

Expand Down
5 changes: 4 additions & 1 deletion test/unit/math/mix/fun/elt_divide_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

TEST(MathMixMatFun, elt_divide_transpose_test) {
auto f
= [](const auto& x) { return stan::math::elt_divide(x, x.transpose()); };
= [](const auto& x) {
auto&& x_ref = stan::math::to_ref(x);
return stan::math::elt_divide(x_ref, x_ref.transpose());
};

Eigen::MatrixXd x(2, 2);

Expand Down
3 changes: 2 additions & 1 deletion test/unit/math/mix/fun/elt_multiply_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

TEST(MathMixMatFun, elt_multiply_transpose_test) {
auto f = [](const auto& x) {
return stan::math::elt_multiply(x, x.transpose());
auto x_ref = stan::math::to_ref(x);
return stan::math::elt_multiply(x_ref, x_ref.transpose());
};

Eigen::MatrixXd x = Eigen::MatrixXd::Random(2, 2);
Expand Down
4 changes: 3 additions & 1 deletion test/unit/math/mix/fun/inverse_spd_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ TEST(MathMixMatFun, inverseSpd) {
auto f = [](const auto& x) {
if (x.rows() != x.cols())
return stan::math::inverse_spd(x);
auto y = ((x + x.transpose()) * 0.5).eval(); // symmetry for finite diffs

auto x_ref = stan::math::to_ref(x);
auto y = ((x_ref + x_ref.transpose()) * 0.5).eval(); // symmetry for finite diffs
return stan::math::inverse_spd(y);
};

Expand Down
3 changes: 0 additions & 3 deletions test/unit/math/mix/fun/lb_constrain_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,9 @@ TEST(mathMixMatFun, lb_stdvec_mat_mat_constrain) {
std::vector<Eigen::MatrixXd> A;
A.push_back(A_inner);
A.push_back(A_inner2);
std::cout << "111111 " << std::endl;
lb_constrain_test::expect_vec(A, lbm_inner);
std::cout << "2222222 " << std::endl;
lb_constrain_test::expect_vec(A, lbm_inner_bad);
double lbd = 6.0;
std::cout << "333333 " << std::endl;
lb_constrain_test::expect_vec(A, lbd);
}

Expand Down
3 changes: 2 additions & 1 deletion test/unit/math/mix/fun/log_determinant_ldlt_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

TEST(MathMixMatFun, logDeterminantLdlt) {
auto f = [](const auto& x) {
auto x_sym = stan::math::multiply(0.5, x + x.transpose());
auto&& x_ref = stan::math::to_ref(x);
auto x_sym = stan::math::multiply(0.5, x_ref + x_ref.transpose());
auto y = stan::math::make_ldlt_factor(x_sym);
return stan::math::log_determinant_ldlt(y);
};
Expand Down
3 changes: 2 additions & 1 deletion test/unit/math/mix/fun/log_determinant_spd_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

TEST(MathMixMatFun, logDeterminantSpd) {
auto f = [](const auto& x) {
auto z = stan::math::multiply(x + x.transpose(), 0.5);
auto&& x_ref = stan::math::to_ref(x);
auto z = stan::math::multiply(x_ref + x_ref.transpose(), 0.5);
return stan::math::log_determinant_spd(z);
};

Expand Down
3 changes: 2 additions & 1 deletion test/unit/math/mix/fun/mdivide_left_ldlt_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

TEST(MathMixMatFun, mdivideLeftLdlt) {
auto f = [](const auto& x, const auto& y) {
auto x_sym = stan::math::multiply(0.5, x + x.transpose());
auto&& x_ref = stan::math::to_ref(x);
auto x_sym = stan::math::multiply(0.5, x_ref + x_ref.transpose());
auto ldlt = stan::math::make_ldlt_factor(x_sym);
return stan::math::mdivide_left_ldlt(ldlt, y);
};
Expand Down
4 changes: 3 additions & 1 deletion test/unit/math/mix/fun/mdivide_left_spd_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ TEST(MathMixMatFun, mdivideLeftSpd) {
auto f = [](const auto& x, const auto& y) {
if (x.rows() != x.cols())
return stan::math::mdivide_left_spd(x, y);

auto x_ref = stan::math::to_ref(x);
auto x_sym = stan::math::eval(
stan::math::multiply(x + x.transpose(), 0.5)); // sym for finite diffs
stan::math::multiply(x_ref + x_ref.transpose(), 0.5)); // sym for finite diffs
return stan::math::mdivide_left_spd(x_sym, y);
};

Expand Down
3 changes: 2 additions & 1 deletion test/unit/math/mix/fun/mdivide_right_ldlt_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
TEST(MathMixMatFun, mdivideRightLdlt) {
using stan::test::relative_tolerance;
auto f = [](const auto& x, const auto& y) {
auto y_sym = stan::math::multiply(0.5, y + y.transpose()).eval();
auto&& y_ref = stan::math::to_ref(y);
auto y_sym = stan::math::multiply(0.5, y_ref + y_ref.transpose()).eval();
auto ldlt = stan::math::make_ldlt_factor(y_sym);
return stan::math::mdivide_right_ldlt(x, ldlt);
};
Expand Down
4 changes: 3 additions & 1 deletion test/unit/math/mix/fun/mdivide_right_spd_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ TEST(MathMixMatFun, mdivideRightSpd) {
auto f = [](const auto& x, const auto& y) {
if (y.rows() != y.cols())
return stan::math::mdivide_right_spd(x, y);
auto y_sym = ((y + y.transpose()) * 0.5).eval(); // sym for finite diffs

auto&& y_ref = stan::math::to_ref(y);
auto y_sym = ((y_ref + y_ref.transpose()) * 0.5).eval(); // sym for finite diffs
return stan::math::mdivide_right_spd(x, y_sym);
};

Expand Down
7 changes: 5 additions & 2 deletions test/unit/math/mix/fun/quad_form_sym_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

TEST(MathMixMatFun, quadFormSym) {
auto f = [](const auto& x, const auto& y) {
auto&& x_ref = stan::math::to_ref(x);
// symmetrize the input matrix
auto x_sym = ((x + x.transpose()) * 0.5).eval();
auto x_sym = ((x_ref + x_ref.transpose()) * 0.5).eval();
return stan::math::quad_form_sym(x_sym, y);
};

Expand Down Expand Up @@ -96,7 +97,9 @@ TEST(MathMixMatFun, quad_form_sym_2095) {
auto f = [](const auto& x, const auto& y) {
// symmetrize the input matrix;
// expect_ad will perturb elements and cause it not to be symmetric
auto x_sym = ((x + x.transpose()) * 0.5).eval();
auto&& x_ref = stan::math::to_ref(x);
// symmetrize the input matrix
auto x_sym = ((x_ref + x_ref.transpose()) * 0.5).eval();
return stan::math::quad_form_sym(x_sym, y);
};
stan::test::expect_ad(f, ad, bd);
Expand Down
10 changes: 4 additions & 6 deletions test/unit/math/mix/fun/trace_gen_inv_quad_form_ldlt_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

TEST(mathMixMatFun, traceGenInvQuadForm) {
auto f = [](const auto& c, const auto& a, const auto& b) {
auto x_sym = stan::math::multiply(0.5, a + a.transpose());
auto&& a_ref = stan::math::to_ref(a);
auto x_sym = stan::math::multiply(0.5, a_ref + a_ref.transpose());
auto ldlt_a = stan::math::make_ldlt_factor(x_sym);
return stan::math::trace_gen_inv_quad_form_ldlt(c, ldlt_a, b);
};
Expand Down Expand Up @@ -76,15 +77,12 @@ TEST(mathMixMatFun, traceGenInvQuadForm) {

TEST(mathMixMatFun, traceGenInvQuadForm_vec) {
auto f = [](const auto& c, const auto& a, const auto& b) {
auto x_sym = stan::math::multiply(0.5, a + a.transpose());
auto&& a_ref = stan::math::to_ref(a);
auto x_sym = stan::math::multiply(0.5, a_ref + a_ref.transpose()).eval();
auto ldlt_a = stan::math::make_ldlt_factor(x_sym);
return stan::math::trace_gen_inv_quad_form_ldlt(c, ldlt_a, b);
};

auto f1 = [&](const auto& c) {
return [&](const auto& a, const auto& b) { return f(c, a, b); };
};

Eigen::MatrixXd a00(0, 0);
Eigen::MatrixXd b00(0, 0);
Eigen::VectorXd c0(0);
Expand Down
3 changes: 2 additions & 1 deletion test/unit/math/mix/fun/trace_inv_quad_form_ldlt_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

TEST(MathMixMatFun, traceInvQuadFormLdlt) {
auto f = [](const auto& x, const auto& y) {
auto x_sym = stan::math::multiply(0.5, x + x.transpose());
auto&& x_ref = stan::math::to_ref(x);
auto x_sym = stan::math::multiply(0.5, x_ref + x_ref.transpose());
auto ldlt = stan::math::make_ldlt_factor(x_sym);
return stan::math::trace_inv_quad_form_ldlt(ldlt, y);
};
Expand Down
Loading

0 comments on commit 492a7ae

Please sign in to comment.