Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ad embedded expression tests to expect_ad testing framework #2837

Open
wants to merge 78 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
78 commits
Select commit Hold shift + click to select a range
778bf84
adds expression tests to all expect_ad calls
SteveBronder Oct 25, 2022
f22c705
add tests and docs for embedded expression tests
SteveBronder Oct 25, 2022
832c17e
Merge commit '224e5812c0c4885f5420caa72275f64f0b2568ed' into HEAD
yashikno Oct 25, 2022
868aedf
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Oct 25, 2022
6edc218
update embedded expr tests to support complex types
SteveBronder Oct 25, 2022
9320cdb
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Oct 25, 2022
20faba8
add way to filter out bad values for the common values in the ad tests
SteveBronder Oct 27, 2022
a88372a
fix double eval bug related to lub_constraints
SteveBronder Oct 27, 2022
981737c
fix fma to eval in its holder
SteveBronder Nov 1, 2022
ccaa037
Merge commit '4d48730cde2f8e7cd420c8227b49433153196d46' into HEAD
yashikno Nov 1, 2022
4c343d8
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Nov 1, 2022
7860110
fix pretty printing in expr tests for complex types
SteveBronder Nov 1, 2022
1a3b224
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Nov 1, 2022
9002b49
fix static init for scalar_type helper
SteveBronder Nov 3, 2022
fafc419
Merge commit 'f4c68170c1f1e9bc1284520560204ea9e76c62cd' into HEAD
yashikno Nov 3, 2022
cf01e32
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Nov 3, 2022
ef23692
fix char_scalar_type static init
SteveBronder Nov 3, 2022
30602ca
merge
SteveBronder Nov 3, 2022
f0cc4f7
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Nov 3, 2022
ac26fe7
catch throws from expect_ad
SteveBronder Nov 3, 2022
d142171
Merge branch 'feature/inbedded-expression-tests' of github.com:stan-d…
SteveBronder Nov 3, 2022
f5a6f5a
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Nov 3, 2022
70df44c
Merge remote-tracking branch 'origin/develop' into feature/inbedded-e…
SteveBronder Jul 20, 2023
e98b2b9
Pass through all exceptions during the expression tests as that is ha…
SteveBronder Jul 24, 2023
5097faa
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Jul 24, 2023
e21fe54
move around expr_test so that they run inside of the try where the no…
SteveBronder Jul 24, 2023
85359bc
move around expr_test so that they run inside of the try where the no…
SteveBronder Jul 24, 2023
df09a70
breakup eigenvalues test for mingw max size
SteveBronder Jul 24, 2023
c268474
fix fma test expression usage
SteveBronder Jul 24, 2023
7f7a874
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Jul 24, 2023
fd2910c
fix functions not catching expressions correctly
SteveBronder Jul 26, 2023
e6112b0
Merge remote-tracking branch 'origin' into feature/inbedded-expressio…
SteveBronder Jul 26, 2023
eaf4cd2
merge
SteveBronder Jul 26, 2023
269e52b
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Jul 26, 2023
f1fb408
newline
SteveBronder Jul 26, 2023
5e3c227
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Jul 26, 2023
361877c
Merge remote-tracking branch 'origin/develop' into feature/inbedded-e…
SteveBronder Aug 10, 2023
492a7ae
Fix a lot of functions that fail the expr test
SteveBronder Aug 11, 2023
ee8393f
Merge commit '38289cd2c731c0458a5437ab687638de8d35fe50' into HEAD
yashikno Aug 11, 2023
d27e5a2
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Aug 11, 2023
d16255e
update hmm
SteveBronder Aug 14, 2023
b6ff641
Merge commit 'd4eab2773347ca6fbe03d49f70828c08ff248269' into HEAD
yashikno Aug 14, 2023
86b0594
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Aug 14, 2023
1e08001
bug fix in multi_student_t mix test
SteveBronder Aug 14, 2023
3588b46
Fix matrix_exp etc functions for expression passing
SteveBronder Aug 17, 2023
db69ecf
Merge branch 'feature/reverse-mode-move-semantics' into feature/inbed…
SteveBronder Aug 17, 2023
bdb5f0f
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Aug 17, 2023
8443f8e
remove extra log1p include
SteveBronder Aug 17, 2023
33ae9e1
Merge remote-tracking branch 'origin/develop' into feature/inbedded-e…
SteveBronder Aug 18, 2023
cd620dc
Call recover_memory() after calling the var expression tests
SteveBronder Aug 18, 2023
8aeccb2
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Aug 18, 2023
45e42b3
Merge remote-tracking branch 'origin/develop' into feature/inbedded-e…
SteveBronder Sep 6, 2023
9230fbf
update decompose_test
SteveBronder Sep 6, 2023
59792ee
adds test framework to all mix functions to cleanup memory after they…
SteveBronder Sep 14, 2023
a2e32ab
Merge commit '4cf25de56d29ef39c93eb2595d13dcfd65f97818' into HEAD
yashikno Sep 14, 2023
fa269bb
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Sep 14, 2023
ed5de89
Fix tests
SteveBronder Sep 15, 2023
545cdee
Merge branch 'develop' into feature/inbedded-expression-tests
andrjohns Sep 19, 2023
a7bf561
Fix test name collisions
andrjohns Sep 19, 2023
29c1239
Merge branch 'develop' into feature/inbedded-expression-tests
andrjohns Mar 24, 2024
8d2d1f2
Fix rev fill_test
andrjohns Mar 24, 2024
457b066
Fixup more test failures
andrjohns Mar 25, 2024
aa3ef18
Fix ASAN errors
andrjohns Mar 25, 2024
8056d25
Fix threading tests
andrjohns Mar 26, 2024
d00240e
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Mar 26, 2024
5a0e897
Merge branch 'develop' into feature/inbedded-expression-tests
andrjohns Mar 27, 2024
b8e496d
Fix hmm_marginal
andrjohns Mar 27, 2024
a43932b
cpplint
andrjohns Mar 27, 2024
8a9eb58
update
SteveBronder May 14, 2024
df47fbc
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot May 14, 2024
6daa488
fix arena matrix includes
SteveBronder May 14, 2024
462475c
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot May 14, 2024
be3734e
update needed requires
SteveBronder May 14, 2024
1721a83
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot May 14, 2024
9b9f77d
fix hyper pfq
SteveBronder May 17, 2024
bc9f900
Merge remote-tracking branch 'origin/feature/inbedded-expression-test…
SteveBronder May 17, 2024
cf2d194
update to develop
SteveBronder Oct 2, 2024
ea2735b
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Oct 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
13 changes: 7 additions & 6 deletions stan/math/fwd/fun/mdivide_left_ldlt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,17 @@ mdivide_left_ldlt(LDLT_factor<T>& A, const EigMat& b) {
check_multiplicable("mdivide_left_ldlt", "A", A.matrix(), "b", b);

const auto& b_ref = to_ref(b);
Eigen::Matrix<EigMatValueScalar, R2, C2> b_val(b.rows(), b.cols());
Eigen::Matrix<EigMatValueScalar, R2, C2> b_der(b.rows(), b.cols());
for (int j = 0; j < b.cols(); j++) {
for (int i = 0; i < b.rows(); i++) {
Eigen::Matrix<EigMatValueScalar, R2, C2> b_val(b_ref.rows(), b_ref.cols());
Eigen::Matrix<EigMatValueScalar, R2, C2> b_der(b_ref.rows(), b_ref.cols());
for (int j = 0; j < b_ref.cols(); j++) {
for (int i = 0; i < b_ref.rows(); i++) {
b_val.coeffRef(i, j) = b_ref.coeff(i, j).val_;
b_der.coeffRef(i, j) = b_ref.coeff(i, j).d_;
}
}

return to_fvar(mdivide_left_ldlt(A, b_val), mdivide_left_ldlt(A, b_der));
auto&& A_ref = to_ref(A);
return to_fvar(mdivide_left_ldlt(A_ref, b_val),
mdivide_left_ldlt(A_ref, b_der));
}

} // namespace math
Expand Down
8 changes: 4 additions & 4 deletions stan/math/prim/constraint/cov_matrix_constrain_lkj.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ namespace math {
*/
template <typename T, require_eigen_vector_t<T>* = nullptr>
inline Eigen::Matrix<value_type_t<T>, Eigen::Dynamic, Eigen::Dynamic>
cov_matrix_constrain_lkj(const T& x, size_t k) {
cov_matrix_constrain_lkj(T&& x, size_t k) {
size_t k_choose_2 = (k * (k - 1)) / 2;
const auto& x_ref = to_ref(x);
auto&& x_ref = to_ref(std::forward<T>(x));
return read_cov_matrix(corr_constrain(x_ref.head(k_choose_2)),
positive_constrain(x_ref.tail(k)));
}
Expand All @@ -56,9 +56,9 @@ cov_matrix_constrain_lkj(const T& x, size_t k) {
*/
template <typename T, require_eigen_vector_t<T>* = nullptr>
inline Eigen::Matrix<value_type_t<T>, Eigen::Dynamic, Eigen::Dynamic>
cov_matrix_constrain_lkj(const T& x, size_t k, return_type_t<T>& lp) {
cov_matrix_constrain_lkj(T&& x, size_t k, return_type_t<T>& lp) {
size_t k_choose_2 = (k * (k - 1)) / 2;
const auto& x_ref = x;
auto&& x_ref = to_ref(std::forward<T>(x));
return read_cov_matrix(corr_constrain(x_ref.head(k_choose_2)),
positive_constrain(x_ref.tail(k)), lp);
}
Expand Down
6 changes: 4 additions & 2 deletions stan/math/prim/constraint/lb_constrain.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,9 @@ inline auto lb_constrain(const T& x, const L& lb, return_type_t<T, L>& lp) {
template <typename T, typename L, require_not_std_vector_t<L>* = nullptr>
inline auto lb_constrain(const std::vector<T>& x, const L& lb) {
std::vector<plain_type_t<decltype(lb_constrain(x[0], lb))>> ret(x.size());
auto&& lb_ref = to_ref(lb);
for (size_t i = 0; i < x.size(); ++i) {
ret[i] = lb_constrain(x[i], lb);
ret[i] = lb_constrain(x[i], lb_ref);
}
return ret;
}
Expand All @@ -173,8 +174,9 @@ template <typename T, typename L, require_not_std_vector_t<L>* = nullptr>
inline auto lb_constrain(const std::vector<T>& x, const L& lb,
return_type_t<T, L>& lp) {
std::vector<plain_type_t<decltype(lb_constrain(x[0], lb))>> ret(x.size());
auto&& lb_ref = to_ref(lb);
for (size_t i = 0; i < x.size(); ++i) {
ret[i] = lb_constrain(x[i], lb, lp);
ret[i] = lb_constrain(x[i], lb_ref, lp);
}
return ret;
}
Expand Down
40 changes: 25 additions & 15 deletions stan/math/prim/constraint/lub_constrain.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ template <typename T, typename L, typename U, require_eigen_t<T>* = nullptr,
require_all_stan_scalar_t<L, U>* = nullptr,
require_not_var_t<return_type_t<T, L, U>>* = nullptr>
inline auto lub_constrain(const T& x, const L& lb, const U& ub) {
return eval(
x.unaryExpr([ub, lb](auto&& xx) { return lub_constrain(xx, lb, ub); }));
return eval(to_ref(x).unaryExpr(
[ub, lb](auto&& xx) { return lub_constrain(xx, lb, ub); }));
}

/**
Expand All @@ -131,7 +131,7 @@ template <typename T, typename L, typename U, require_eigen_t<T>* = nullptr,
require_not_var_t<return_type_t<T, L, U>>* = nullptr>
inline auto lub_constrain(const T& x, const L& lb, const U& ub,
return_type_t<T, L, U>& lp) {
return eval(x.unaryExpr(
return eval(to_ref(x).unaryExpr(
[lb, ub, &lp](auto&& xx) { return lub_constrain(xx, lb, ub, lp); }));
}

Expand All @@ -145,8 +145,9 @@ template <typename T, typename L, typename U,
require_not_var_t<return_type_t<T, L, U>>* = nullptr>
inline auto lub_constrain(const T& x, const L& lb, const U& ub) {
check_matching_dims("lub_constrain", "x", x, "lb", lb);
return eval(x.binaryExpr(
lb, [ub](auto&& x, auto&& lb) { return lub_constrain(x, lb, ub); }));
return eval(to_ref(x).binaryExpr(to_ref(lb), [ub](auto&& x, auto&& lb) {
return lub_constrain(x, lb, ub);
}));
}

/**
Expand All @@ -160,7 +161,7 @@ template <typename T, typename L, typename U,
inline auto lub_constrain(const T& x, const L& lb, const U& ub,
return_type_t<T, L, U>& lp) {
check_matching_dims("lub_constrain", "x", x, "lb", lb);
return eval(x.binaryExpr(lb, [ub, &lp](auto&& x, auto&& lb) {
return eval(to_ref(x).binaryExpr(to_ref(lb), [ub, &lp](auto&& x, auto&& lb) {
return lub_constrain(x, lb, ub, lp);
}));
}
Expand All @@ -175,8 +176,9 @@ template <typename T, typename L, typename U,
require_not_var_t<return_type_t<T, L, U>>* = nullptr>
inline auto lub_constrain(const T& x, const L& lb, const U& ub) {
check_matching_dims("lub_constrain", "x", x, "ub", ub);
return eval(x.binaryExpr(
ub, [lb](auto&& x, auto&& ub) { return lub_constrain(x, lb, ub); }));
return eval(to_ref(x).binaryExpr(to_ref(ub), [lb](auto&& x, auto&& ub) {
return lub_constrain(x, lb, ub);
}));
}

/**
Expand All @@ -190,7 +192,7 @@ template <typename T, typename L, typename U,
inline auto lub_constrain(const T& x, const L& lb, const U& ub,
return_type_t<T, L, U>& lp) {
check_matching_dims("lub_constrain", "x", x, "ub", ub);
return eval(x.binaryExpr(ub, [lb, &lp](auto&& x, auto&& ub) {
return eval(to_ref(x).binaryExpr(to_ref(ub), [lb, &lp](auto&& x, auto&& ub) {
return lub_constrain(x, lb, ub, lp);
}));
}
Expand Down Expand Up @@ -248,8 +250,10 @@ template <typename T, typename L, typename U,
inline auto lub_constrain(const std::vector<T>& x, const L& lb, const U& ub) {
std::vector<plain_type_t<decltype(lub_constrain(x[0], lb, ub))>> ret(
x.size());
auto&& lb_ref = to_ref(lb);
auto&& ub_ref = to_ref(ub);
for (size_t i = 0; i < x.size(); ++i) {
ret[i] = lub_constrain(x[i], lb, ub);
ret[i] = lub_constrain(x[i], lb_ref, ub_ref);
}
return ret;
}
Expand All @@ -263,8 +267,10 @@ inline auto lub_constrain(const std::vector<T>& x, const L& lb, const U& ub,
return_type_t<T, L, U>& lp) {
std::vector<plain_type_t<decltype(lub_constrain(x[0], lb, ub))>> ret(
x.size());
auto&& lb_ref = to_ref(lb);
auto&& ub_ref = to_ref(ub);
for (size_t i = 0; i < x.size(); ++i) {
ret[i] = lub_constrain(x[i], lb, ub, lp);
ret[i] = lub_constrain(x[i], lb_ref, ub_ref, lp);
}
return ret;
}
Expand All @@ -279,8 +285,9 @@ inline auto lub_constrain(const std::vector<T>& x, const L& lb,
check_matching_dims("lub_constrain", "x", x, "ub", ub);
std::vector<plain_type_t<decltype(lub_constrain(x[0], lb, ub[0]))>> ret(
x.size());
auto&& lb_ref = to_ref(lb);
for (size_t i = 0; i < x.size(); ++i) {
ret[i] = lub_constrain(x[i], lb, ub[i]);
ret[i] = lub_constrain(x[i], lb_ref, ub[i]);
}
return ret;
}
Expand All @@ -296,8 +303,9 @@ inline auto lub_constrain(const std::vector<T>& x, const L& lb,
check_matching_dims("lub_constrain", "x", x, "ub", ub);
std::vector<plain_type_t<decltype(lub_constrain(x[0], lb, ub[0]))>> ret(
x.size());
auto&& lb_ref = to_ref(lb);
for (size_t i = 0; i < x.size(); ++i) {
ret[i] = lub_constrain(x[i], lb, ub[i], lp);
ret[i] = lub_constrain(x[i], lb_ref, ub[i], lp);
}
return ret;
}
Expand All @@ -312,8 +320,9 @@ inline auto lub_constrain(const std::vector<T>& x, const std::vector<L>& lb,
check_matching_dims("lub_constrain", "x", x, "lb", lb);
std::vector<plain_type_t<decltype(lub_constrain(x[0], lb[0], ub))>> ret(
x.size());
auto&& ub_ref = to_ref(ub);
for (size_t i = 0; i < x.size(); ++i) {
ret[i] = lub_constrain(x[i], lb[i], ub);
ret[i] = lub_constrain(x[i], lb[i], ub_ref);
}
return ret;
}
Expand All @@ -328,8 +337,9 @@ inline auto lub_constrain(const std::vector<T>& x, const std::vector<L>& lb,
check_matching_dims("lub_constrain", "x", x, "lb", lb);
std::vector<plain_type_t<decltype(lub_constrain(x[0], lb[0], ub))>> ret(
x.size());
auto&& ub_ref = to_ref(ub);
for (size_t i = 0; i < x.size(); ++i) {
ret[i] = lub_constrain(x[i], lb[i], ub, lp);
ret[i] = lub_constrain(x[i], lb[i], ub_ref, lp);
}
return ret;
}
Expand Down
6 changes: 4 additions & 2 deletions stan/math/prim/constraint/ub_constrain.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,9 @@ inline auto ub_constrain(const T& x, const U& ub,
template <typename T, typename U, require_not_std_vector_t<U>* = nullptr>
inline auto ub_constrain(const std::vector<T>& x, const U& ub) {
std::vector<plain_type_t<decltype(ub_constrain(x[0], ub))>> ret(x.size());
auto&& ub_ref = to_ref(ub);
for (size_t i = 0; i < x.size(); ++i) {
ret[i] = ub_constrain(x[i], ub);
ret[i] = ub_constrain(x[i], ub_ref);
}
return ret;
}
Expand All @@ -183,8 +184,9 @@ template <typename T, typename U, require_not_std_vector_t<U>* = nullptr>
inline auto ub_constrain(const std::vector<T>& x, const U& ub,
return_type_t<T, U>& lp) {
std::vector<plain_type_t<decltype(ub_constrain(x[0], ub))>> ret(x.size());
auto&& ub_ref = to_ref(ub);
for (size_t i = 0; i < x.size(); ++i) {
ret[i] = ub_constrain(x[i], ub, lp);
ret[i] = ub_constrain(x[i], ub_ref, lp);
}
return ret;
}
Expand Down
2 changes: 1 addition & 1 deletion stan/math/prim/fun/diag_post_multiply.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ template <typename T1, typename T2, require_eigen_t<T1>* = nullptr,
auto diag_post_multiply(const T1& m1, const T2& m2) {
check_size_match("diag_post_multiply", "m2.size()", m2.size(), "m1.cols()",
m1.cols());
return m1 * m2.asDiagonal();
return to_ref(to_ref(m1) * to_ref(m2).asDiagonal());
}

} // namespace math
Expand Down
2 changes: 1 addition & 1 deletion stan/math/prim/fun/diag_pre_multiply.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ auto diag_pre_multiply(const T1& m1, const T2& m2) {
check_size_match("diag_pre_multiply", "m1.size()", m1.size(), "m2.rows()",
m2.rows());

return m1.asDiagonal() * m2;
return to_ref(to_ref(m1).asDiagonal() * to_ref(m2));
}

} // namespace math
Expand Down
3 changes: 2 additions & 1 deletion stan/math/prim/fun/fma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ inline auto fma(T1&& x, T2&& y, T3&& z) {
[](auto&& x, auto&& y, auto&& z) {
return ((as_array_or_scalar(x) * as_array_or_scalar(y))
+ as_array_or_scalar(z))
.matrix();
.matrix()
.eval();
},
std::forward<T1>(x), std::forward<T2>(y), std::forward<T3>(z));
}
Expand Down
4 changes: 2 additions & 2 deletions stan/math/prim/fun/matrix_exp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ namespace math {
template <typename T, typename = require_eigen_t<T>>
inline plain_type_t<T> matrix_exp(const T& A_in) {
using std::exp;
const auto& A = A_in.eval();
const auto& A = to_ref(A_in);
check_square("matrix_exp", "input matrix", A);
if (T::RowsAtCompileTime == 1 && T::ColsAtCompileTime == 1) {
plain_type_t<T> res;
res << exp(A(0));
return res;
}
if (A_in.size() == 0) {
if (A.size() == 0) {
return {};
}
return (A.cols() == 2
Expand Down
6 changes: 3 additions & 3 deletions stan/math/prim/fun/matrix_exp_2x2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ matrix_exp_2x2(const EigMat& A) {
using std::exp;
using std::sinh;
using std::sqrt;

auto&& A_ref = to_ref(A);
using T = value_type_t<EigMat>;
T a = A(0, 0), b = A(0, 1), c = A(1, 0), d = A(1, 1), delta;
T a = A_ref(0, 0), b = A_ref(0, 1), c = A_ref(1, 0), d = A_ref(1, 1), delta;
delta = sqrt(square(a - d) + 4 * b * c);

Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic> B(2, 2);
Expand All @@ -49,7 +49,7 @@ matrix_exp_2x2(const EigMat& A) {

// use pade approximation if cosh & sinh ops overflow to NaN
if (B.hasNaN()) {
return matrix_exp_pade(A);
return matrix_exp_pade(A_ref);
} else {
return B / delta;
}
Expand Down
12 changes: 7 additions & 5 deletions stan/math/prim/fun/matrix_exp_multiply.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@ template <typename EigMat1, typename EigMat2,
require_all_st_same<double, EigMat1, EigMat2>* = nullptr>
inline Eigen::Matrix<double, Eigen::Dynamic, EigMat2::ColsAtCompileTime>
matrix_exp_multiply(const EigMat1& A, const EigMat2& B) {
check_square("matrix_exp_multiply", "input matrix", A);
check_multiplicable("matrix_exp_multiply", "A", A, "B", B);
if (A.size() == 0) {
return {0, B.cols()};
auto&& A_ref = to_ref(A);
auto&& B_ref = to_ref(B);
check_square("matrix_exp_multiply", "input matrix", A_ref);
check_multiplicable("matrix_exp_multiply", "A", A_ref, "B", B_ref);
if (A_ref.size() == 0) {
return {0, B_ref.cols()};
}

return matrix_exp_action_handler().action(A, B);
return matrix_exp_action_handler().action(A_ref, B_ref);
}

} // namespace math
Expand Down
8 changes: 5 additions & 3 deletions stan/math/prim/fun/matrix_exp_pade.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,20 @@ template <typename EigMat, require_eigen_t<EigMat>* = nullptr>
Eigen::Matrix<value_type_t<EigMat>, EigMat::RowsAtCompileTime,
EigMat::ColsAtCompileTime>
matrix_exp_pade(const EigMat& arg) {
auto&& arg_ref = to_ref(arg);
using MatrixType
= Eigen::Matrix<value_type_t<EigMat>, EigMat::RowsAtCompileTime,
EigMat::ColsAtCompileTime>;
check_square("matrix_exp_pade", "arg", arg);
if (arg.size() == 0) {
check_square("matrix_exp_pade", "arg", arg_ref);
if (arg_ref.size() == 0) {
return {};
}

MatrixType U, V;
int squarings;

Eigen::matrix_exp_computeUV<MatrixType>::run(arg, U, V, squarings, arg(0, 0));
Eigen::matrix_exp_computeUV<MatrixType>::run(arg_ref, U, V, squarings,
arg_ref(0, 0));
// Pade approximant is
// (U+V) / (-U+V)
MatrixType numer = U + V;
Expand Down
9 changes: 5 additions & 4 deletions stan/math/prim/fun/pseudo_eigenvalues.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
namespace stan {
namespace math {

template <typename T>
Eigen::Matrix<T, -1, -1> pseudo_eigenvalues(const Eigen::Matrix<T, -1, -1>& m) {
template <typename EigMat, require_eigen_matrix_base_t<EigMat>* = nullptr>
inline Eigen::Matrix<scalar_type_t<EigMat>, -1, -1> pseudo_eigenvalues(
EigMat&& m) {
check_nonzero_size("pseudo_eigenvalues", "m", m);
check_square("pseudo_eigenvalues", "m", m);

Eigen::EigenSolver<Eigen::Matrix<T, -1, -1>> solver(m);
Eigen::EigenSolver<Eigen::Matrix<scalar_type_t<EigMat>, -1, -1>> solver(
to_ref(std::forward<EigMat>(m)));
return solver.pseudoEigenvalueMatrix();
}

Expand Down
10 changes: 5 additions & 5 deletions stan/math/prim/fun/pseudo_eigenvectors.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
namespace stan {
namespace math {

template <typename T>
Eigen::Matrix<T, -1, -1> pseudo_eigenvectors(
const Eigen::Matrix<T, -1, -1>& m) {
template <typename EigMat, require_eigen_matrix_base_t<EigMat>* = nullptr>
inline Eigen::Matrix<scalar_type_t<EigMat>, -1, -1> pseudo_eigenvectors(
EigMat&& m) {
check_nonzero_size("pseudo_eigenvectors", "m", m);
check_square("pseudo_eigenvectors", "m", m);

Eigen::EigenSolver<Eigen::Matrix<T, -1, -1>> solver(m);
Eigen::EigenSolver<Eigen::Matrix<scalar_type_t<EigMat>, -1, -1>> solver(
to_ref(std::forward<EigMat>(m)));
return solver.pseudoEigenvectors();
}

Expand Down
Loading